Re: 今日も使えないすごいビット演算

昨日も使ってなかったですよね(それはそう)

参考文献↓

tl;dr

乗算 $2$ 回で msb 演算を実現する。 AtCoder の環境で $64$ ビットの計算をする計測では、条件分岐を $6$ 回やるいつものやつよりも高速だった。

導入

ビットごとの論理演算・加減算・乗算を $O(1)$ 時間とすると、 msb の番号を求めるのが $O(1)$ 時間になるが、 builtin に勝てないという話でした。

たぶん builtin は CPU 命令になってしまえば本当に 3 サイクルとかで実現してくるので、乗算こねこねしたようなやつでは勝てる見込みがありません。接戦に見えているならば乱数生成器が律速だと思います。とはいえ、 builtin はいつでも使えるわけではないので、言語規格にある程度の機能だけで実装したいなと思う今日この頃でした。

えびちゃんの記事で言ってる方法は乗算を合計 $5$ 回使いますが、今回はそれを $2$ 回に抑えることで高速化を期待します。

考察の上で重要だと思ったこと

もしかしたら別のビット演算を考察したい人がいるかもしれないので、思ったことを書きます。

  • 加減算で得られる結果はどちらかというと lsb と親和性があり、今回必要な msb に直接使うのは難しそうです。
  • ある $1$ ビットが全体に影響を及ぼすような計算、および全体が $1$ ビットに影響を及ぼすような計算は、定数回の論理演算や加減算では表現できません。乗算が必要です。定数回でなくてもよい場合は、シフト幅を変えながら繰り返すアルゴリズムが有効です。
  • MoFR のようななにか
    $O(\log w)$ ビットくらいまで落とすと、 $1$ ワードにすべての答えを埋め込めて、ビットシフトで呼び出せます。

計算

$64$ ビットが非常に都合がよいので、あえてビット長を一般化しないで計算します。

64 → 32

乗算結果の上位ビットを利用する方法のため、あらかじめビット長を半分にします。ビット長を半分にするのは $O(1)$ 時間です。

uintmax_t msb(uintmax_t n) {
    int q = (n >> 32) ? 32 : 0;
    q += macromsb(n >> q);
    return q + micromsb(n >> q);
}

summary

えびちゃんの記事にも書いてあります。

$4$ ビットずつについて、そのうち $1$ ビットでも立っていれば最上位ビットを立てます。

  • ビット列が 1xxx のとき、そこでマスクを取ります。
  • ビット列が 0xxx のとき、 1000 から引くことで最上位ビットに特徴が現れます。
uintmax_t summary(uintmax_t n){
    constexpr uintmax_t himask = 0x8888'8888;
    n = (n | ~(himask - (n & ~himask))) & himask;
    return n;
}

macro msb

summary 後のビット列について、 msb を求めます。ここで乗算を $2$ 回使います。

まず、 msb より下のビットをすべて立てます。 0x11111 みたいなのを掛ければ lsb より上のビットを立てることができますが、そのような乗算の上位ビットを見ると msb より下のビットが立っているような部分があります。その部分で summary するとよいです。

イメージ
                  1 1 1 1 1 1 1 1
 x)               1 1 1 1 1 1 1 1
----------------------------------
  0(1 2 3 4 5 6 7 8)7 6 5 4 3 2 1

もう一度同じ乗算をすると popcount が得られます。これが msb の番号 +1 になります。

uintmax_t macromsb(uintmax_t n) {
    constexpr uintmax_t multiplier = 0x1111'1111;
    n = summary(n);
    n = (n * multiplier) >> 35;
    n = summary(n);
    n = (n * multiplier) >> 31;
    return (n & 0xf) * 4;
}

micro msb

$4$ ビット分しか残っていないので、全パターンの答えを $4\times 16$ ビットに埋め込めます。

uintmax_t micromsb(uintmax_t n) {
    return 0x3333'3333'2222'1100 >> (n << 2) & 0xf;
}

実測

いつものように、 AtCoder のコードテストで計測します。

int main(){
    Xoshiro256pp rng;
    uintmax_t x = 0;
    for(int i=0; i<100'000'000; i++){
        uintmax_t n = rng.rng64() | 1;
        x += n; // (0)
        // x += 63 - __builtin_clzll(n); // (1)
        // x += msb(n); // (2)
        // x += msbx(n); // (3)
        // x += msbx2(n); // (4)
        // x += msbebi(n); // (5)
    }
    printf("%llu\n", x); // expect 6199992434 for (1-5)
    return 0;
}
全文(クリックで展開)
#include <cstdint>
#include <cstdio>
uintmax_t msb4(uintmax_t n) {
    return 0x3333'3333'2222'1100 >> (n << 2) & 0xf;
}

uintmax_t msbblock(uintmax_t n){
    constexpr uintmax_t himask = 0x8888'8888;
    n = (n | ~(himask - (n & ~himask))) & himask;
    constexpr uintmax_t multiplier = 0x1111'1111;
    n *= multiplier;
    n >>= 35;
    n = (n | ~(himask - (n & ~himask))) & himask;
    n *= multiplier;
    n >>= 31;
    return (n & 0xf) * 4;
}

