どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

Segment Tree(セグメント木)による重み付きランダムサンプリング

競技プログラミング界隈では一般教養であるらしいセグメント木のSum-tree構造で高速な重み付きサンプリングを実装します。


はじめに

強化学習の重要手法である優先度付き経験再生(Prioritized Experience Replay)では、重みづけされた100万の経験(遷移情報)からランダムにサンプリングしてミニバッチを作成する、という処理があります。このような重みづけサンプリングはnp.random.choiceの引数pに重み情報を与えることで楽に実装できます。コードの見通しが大変よくなるので過去記事ではこの方法での実装例を紹介しました。

しかし論文ではsum-treeデータ構造で実装すると速いと書いてあります。本記事ではせっかくなのでこちらの実装を試してみます。

[1511.05952] Prioritized Experience Replay

DQNの進化史 ③Prioritized experience replay, Multi-step learning, Categorical DQN - どこから見てもメンダコ

DQNの進化史 ④Rainbowの実装 - どこから見てもメンダコ


A. numpy.choiceによる重み付きランダムサンプリング

まずはベースラインとしてnumpy.random.choiceによる重み付きランダムサンプリングのパフォーマンスを見ます。要素数は100万で各要素には0-5の優先度が割り当てられます。DQNでのミニバッチサイズの32に従って1iterで32要素をサンプリングします。また、Breakout(ブロック崩し)環境ではそれなりに学習が進むと1 episodeで200回くらいはミニバッチ作成するので200iter繰り返します。

f:id:horomary:20210215220249p:plain:w500

結果は約3.2秒となりました。1episodeあたり3.2秒なら趣味で強化学習やるくらいなら許容できる程度ではありますがちょっと遅いですね。


B. 累積和による重み付きランダムサンプリング

つぎに愚直な実装として逆関数法による重み付きランダムサンプリングを実装します。累積密度関数が計算できる確率分布なら逆関数法を使うことでい一様乱数から目的の確率分布に従う乱数に変換できます。

逆関数法 - Wikipedia

逆関数法を用いた乱数生成の証明と例 | 高校数学の美しい物語

たとえば、4要素のリストにおいて各要素の優先度が [4, 2, 1, 3] のときは、0≦ z ≦ 4+2+1+3 = 10 の範囲で一様乱数を発生させ、累積和がzとなるのがどの要素のときかを調べることで優先度に従ったサンプリングを行うことができます。

f:id:horomary:20210215223407p:plain:w600
0≦z≦10で乱数を発生させて累積和がzに該当する要素を選択すれば優先度の大きさにサンプリング確率が従う

f:id:horomary:20210215224616p:plain:w500

たしかに優先度に従ってサンプリングできていることがわかります。では要素数が増えた時のパフォーマンスがどうなるかをnp.random.choiceと同じ条件で確かめます。

f:id:horomary:20210215231046p:plain:w500

結果は144秒、遅い! すべての要素に対して累積和チェックをしているので計算量がNになってしまっていることが原因です。

C. Sum-tree構造を活用した重み付きランダムサンプリング

上で重いのは累積和がzになるのはどの要素番号のときであるかを調べる処理です。これはSegment-tree(セグメント木)構造を使うことで高速に検索することができます。さきほどと同様に各要素の優先度が [4, 2, 1, 3] のときのSum-treeを構築すると下図のようになります。

f:id:horomary:20210215232341p:plain:w500
[4, 2, 1, 3]に対するSum-tree

たとえば累積和が6.5を超える要素番号を検索したいとしましょう。ルートノードである10の左子ノードが6なので、要素0, 1までの累積和が6であることがわかります。よって、要素番号2,3の区間における累積和が0.5(= 6.5 - 6)になる要素を探せばよいというわけです。そこでルートノードの右子ノード4に進みます。この子ノードを見ると1, 3なので左子ノードで累積和が0.5になることが分かります。左子ノードは実要素なのでここで探索終了となります。

実際に格納される要素数Nが  N = 2^{K} のとき、Sum-treeの深さ(階層?)はKになるので探索回数は要素数Nに対してlogNとなり効率的であることがわかります。

Sum-TreeのPython実装

競プロ界隈の人はわざわざ遅いpythonで実装とかしないかもしれませんが、強化学習で使う分にはそこそこのパフォーマンスが出ればよいのでpythonでSum-Treeを実装します。この実装は ray/segment_tree.py at master · ray-project/ray · GitHub から抽象化を削りシンプルに再実装したものです。

実装のポイント:
・格納される要素数がNのときSumtree全体の要素数は2N-1
・ルートノードのインデックス番号を1に設定すると左子ノードのインデックス=2×親のインデックス、右子ノードのインデックス=2×親のインデックス+1 となり便利なので、インデックス番号0を使わない長さ2NのリストでSum-treeを実装する
__setitem____getitem__ を活用してsumtreeであることを感じさせない使い勝手を実現する

f:id:horomary:20210215235032p:plain:w500
クラス外から見た要素番号(赤)と実体の要素番号(青)

f:id:horomary:20210216004207p:plain:w500

速度パフォーマンスの確認

numpyのときと同様にバッチサイズ32のミニバッチを200回作った時の速度を計測します。

f:id:horomary:20210216004530p:plain:w400

サンプリング時間だけ見ればnumpy.random.choiceより30倍程度は速いことがわかります。


おわりに

__setitem____getitem__ が一番輝くのはsegment tree説