uintmax_t msb(uintmax_t n) {
    int q = (n >> 32) ? 32 : 0;
    q += msbblock(n >> q);
    return q + msb4(n >> q);
}

uintmax_t msbx(uintmax_t n) {
    int q = (n & 0xffff'ffff'0000'0000) ? 32 : 0;
    q += (n >> q & 0xffff'0000) ? 16 : 0;
    q += (n >> q & 0xff00) ? 8 : 0;
    q += (n >> q & 0xf0) ? 4 : 0;
    q += msb4(n >> q);
    return q;
}

uintmax_t msbx2(uintmax_t n) {
    int q = (n & 0xffff'ffff'0000'0000) ? 32 : 0;
    q += (n >> q & 0xffff'0000) ? 16 : 0;
    q += (n >> q & 0xff00) ? 8 : 0;
    q += (n >> q & 0xf0) ? 4 : 0;
    q += (n >> q & 0xC) ? 2 : 0;
    q += (n >> q & 0x2) ? 1 : 0;
    return q;
}

constexpr uintmax_t operator ""_ju(unsigned long long n) {
  return n;
}

uintmax_t popcount(uintmax_t n) {
  constexpr uintmax_t multiplier = 0x0101010101010101_ju;
  return (n * multiplier) >> 56;
}

uintmax_t minimsb(uintmax_t n) {
  if (n >= 0x80) return 7;

  constexpr uintmax_t multiplier = 0x0101010101010101_ju;
  uintmax_t minuend = n * multiplier;

  constexpr uintmax_t subtrahend = 0x8040201008040201_ju;
  uintmax_t rest = minuend - subtrahend;

  constexpr uintmax_t mask = 0x8080808080808080_ju;
  return popcount((~rest & mask) >> 7) - 1;
}

uintmax_t summary(uintmax_t n) {
  constexpr uintmax_t hi_mask = 0x8080808080808080_ju;  
  constexpr uintmax_t lo_mask = 0x7F7F7F7F7F7F7F7F_ju;
  uintmax_t ones = (n | ~(hi_mask - (n & lo_mask))) & hi_mask;
  
  constexpr uintmax_t multiplier = 0x0002040810204081_ju;
  return (ones * multiplier) >> 56;
}

uintmax_t access(uintmax_t n, uintmax_t i) {
  constexpr uintmax_t mask = 0xFF_ju;
  return (n >> (i * 8)) & mask;
}

uintmax_t msbebi(uintmax_t n) {
  uintmax_t i = minimsb(summary(n));
  uintmax_t j = minimsb(access(n, i));
  return i * 8 + j;
}

#include <vector>
#include <algorithm>
#include <unordered_map>
#include <cassert>


class Xoshiro256pp{
public:

    using i32 = int32_t;
    using u32 = uint32_t;
    using i64 = int64_t;
    using u64 = uint64_t;


private:
    uint64_t s[4];

    // https://prng.di.unimi.it/xoshiro256plusplus.c
    static inline uint64_t rotl(const uint64_t x, int k) noexcept {
        return (x << k) | (x >> (64 - k));
    }
    inline uint64_t next(void) noexcept {
        const uint64_t result = rotl(s[0] + s[3], 23) + s[0];
        const uint64_t t = s[1] << 17;
        s[2] ^= s[0];
        s[3] ^= s[1];
        s[1] ^= s[2];
        s[0] ^= s[3];
        s[2] ^= t;
        s[3] = rotl(s[3], 45);
        return result;
    }

    // https://xoshiro.di.unimi.it/splitmix64.c
    u64 splitmix64(u64& x) {
        u64 z = (x += 0x9e3779b97f4a7c15);
        z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
        z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
        return z ^ (z >> 31);
    }

public:

    void seed(u64 x = 7001){
        assert(x != 0);
        s[0] = x;
        for(int i=1; i<4; i++) s[i] = splitmix64(x);
    }

    Xoshiro256pp(){ seed(); }
    
    u64 rng64() {
        return next();
    }
};

int main(){
    Xoshiro256pp rng;
    uintmax_t x = 0;
    for(int i=0; i<100'000'000; i++){
        uintmax_t n = rng.rng64() | 1;
        x += n;
        x += msb(n);
        x += 63 - __builtin_clzll(n);
        x += msbx(n);
        x += msbx2(n);
        x += msbebi(n);
    }
    printf("%llu\n", x); // expect 6199992434
    return 0;
}
(#)時間[ms](0) との差備考
(0)1550msb を求めない
(1)17419builtin
(2)439284乗算 2 回
(3)499344条件分岐+micromsb
(4)615460条件分岐 6 回
(5)705550えびちゃんの

やはり builtin は CPU 命令 $1$ 回ぶんくらいの時間しかかかってなさそうです。それにしても速すぎるような気がするんですが、処理時間が別の処理に溶けこんでいたりするんでしょうか?

そして builtin 以外に試したもののうちで最も速かったのがメインの方法でした。やったね

おわり

タイトルとURLをコピーしました