どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

スッキリわかるAlphaZero

The game of Go has long been viewed as the most challenging of classic games for artificial intelligence
囲碁はAIにとってもっとも困難なボードゲームの一つと考えられてきました
(Mastering the game of Go with deep neural networks and tree search | Nature より)

Alpha Zero: https://science.sciencemag.org/content/362/6419/1140.full?ijkey=XGd77kI6W4rSc&keytype=ref&siteid=sci (オープンアクセス版)

Alpha Go Zero: Mastering the game of Go without human knowledge | Nature (オープンアクセス版)

DeepMindブログ: AlphaZero: Shedding new light on chess, shogi, and Go | DeepMind



AlphaZeroとは

AlphaZero(2018)とは人間の対局データや定石の知識(human knowledge)を使わずに学習を行うボードゲームAIであり、2015年に世界で初めてプロ囲碁棋士に勝利した囲碁アルゴリズムAlpha Goの後継手法です。Alpha Goの”プロ棋士に勝利”という実績はAIの時代を感じさせるわかりやすいインパクトがありましたが、AlphaGoのパフォーマンスは大量の棋譜およびドメイン知識に基づくヒューリスティクス*1に強烈に依存しており、いわゆる"強いAI"からははるか遠いものでした。

このAlpha Goの後継手法であるAlpha Go Zeroでは棋譜すなわち人間の対局データを使わずに、自己対局(selfplay)のみでAlpha Goを超える性能を実現することに成功しました。また、Alpha Goで採用されていた囲碁ドメイン知識に基づくヒューリスティクスも除外されました。

さらにAlpha Go Zeroは微調整され、 AlphaZero として再発表されました。AlphaZeroは Alpha Go Zeroとほぼ同じアルゴリズムをチェス、将棋向けに汎用化したよ、という内容ですので、アルゴリズム的には Alpha Go Zero = AlphaZero という理解で問題ありません。

ただし、事前知識ゼロとは言ってもAlphaZeroにはゲームのルールブックが与えられている、すなわち

  1. ある局面における可能な手(合法手)を問い合わせることができる
  2. ある盤面である手を選択したときの次盤面(盤面の遷移)を問い合わせることができる

ということには留意ください。

f:id:horomary:20210613232914p:plain:w500
MuZero: Mastering Go, chess, shogi and Atari without rules | DeepMind より切り抜き

AlphaZeroは事前データおよびドメイン知識なしで囲碁/チェス/将棋において超人的パフォーマンスを達成したことにより世界に大きなインパクトを与えましたが、真に驚くべきはそのアルゴリズムのシンプルさです。本記事ではこのAlphaZeroアルゴリズムを簡単に解説しつつオセロ向けの実装を紹介します。


アルゴリズム概要

囲碁のような完全情報ゲームにおける理論上最強のアルゴリズムはゲーム木の全探索です。完全探索されたゲーム木を用意できればそこからの検索によって常に最善手をとりつづけることができるためです。

f:id:horomary:20210619193144p:plain:w400
wikipedia: ゲーム木 より

しかし、囲碁や将棋ではゲーム木サイズが巨大であるため*2、全探索は実質的に不可能です。

そこで、AlphaZeroの登場以前から囲碁AIではモンテカルロ木探索(MCTS: Monte-Carlo Tree Search)というゲーム木探索アルゴリズムが中心的に使用されてきました。モンテカルロ木探索とは、"ある盤面がどれだけ良いか"をその盤面から始まる無数のランダムプレイの結果から評価することによって行動決定を行うアルゴリズムです。盤面評価がランダムプレイのみに依存するために盤面評価関数の設計不要であることが特徴です。

モンテカルロ木探索を利用したゲームプレイの流れは下図のようになります。すなわち、その盤面におけるすべての可能な行動(合法手)について遷移先の盤面をモンテカルロ木探索によって評価することによって行動を決定するというものです。ゆえにモンテカルロ木探索は人間が思考内で行う先読みのようなことを行っています。

f:id:horomary:20210619180806p:plain:w600
MuZero: Mastering Go, chess, shogi and Atari without rules | DeepMind の図を改変して掲載

AlphaZeroもまたモンテカルロ木探索を中心としたアルゴリズムであり、その貢献はモンテカルロ木探索に深層学習を導入することによって探索パフォーマンスの大幅な向上に成功したことです。より具体的にはAlphaZeroではselfplay(自己対戦)によって生成された対局データからの教師あり学習によって明らかに筋の悪い手を足切りすることでモンテカルロ木探索を効率化・高精度化します。


Step by stepで理解するAlphaZero版モンテカルロ木探索

参考資料:
モンテカルロ⽊探索の理論と実践

コンピュータ囲碁研究の歩み
モンテカルロ木探索 - Wikipedia

AlphaZeroの中心にあるのはモンテカルロ木探索(MCTS)であり、MCTSを理解することがそのままAlpha Zeroを理解することです。ここではもっとも単純なモンテカルロ木探索のアルゴリズムから始めて、AlphaZero版MCTSに至る過程を3目並べを題材にステップ by ステップで解説します。

※各手法名は必ずしも正式なものではなく、説明の便宜上てきとうに命名したものが含まれていることに留意ください

1. 原始モンテカルロ木探索

まずはもっとも単純な原始モンテカルロ木探索(pure Monte-Carlo Tree Search)を理解しましょう。

f:id:horomary:20210619221037p:plain:w400

3目並べにおいてあなたが先手(黒)だとします。現在の盤面には何も置かれていないので有効なアクションは(左上, 上, 左上, 左, 中央, 右, 左下, 下, 右下)の計9つです。さて、このうちどのアクションを選択するのがよいでしょうか?

単純モンテカルロ木探索では、この9つのアクションそれぞれの次盤面から始まるランダム対局を無数に行います。ここでは、たとえば各アクションについて100回ずつランダム対局を行ったとします。

※ランダム対局ではゲームが終了するまで相手も自分もランダムに行動選択を行います。

f:id:horomary:20210619223159p:plain:w400

9つの合法手のそれぞれについてランダム対局を100回行ったところ、初手は中心に石を置くのが勝率60%でもっともよい結果となりましたので*3、現状の盤面では中心に石を置くのが最善手と判断するというのが単純モンテカルロ木探索のアルゴリズムの全てです、本当に単純ですね。あまりにも単純なアルゴリズムですが、"盤面の良さ"の評価がランダム対局のみに依存しゲームに関する一切の事前知識や評価関数設計が必要ないので、評価関数を設計しにくいゲームでは非常に有効に機能します


2. UCT-原始モンテカルロ木探索

原始モンテカルロ木探索の大きな欠点はすべての合法手について均等な回数のランダム対局(以下ではプレイアウトとも呼称)を行わなければならないことです。

ランダムプレイである以上、盤面評価の信頼性を確保するためにはある程度の試行回数が必要です。しかし、それでは3目並べのような合法手が少ないゲームならともかく囲碁のような合法手が多い(たとえば囲碁の初手は361通り)ゲームでは膨大な回数のプレイアウトが必要になってしまいます。

そこで導入するのが多腕バンディット問題の考え方です。多腕バンディット問題とは多数のスロットマシーン(バンディットマシン)から有限の試行回数で高設定の台を見つけ出すための理論です。これをモンテカルロ木探索に適用すると、有限のプレイアウト回数でできるだけ良い評価値が得られる行動を見つけ出したい、という問題設定になります。

Vol.31.No.5(2016/9)多腕バンディット問題 – 人工知能学会 (The Japanese Society for Artificial Intelligence)
バンディット問題の理論とアルゴリズム (機械学習プロフェッショナルシリーズ)


ゲーム木探索においてはバンディット問題の有名手法であるUCBアルゴリズムのゲーム木への応用であるUCT(Upper Confidence bound applied to Trees)アルゴリズムがしばしば使用されます。UCTアルゴリズムではすべての合法手についてまずは最低一回評価した後は、次式で定義されるUCTスコアが最大のアクションを選択することでモンテカルロ木探索の効率を向上させます。

f:id:horomary:20210620012618p:plain:w800
Monte Carlo tree search - Wikipedia

 \displaystyle{
\text{アクションa'の選択確率} = \frac{\text{アクションa'のプレイアウト累計勝利数}}{\text{アクションa'の累計試行回数}} + c  \sqrt{ \frac{  \ln{\text{累計試行回数}} }{ \text{アクションa'の累計プレイアウト回数} } }
}

右辺第一項は勝率であり、第二項は全体の試行回数に占めるそのアクションの試行回数割合が小さいほど大きくなるので、プレイアウトの勝率実績が高い(活用)が、あまりプレイアウトされていないアクション(探索)を優先的に選択するアルゴリズムとなります。なお、cは探索と活用をバランスするパラメータです。

では具体例を見てみましょう。下図はさきほどの三目並べの例においてUCTアルゴリズムに従い10回のプレイアウトが終わった状態です。

f:id:horomary:20210620015940p:plain:w600

計算されたUCTスコアが最も高い行動は中央配置の1.47なので次の試行、すなわち11回目の試行ではこの盤面からのプレイアウトを行います。

このようにUCTスコアに従い十分な数のプレイアウトを行うと最終的にはプレイアウト勝率のもっとも高い盤面をもっとも多く試行することになります*4。よって、UCT-原始MCTSの最終的な結論としては試行回数(プレイアウト回数)のもっとも多いアクションを最善手とします。


3. PUCT-モンテカルロ木探索

上述したUCT-原始モンテカルロ探索にさらにN手先読みの要素が追加されることでPUCT-MCTSとなります。ちなみに、一般に”モンテカルロ木探索”と言ったらこのPUCT-MCTSを指すことが多いかと思います。

原始モンテカルロ木探索では”一手読み”しか行いませんが、より深く先読みすることでよりゲーム終端に近づくのでプレイアウト(ランダム対局)の信頼性も高まっていきます。 とはいえ先読みはすればするほど評価しなければいけない盤面が増えるので、先読みを行うのはある程度有望な行動に限定したいところです。

そこでPUCT-モンテカルロ木探索では試行回数が一定回数に到達した盤面のみ、さらに次の盤面を展開します。下図ではたとえば子盤面を展開する閾値を10回の試行としています。

f:id:horomary:20210620111651p:plain:w500

UCT方策によって子盤面が展開済みの盤面が選択された場合は、子盤面のUCTスコアを算出しUCTスコアのもっとも大きい子盤面からプレイアウトを行います。もしその子盤面がさらに子盤面を展開済みであれば、展開されていない盤面を見つけるまで同じことを繰り返します(下図)。

f:id:horomary:20210620113929p:plain:w500

ここで注意すべきは、子盤面からのプレイアウトの結果はその親盤面にも逆伝播すること、さらに囲碁のような2人ゲームでは(当然ではありますが)勝ち/負けのカウントはその盤面が先手番か後手番かで逆になることです。

ここまでで、有望な行動に多くの試行回数を割りあてつつ、十分に有望な手に限ってはさらに深い先読みを行うことで効率的にゲーム木を探索するPUCT-モンテカルロ木探索を説明しました。ここまで理解できればAlphaZero-モンテカルロ木探索の理解までもう少しです。


4. PUCTモンテカルロ木探索+エピソード記憶

UCTアルゴリズムは効率のよい探索手法ですが、毎回知識ゼロからのMCTS開始を強いられる、というつらみがあります。とくに囲碁や将棋のような合法手の多いゲームでは致命的につらいので、過去のモンテカルロ木探索の履歴をうまく利用して筋の悪い手を足切りできないかということを考えます。

たとえば、100回ゲームをプレイすれば少なくとも初期盤面から始まるMCTSについては100回行っているので筋の良い/悪い初手のあたりはつくはずです。このような過去のMCTSの履歴から考えて明らかに有望でない手は試さなくてもいいよね、というようにepisode contextを活用したくなります。

Multi-armed bandits with episode context | SpringerLink

このコンセプトを実現するためにはどうすればよいでしょうか?

たとえば、ごく単純なやり方としてエピソード(ゲーム)をまたいで盤面ごとのMCTS試行回数を記録しておく方法が考えられますが、この方法は可能な盤面の多いゲームではつらいですし、ほんの少し盤面が異なるだけで使えなくなってしまいます。

そこで、AlphaZeroでは過去のモンテカルロ木探索の結果を近似(再現)するようなニューラルネットワークを構築します。 つまり、ある盤面を入力として、過去のMCTSの結果を出力するように教師あり学習で訓練します。

f:id:horomary:20210620162645p:plain:w500
Alpha Go Zero論文 Fig.1を改変して掲載

このように、過去のモンテカルロ木探索の結果を近似するネットワークを訓練することで、有望な手とそうでない手のあたりをつけることができるようになりました。このニューラルネット方策 P(s, a)をUCTスコアに組み込みます。(C_puctはハイパラ)

f:id:horomary:20210620163100p:plain:w400

U(s, a)ではUCTスコアの第二項にP(s, a)を掛けており、過去のMCTS試行の記憶であるP(s, a)がMCTSの事前信念のような役割を担っていると解釈できます。

なお、行動選択ではUCTと同様に、活用 Q(s, a) と 探索 U(s, a) の和が最大の行動を選択します。

f:id:horomary:20210620163223p:plain:w300

このようにAlphaZeroではPUCT-MCTSに過去のMCTS試行情報(episode context)を反映することで、さらに効率よくゲーム木を探索できるようになります。

ところで、ここまでの手法では Q(s, a) とはプレイアウト(=ランダム対局)の勝率でした。しかし、AlphaZeroではこの項もまたニューラルネットワークで置き換えます。


5. PV-モンテカルロ木探索(AlphaZero)

エピソード記憶に基づく事前信念付きPUCTアルゴリズムの導入によって探索効率は劇的に改善されますが、盤面の良さの評価はいまだにランダム対局(プレイアウト)に完全に依存しています。プレイアウトでは盤面評価関数の設計をせずに盤面を評価できる良さはありますが、人間はランダムプレイするわけではないのでやはり評価精度に限界があります。よい盤面評価関数が利用可能なら本当はそっちを使いたいのです*5

そこで、Alpha Zeroではある盤面Sを入力としその盤面Sが最終的に勝利したか敗北したかをラベルとする教師あり学習によってニューラルネットワークを訓練し、盤面の評価関数V(s) とします。すなわち、V(s)が1.0に近いほど勝ちそうな盤面ということです。

これによりプレイアウト(ランダム対局)をニューラル盤面評価関数Vで置き換える*6ことで、盤面の良さの評価精度もまた大きく改善することとなりました。


AlphaZeroのMCTSまとめ

過去のモンテカルロ木探索の履歴をニューラルネットで近似し、方策の事前分布としてPUCT-MCTSへ導入
→ 筋の悪い行動の足切りを行いMCTSの効率化を実現

過去のある盤面の最終的な勝敗結果をニューラルネットで学習し、盤面評価関数Vとする
MCTSによる行動評価の高精度化を実現

この2つのネットワークを備えたMCTSがAlpha Zeroのモンテカルロ木探索であり、論文ではPV-MCTS (Policy Value - MCTS)と呼称されています。そして2つのネットワークを訓練するためのデータはすべて Selfplay (自己対局) によって生成されるため人間のナレッジを一切必要としません。これがAlphaZeroのすべてです、なんとシンプルなアルゴリズムでしょうか。


Policy-Value ネットワークについて

上述のMCTS説明では、過去のMCTS履歴を近似するニューラル方策関数(Policyネットワーク)と盤面評価関数(Valueネットワーク)は別のもののように書きましたが、実際はパラメータを部分共有する双出力ネットワーク(下図)になっています。

f:id:horomary:20210620182615p:plain:w200
パラメータを部分共有するネットワーク構造

このようなネットワーク構造はモデルフリー強化学習手法のA3Cでもお馴染みであり、視覚情報をそのまま入力するようなネットワークでは入力画像の表現抽出を担う部分をマルチタスク学習することでロバストな特徴表現が得られる(だろう)ために、強化学習エージェントのパフォーマンスが向上することが経験的に知られています。

rayで実装する分散強化学習 ①A3C(非同期Advantage Actor-Critic) - どこから見てもメンダコ


盤面状態の入力形式

盤面状態をどのような形式でニューラルネットに入力するかについて、alphazero論文の記載がややわかりづらいので補足説明します。囲碁も将棋もチェスもちゃんとルールわかってないので間違ってる場所あるかも。

f:id:horomary:20210620183524p:plain:w500
Alpha Zero論文より

囲碁
囲碁では19×19の平面(碁盤に対応)を17枚重ねることで盤面状態を表現しニューラルネットへの入力とします。17枚の平面の内訳は(自分の石配置を1としそれ以外を0とした平面, 相手の石配置を1それ以外を0とした平面)× 直近8step分 + (自分の色が黒ならすべて1で白ならすべて0の平面) となっています。

直近8step分を入力するのは同じ手の反復を検出するため、自分の色情報を含めるのは”コミ”への対応のため記述があります。

チェス・将棋:
チェスでは8×8の平面(チェス盤に対応)を119枚重ねることで盤面状態を表現しニューラルネットへの入力とします。119枚の平面の内訳は([自分のポーンの配置を1としそれ以外を0とした平面,、同様にルーク、同様にナイト、同様にビショップ、同様にクイーン、同様にキング]、[同様に相手の各駒の配置] + その盤面をすでに1回見たかどうかを示す平面 + その盤面をすでに2回見たかどうかを示す平面)× 直近8step分 + (自分の色が黒ならすべて1で白ならすべて0の平面) + チェスの特殊ルールを表現した平面×6 となっています。

チェスと将棋は単純な反復手の禁止に加えて、ゲーム中に同じ盤面が3回出たら流局のルールがあるためその盤面の出現回数を表現する平面があることに留意してください。

将棋も基本はチェスと同様ですが、持ち駒(prisoner)や成り駒によって表現が煩雑になっています。たとえばP1 piece(自分の駒)の14平面というのは、(記事を修正)自分の歩、香、桂、銀、金、飛車、角、王の配置を示す平面7枚に加えて、各駒を持ち駒として保持しているかを示す平面7枚。これに加えてP1 prisoner countの7平面で各駒を何枚持っているかreal valueで示しているのだと思います(ちがうかも)。 自分の駒の8種類[歩,香車,桂馬,銀,金,角,飛車,王]に成駒[成歩,成香,成桂,成銀,馬,竜]の6種類を加えたものとなります。

とくに将棋についてはこんな無理やりな盤面表現でも学習するんだから深層学習大したものだなと思います。


ニューラル方策関数の出力形式

同様にニューラル方策ネットワークの出力形についても解説します。

f:id:horomary:20210620183633p:plain:w500
alphazero論文より

囲碁
囲碁は19×19×1の平面で行動を表現します。平面の各要素は碁盤の対応する場所に石を置く確率を示します。

チェス・将棋:

チェスでは8×8の平面を73枚重ねることで行動を表現します。8×8で盤上のどの位置に置いてある駒を動かすかを表現し、73次元でどのように動かすかを表現します。たとえば下の盤面における方策出力をPとします(P.shape == (8, 8, 73) )。

f:id:horomary:20210620214213p:plain:w300
P[7, 2, 9]の行動

ここで、P[7, 2, 9] == 1.0 のとき、(7, 2)に対応する(1, c)に存在する駒(ビショップ)が73種類の動きのうち9番目の動きを行うことを意味します。12番目の動きは56種類(8方向×7マス)で表現されるQueen move、すなわち上下左右斜めの8方向に1-7マスのいずれかです。素直に上から{0, 1, 2, 3, 4, 5, 6}, 次に右斜めを {7, 8, 9, 10, 11, 12, 13} としていた場合は、P[7, 2, 9] == 1.0は(1, c)に置いてある駒を右斜め上に3つ移動させるという行動となります。


AlphaZeroの実装

ここからはオセロを題材としたAlphaZero実装の解説を行います。オセロはゲームの進行につれて行動の自由度が減っていくためMCTSと相性が良く、それほど苦労せずにsuperhumanな性能に到達することができます。

コード全文:
github.com


レーニングループ

AlphaZeroのトレーニングの流れはごく単純で、

  1. N回のselfplayによりデータ収集
  2. 収集したデータでネットワークを更新

のループを繰り返すだけです。

※自己対局(selfplay)の繰り返しはPythonの並列処理ライブラリrayで並列実行しています。

Pythonの分散並列処理ライブラリRayの使い方 - どこから見てもメンダコ



Selfplay(自己対局)

Selfplayもそれ自体は特筆することはありません。モンテカルロ木探索での試行回数に応じて行動決定することをゲーム終了まで繰り返します。


PV-MCTSの実装

Alpha Zeroの中心であるPV-MCTSの実装です。PV-MCTSそのものを理解していればプログラミング的には特段難しいところはありません。

盤面の評価には常に手番のプレイヤー視点で行うので、子盤面の評価値には-1を掛けて符号を逆転させることに注意してください。相手視点での最悪の評価値=自分視点での最良の評価値というわけです。


PVネットワークの実装

AlphaZeroのネットワーク構造はResNet-v1にpolicu headとvalue headが乗った構造となっています。せっかくなのでわりと論文に忠実に実装しましたが、オセロ程度なら正直もっとシンプルなネットワーク構造にしたほうが安定します。


学習結果

パフォーマンスの推移

GCPで 24-vCPU/64GB メモリ/NVIDIA Tesla P4 GPU のプリエンティブルVMインスタンスを使って4時間くらい学習した結果です*7

横軸がselfplayの回数、縦軸がテスト用NPCと20回対戦した時の勝率となっています。

テスト用NPCは70%の確率で貪欲手(もっとも多くの石が取れる手)を選択し30%の確率でランダムな手を選択するアルゴリズムです。2000回のselfplayを終えたころにはほぼこのNPCには負けなくなっていることがわかります。

f:id:horomary:20210618191950p:plain:w400


人間 vs. AlphaZero

先手(黒)が人間で後手(白)がAlphaZeroです。普通に負けました。

f:id:horomary:20210619115017g:plain:w400
先手(黒):人間, 後手(白):Alpha Zero


次:MuZero

AlphaZeroの後継手法であるMuZero(2020)は、ボードゲームのように状態遷移が明らかな環境以外でも使える手法です。具体的にはAlphaZeroのアルゴリズムブロック崩しができるようになりました。

horomary.hatenablog.com


参考

Deep Reinforcement Learning Hands-On

AlphaZero 深層学習・強化学習・探索 人工知能プログラミング実践入門

GitHub - PacktPublishing/Deep-Reinforcement-Learning-Hands-On: Hands-on Deep Reinforcement Learning, published by Packt

https://github.com/suragnair/alpha-zero-general

www.youtube.com


その他雑記

・Alpha Zeroの計算量について

対局時に本家AlphaZeroは一手につき800iterのPV-モンテカルロ木探索を行いますが、GPUによる推論は各iterでbatchsize=1の推論が一回なので、1手ごとにResNet-40相当のニューラルネットワークでの推論を800回行うことになります。

NVIDIAのベンチマークによると Jetson Nano (NVIDIAGPU搭載ラズパイ的な製品)でも224×224の画像を1秒で38回推論(ResNet-50, FP16)できるらしいので、最近のGPUが1枚あれば人間との対戦では困らなさそうです。

推論ではなくトレーニングを高速化するためには、SelfPlayを可能な限り並列実行しつつその裏でネットワークを訓練し続けるApe-Xアーキテクチャが(たぶん)理想です。とくにSelfPlayの並列数が学習速度のボトルネックになるでしょうから、少数の高価なGPUを用意するよりは、並列selfplayのための安価なGPUたくさん+勾配計算用の高性能GPU1枚 という構成がもっともコスパ良くトレーニングを高速化できると思います。

rayで実装する分散強化学習 ③Ape-X DQN - どこから見てもメンダコ


・これ強化学習

AlphaZeroのニューラルネット更新は普通の教師あり学習でありQ学習ではありません。状態行動価値Qや方策ネットワークなど強化学習でおなじみの用語が登場しますが、ネットワーク更新方法としてのQ学習や方策勾配法は一切使われません。しかし、環境と相互作用してデータを自ら収集するという点においてはAlphaZeroは強化学習であると言えます。

AlphaZeroが強化学習かどうかという議論自体はどうでもいいのですが、モデルフリー強化学習の事前知識が無いと非常にconfusingです。


モンテカルロ木探索という名称について

AlphaZeroのPV-MCTSではプレイアウト(ランダム対局)がニューラル盤面評価関数にとって代わられています。しかし、そもそもモンテカルロ木探索の ”モンテカルロ” はプレイアウト(ランダム対局)に由来するのでランダム対局をしないPV-MCTSが”モンテカルロ”であることには違和感があります。

これはUCTアルゴリズムのことを指して”モンテカルロ木探索”と言われることが多かったために、モンテカルロ要素であるプレイアウトが無くともわかりやすさのためモンテカルロ木探索と言っているのだと思います。(さすがにもう聞かないけど)携帯で写真をとることを”写メ”って言ってたようなものですね。


・ネットワークの改良について

AlphaZeroではネットワークにResNetを採用していますが、採用理由は開発当時にComputer Vision 分野での性能が良かったから程度で深い意味はないと思われます。CV向けのDLネットワークもここ数年でだいぶ進化しましたのでネットワーク構造を最先端のものに置き換えるだけでAlphaZeroのパフォーマンスが向上するかもしれません。

また、個人的に興味があるのは将棋分野でCNNをAttentionに置き換えることで性能が向上するのかどうかです。将棋では各ピクセル(マス目)にタテ・ヨコ・ナナメの情報を伝搬させることが囲碁よりも重要そうなので Axial attentionのような手法がうまくハマりそうな気がします。逆に囲碁はゲーム特性上CNNのままで問題なさそうな気もします。

[2003.07853] Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation


・オープニングブックについて

オープニングブックとは序盤の動きの定石を示すチェスの用語です。AlphaZeroのPV-MCTSでは過去のMCTSの履歴を近似する方策ネットワークを行動選択の事前分布(?)とするために、学習が進むにつれて特定のオープニングブックに収束していくと推測されます。

しかし、だからと言ってそれ以外のオープニングブックが明確に劣っているとは限らない、ということに注意してください。あくまでAlphaZeroの学習アルゴリズムがそのオープニングブックの深堀り研究(探索)を打ち切っただけです。たとえば、大器晩成型の複雑な戦術と先行逃げ切り型のわかりやすい戦術が同じ程度の強さだったとしても、AlphaZeroのアルゴリズムでは先行逃げ切り型に収束する確率が高い*8ことが予想されます。

よって、学習済みAlphaZeroに初動の制約をつけて追加訓練することにより、特定のオープニングブックに特化したAlphaZeroを作成することも可能なはずです。*9

Alのオープニングブックはついオラクル(神託)的に受け止めてしまいそうになりますが、少なくともAlphaZeroについては学習アルゴリズムをしっかり理解すれば「AIさんはそんな深く考えてないよ」という感じに思います。


*1:オセロで言えばカドをとられないようにするみたいな

*2:wiki囲碁 - Wikipedia)によると囲碁のゲーム木複雑性は10400

*3:勝率は適当です。実際にシミュレーションしたわけではありません

*4:証明は割愛

*5:将棋なんかでは人間による盤面評価関数の設計がそれなりにうまくいっていたらしいです

*6:モンテカルロ”の所以たるプレイアウトを排除したのにモンテカルロ木探索って呼称するのは混乱を招くと思う。ニューラルPUCTとかでいいんでは。

*7:メモリは16GBあれば十分だった

*8:探索と活用のトレードオフを決定するハイパラ調整次第ではあるけど

*9:このあたりのじゃんけん的な要素のがあるメタゲームをより深く研究したのがAlphaStarなのかもしれません

rayで実装する分散強化学習 ④R2D2

Ape-XにRNNを導入することでatari環境において圧倒的SotAを叩き出した分散強化学習手法 R2D2(Recurrent Experience Replay in Distributed Reinforcement Learningをtensorflow+pythonの分散並列処理ライブラリrayで実装します

Recurrent Experience Replay in Distributed Reinforcement Learning | OpenReview

rayで実装する分散強化学習
Pythonの分散並列処理ライブラリRayの使い方 - どこから見てもメンダコ
rayで実装する分散強化学習 ①A3C(非同期Advantage Actor-Critic) - どこから見てもメンダコ
rayで実装する分散強化学習 ②A2C(Advantage Actor-Critic) - どこから見てもメンダコ
rayで実装する分散強化学習 ③Ape-X DQN - どこから見てもメンダコ


はじめに

R2D2(Recurrent Experience Replay in Distributed Reinforcement Learning) とは Ape-XLSTMを導入した手法と表現できます。

DQNにRNNを導入すればエージェントのパフォーマンス向上するのでは? というのは(私ですら思いつく)ごく自然な発想ですが、学習の難しさからか目立った結果を残せていませんでした*1R2D2 (2018) はこのDQN+LSTMアーキテクチャにおける学習を安定化するテクニックを確立し、atari環境において圧倒的SotAを達成しました。RNNを導入するという発想は自然でも学習が困難で実現できていなかったという意味では、強化学習+CNNにおける学習安定化トリックを確立したDQNと似たような立ち位置とも言えます。

f:id:horomary:20210506013517p:plain:w600


RNNの必要性

Q学習はMDP(マルコフ決定過程)を前提としています。MDPとは乱暴に言うなら適切な行動決定に必要な情報はすべて現在の状態観測に含まれている、という仮定が成立するような系です。atari環境では現在の状態観測とはゲーム画面1フレームにあたりますが、しかし1フレームだけではアクション決定には情報がまったく不十分であることは明らかです。

たとえばBreakout(ブロック崩し環境)では1フレームだけの観測情報ではボールの進行方向がわかりません。

f:id:horomary:20210125000147p:plain:w400
1フレームではボールの進行方向がわからない

適切な行動選択のためにはより過去の観測情報も考慮する必要があります。このような系をPOMDP(部分観測マルコフ決定過程)と言います。そこで、DQN(2013)では直近4フレームの観測を重ねてQネットワークの入力とすることで、atariの多くのゲームをPOMDPからMDPっぽい系にすることに成功し、エポックメイキングな手法となりました。

とはいえ、DQNで考慮できる過去とは所詮直近4フレームまでです*2。 直近4フレームはボールの進行方向を判断する程度なら十分ですが、たとえばMs. Pacmanにおいて”そろそろパワーエサ状態が切れそうだな”というような数秒スケールの判断を適切に行うには全く不十分です。 この課題に対する有望なアプローチは Deep Q-networkへ時系列情報を考慮できる Recurrent Neural Network (RNN, 再帰ニューラルネットワーク) を導入することです。R2D2ではRNNファミリーの中でもよく使われるLSTMを採用しています。


RNN(LSTM)の困難

前述の通りPOMDP打破のためにRNNを使うというアイデアは何ら独創的なものではないので、過去にも同様の検討がされてきましたが華々しい結果とはなっていませんでした。これはRNNに関する2つの困難により学習が不安定化するためであると考えられます。


困難①:経験再生時の初期LSTM状態をどうするか?

LSTMの入力は3つであり(下図)、すなわち 入力  \displaystyle{
x_{t}
}, 1step前の出力  \displaystyle{
h_{t-1}
}, そして1step前のセル記憶  \displaystyle{
c_{t-1}
} です。LSTMを持つネットワークで推論するときには当然これらすべてを入力する必要があります。また以下ではc,hをまとめてLSTM状態と呼称します。

※エピソード開始時、つまりt=1の  \displaystyle{
h_{0}
},  \displaystyle{
c_{0}
} はゼロ行列です。

f:id:horomary:20210508235320p:plain:w400
LSTMの構造(Long short-term memory - Wikipedia より)

R2D2では連続する40遷移のセグメントを1サンプルとしてreplay bufferに格納します。ここで、通常のDQNのように遷移情報として(s, a, r, s')だけを蓄積していると、セグメントが再生されたときに(そのセグメントからエピソード開始される場合を除いて)、 \displaystyle{
h_{t-1}
}および  \displaystyle{
c_{t-1}
} が無い=LSTMの初期状態が無いため困ってしまいます。この問題へのもっとも単純な対応策は、エピソード全体を保存しておいてt=0からunrollする(タイムステップを進めていく)ことで対応するセグメントへの初期入力を作ることです。この方法は正確なLSTMの初期状態が得られる一方で、しかし計算量が酷いことになるので実用的ではありません。

そこでR2D2が採用しているのがStored stateトリックです。このトリックでは経験バッファにセグメントの初期LSTM状態 \displaystyle{
(c_{t-1}, h_{t-1})
} も保存しておくことで、セグメントが再生されたときは保存されている初期LSTM状態 \displaystyle{
(c_{t-1}, h_{t-1})
} をLSTMへの初期入力として使用し、t=0からの愚直なunrollを回避します。


困難②:ネットワーク更新によるStored LSTM state の陳腐化

Stored state トリックだけでは経験再生時の初期LSTM状態の問題は解決していません。なぜならば保存されているLSTM状態は過去のQネットワークによって計算されたLSTM状態であり、現在のQネットワークでLSTM状態を計算しなおすと異なる値になるはずだからです。

この保存されたLSTM状態の陳腐化問題を軽減するためにR2D2で提案されたのがBurn-inトリックです。これはStored state トリックで保存された初期LSTM状態を初期入力に使うものの、Stored stateによる入力に近いところでは実際のLSTM Stateとの乖離が大きいと予想されるため、しばらくタイムステップを進めてから学習に使うことで鮮度の低いLSTM状態の問題を軽減しようというアイデアです(下図)。よってburn-inを最大限長くした場合は上述したt=0からの愚直なunrollと同じになります。

f:id:horomary:20210509220012p:plain:w600
burn-inフェーズはtimestepを進めるだけでネットワーク更新に使わない


余談ですが日本語を当てるなら、burn-inの意味・使い方・読み方 | Weblio英和辞書 に例文として記載されている”ならし運転”がしっくりきます。

f:id:horomary:20210510232659p:plain:w700


LSTM+大規模分散学習

上述したStored state & Burn-in トリックを使っても古すぎるセグメントの初期LSTM状態を再現することは難しいと考えられるため、経験バッファにはできるだけ鮮度の高い(on-policynessの高い)セグメントが蓄積されていることが望ましいはずです。

単純には経験バッファのサイズを小さくすれば全体の鮮度が高まることが期待できますが、そうするとサンプル多様性が失われ学習が不安定化することが予想されます。この問題を力押しで解決するのがApe-X で提案された大規模並列分散マルチ方策学習です。分散並列による圧倒的なサンプル投入速度とマルチ方策(異なる探索率ε)エージェントによって経験バッファ内のサンプル多様性を確保します。

ただし、分散並列の効果についてApe-X論文のFig.6でやってたような検証実験が無いので確実なところはわかりません。


その他の重要なトリック

R2D2はLSTMに目が行きますが、パフォーマンスに大きな影響を与えうる(Ape-Xには無かった)トリックがいくつか追加されています。

報酬クリッピングの廃止

atari環境ではいかなる報酬でも (-1, 0, 1) にクリップする reward clippingトリックが長らく使われてきました。これは多くのゲームで学習を安定化させる有用なトリックである一方、一部のゲームの学習を困難にしてしまいます。

[1805.11593] Observe and Look Further: Achieving Consistent Performance on Atari ではその分かりやすい例として、

For example, the agent no longer differentiates between striking a single pin or all ten pins in Bowling.
ー たとえば、agentはボーリングゲームでピンを1本倒すことと10本倒すことを区別できなくなります。

と述べています。もう少し親しみのあるゲームで言えば、Pacmanで通常クッキーを食べるのもオバケを倒すのも同じ+1点になってしまいます。そこで、同論文ではこの問題低減のためによりソフトな報酬(というか target-Q の)スケーリング関数を提案しており、R2D2でもこれを採用しています。


f:id:horomary:20210509172158p:plain:w600
R2D2論文より

f:id:horomary:20210509171851p:plain:w600
Observe and Look Further より


割引率 γ=0.997

これも同様に [1805.11593] Observe and Look Further: Achieving Consistent Performance on Atari で報告されていることですが、R2D2ではγ=0.997という従来(γ=0.99とか)よりかなり高い割引率を採用することでパフォーマンスを向上させています。ablation studyは Fig.7を参照。


Life loss as episode end の廃止

残機を使い切ることではなく、残機が1減ることをエピソード終了と見なすトリックは、報酬クリッピングと共にatari環境のヒューリスティックスとして長らく使われてきましたがR2D2ではこれを廃止しています。ablation studyを見るとこれによって必ずしもパフォーマンスが向上するわけではないようですが、少なくともヒューリスティクスを一つ排除してSotAを達成したことは重要な成果です。

f:id:horomary:20210509222401p:plain:w600
life loss (roll) が従来のやり方


R2D2の実装(CartPole-v0)

ここからはtensorflow+rayによる実装レベルの解説です。まずは単純なCartPole環境でR2D2の実装を確認してみます。ただし、ここでは簡単のためにDueling-network, n-step return, およびvalue-rescalingは省略しています。分散学習部分はApe-X DQN とほぼ同じなので過去記事も併せて参照ください horomary.hatenablog.com

コード全文:
github.com

分散学習の流れ

前述の通り、分散学習の流れ自体はApe-Xと何も変わりません。


R2D2のネットワーク構造

DQNアーキテクチャのDense層がLSTMに変更されただけです。このネットワークはLSTM状態(c, h)に加えて前のアクションの入力も要求することに留意ください。前ステップのアクションはonehot化したうえでconv層からの出力とconcatします。

※論文ではさらに前ステップのrewardも入力すると書いていますが省略しました。


Actor

各セグメントはepisode-endを跨がないという設定から、rolloutは1episode区切りにすると実装が楽です。1episode分のrolloutが終わったらセグメントの切り出しを行い、優先度付き経験再生のための初期優先度を算出したうえでセグメントを送信します。セグメントへの優先度の割り当てはR2D2論文にて提案された方法です。


ActorはEpisodeBufferに1episode分の遷移を蓄積します。


Replay

Actorから受け取ったセグメントを蓄積するSegmentReplayBufferは、対象がセグメントであること以外はApe-Xとまったく同じ優先度つき経験再生バッファなので掲載を省略します。

rayで実装する分散強化学習 ③Ape-X DQN - どこから見てもメンダコ


Learner

Learnerは16セグメントで構成されるミニバッチを16セット受け取りネットワークを更新します。Actorで初期優先度割り当てとほぼ同じ処理ですが、ターゲットネットワーク(target_q_network)はオンラインネットワーク(q_network)とは別にburn-inする必要があるため計算量が増えています。

学習結果

CartPoleでLSTM使う意味はほぼありませんが、問題なく学習出来ています。

f:id:horomary:20210511005946p:plain:w400
x軸:Leaner.update_networkが呼ばれた回数


R2D2の実装(Breakout)

Breakoutでの実装はCartPoleのコードに

  • N step-return

  • Dueling network

  • Value function rescaling

  • RAM節約のためのsegment圧縮

を追加したものとなっていますが、Ape-Xと同様にコードを直接掲載するには多すぎるので結果だけ示します。 詳細はGithubを参照ください。

github.com

学習結果

BreakoutDeterministic-v4環境(ブロック崩し)を、GCPで 24-vCPU/128GB RAM/GPU T4 のプリエンティブルVMインスタンスを使って24時間学習しました。actor数は論文では256であるのに対してここでは20と圧倒的に少ないですが、なんとか正常に学習出来ているっぽくはあります。プリエンティブルインスタンスの24時間制限によりパフォーマンスが急激に向上してきたところで時間切れとなってしまいました。

f:id:horomary:20210514004442p:plain:w600

速度パフォーマンスは論文記載の20%程度しかでていなかったので単純計算でR2D2論文の5時間時点相当くらいの更新回数になっています。プロファイリングしたところLearnerのネットワーク更新がボトルネックになっていたので、GPUをもっと性能が良いものにするかGPU利用効率の良い実装を考える必要があります。


次: Agent57

R2D2をベースに内発的報酬を追加し、さらにエージェントへの方策割り当てをバンディット問題と捉えることでついにすべてのatariゲームで人間超えを達成した手法。そのうち。

*1:https://arxiv.org/abs/1507.06527

*2:NoFrameSkip環境でなければ実質16フレーム

深層分布強化学習 ③FQF: Fully Parameterized Quantile Function for Distributional RL

単体でRainbow越えを達成した深層分布強化学習手法FQFをtensorflow2で実装します。


前提手法
horomary.hatenablog.com

horomary.hatenablog.com


はじめに

現実のほとんどの環境はランダム性を内包するため、状態価値は分布であると考えるのが妥当です。しかし、典型的な状態価値ベースの強化学習では状態価値分布の期待値のみの近似を目的とするため、状態価値が明示的に分布としてモデル化されることはありません。

これに対して、深層分布強化学習では状態価値を明示的に分布として深層学習により近似し、状態価値分布から状態価値分布の期待値を算出するというアプローチをとります。このような深層分布強化学習はリスク考慮型方策が可能になるなどいくつかのメリットがありますが、その最大の利点は状態価値を明示的に分布としてモデル化することは(なぜか)パフォーマンスの向上に寄与する、という点です。

状態価値分布のモデル化によりなぜエージェントのパフォーマンスが向上するかは(私の知る限りは)理論的に解明されていないものの、分布の近似がQネットワーク訓練のためのよい補助タスク(Auxiliary Tasks, 詳細はUNREALを参照)になっているのだろうと推察されます。

このような深層学習と分布強化学習の組み合わせの有用性は、 Categorical DQN, C51 論文から注目されるようになり、その後もQR-DQN, IQNなどいくつかの改良手法が提案され続けています。本記事で実装を紹介するFQF(Fully Parameterized Quantile Function for Distributional RL)もそのひとつであり、特筆すべきはついに単体でRainbow超えを達成したことです。

f:id:horomary:20210417161518p:plain:w500
論文Table1より

論文:
[1911.02140] Fully Parameterized Quantile Function for Distributional Reinforcement Learning

Microsoftの実装・解説:
Finding the best learning targets automatically: Fully Parameterized Quantile Function for distributional RL - Microsoft Research

github.com


C51 → QR-DQN → IQN

FQFの話を始める前にこれまでの深層分布強化学習の各手法がどのようなアプローチで分布をモデル化してきたのかを確認しましょう。

Categorical DQN (C51) では、素直にカテゴリ分布によって状態価値の確率分布を近似します。このアプローチは一定の成功を収めたものの、分布の最大値/最小値の設定が重要なハイパラになっていたり、ベルマンオペレータの適用で生じるビン幅のずれの修正処理が煩雑だったりといくつかの欠点を抱えていました。

f:id:horomary:20210328225227p:plain:w500
C51: 状態価値の確率質量を予測する

QR-DQNでは、状態価値分布の分位点を予測する=状態価値分布の累積分布関数の逆関数を近似するというアプローチによりC51の残したいくつかの課題を解決しました。

f:id:horomary:20210329010227p:plain:w500
QR-DQN:分位点を予測する

IQNでは、QR-DQNはあらかじめ設定された均等幅の分位しか予測しないため真の状態価値分布を近似することができないという課題に対して、Qネットワークに状態sとともにランダムサンプリングされた分位τを与えて、対応する分位点を予測させるIQNアーキテクチャを提案しました。訓練済みのIQNネットワークは任意の分位τについて分位点を予測することができるので、十分に多くの数の分位τをサンプリングすれば滑らかな状態価値分布を近似することができるはずです。

f:id:horomary:20210418232351p:plain:w600
IQN論文Fig1より


FQFとは:いい感じのτを提案する機構付きのIQN

IQNで提案されたQ関数に任意の分位τの分位点を予測させるアーキテクチャでは、十分に多くの分位τをQ関数に入力することで実質的に状態価値の累積分布関数の逆関数  \displaystyle{
F_{z}^{-1}
}を近似することができます。しかし、IQNアーキテクチャでは与えられる分位τの数に応じてニューラルネットワークのパラメータ数が増え学習が不安定になるため、可能ならばできるだけ少ない分位の予測で済ませたいところです。

少ない分位で  \displaystyle{
F_{z}^{-1}
}をうまく近似するには、 累積分布関数の形状(=状態価値分布の形状)に応じていい感じにτを選ぶことが必要です(下図)。そこで、状態sに応じていい感じの分位τセットを提案するネットワークをIQNに追加したのがFQFであると理解できます。

f:id:horomary:20210419204245p:plain:w600
論文Fig1, どちらも6つの分位点だがうまくτを選べばWasserstein距離を小さくできる

FQFとは具体的には下図(論文著者の解説記事より転載)のようになります。このFQFアーキテクチャから分位提案ネットワーク(frction proposal network)が除去されるとQR-DQNとなります。また、分位提案ネットワークが一様分布からのサンプリングに置き換えられるとIQNとなります。CNN層 (future network) & Quantile function network の訓練と 分位提案ネットワークの訓練は独立して別のロス関数で行うことに留意してください。詳細は後述。

f:id:horomary:20210419201956p:plain:w600
Finding the best learning targets automatically: Fully Parameterized Quantile Function for distributional RL - Microsoft Research


FQFネットワークの実装

※この実装はMicrosoftによる公式実装 を参考にしています。基本のトレーニングループについては基本とDQNと変わらないので割愛します。

コード全文:

github.com


FQFアーキテクチャ

上のアーキテクチャ図に示したようにFQFは複数のネットワークで構成され、そのままでは扱いづらいのでそれらをとりまとめるFQFモデルを実装します。(各構成要素についての詳細は後述。)

このモデルはまず入力として受け取った状態sをFeature Networkに通して特徴抽出を行います。さらに抽出された状態特徴(state_embedded)を分位提案ネットワークに入力することにより分位 \displaystyle{ \tau }のセットおよびその中点 \displaystyle{ \hat{\tau} } を提案させます。

たとえばnum_quantiles=4のときに  \displaystyle{ \tau }=[0, 0.2, 0.6, 0.9, 1.0]のように分位τが提案された場合は、この中点 \displaystyle{ \hat{\tau} }=[0.1, 0.4, 0.75, 0.95] となります。このうち、  \displaystyle{ \hat{\tau} } をQuantile function networkに入力し、対応する分位点を出力します。 \displaystyle{ \tau } については分位提案ネットワークの更新にのみ使用します。

Feature network:特徴抽出ネットワーク

状態Sから特徴抽出するネットワークですが、これはただのDense層を除いたDQNアーキテクチャなので解説不要ですね。入力がDQN論文と同じ(84, 84, 4)であれば出力は(3136,)となります。


Fraction proposal network:分位提案ネットワーク

FQFのキモである状態特徴を入力として分位τを提案するネットワークです。出力する分位τが 単調増加 かつ 0≦τ≦1 であることを保証するために、softmaxを取った後に累積分布を計算します。さらに、0.01や0.99など極端に0 or 1に近い数の提案を許可すると学習が不安定化したため、この実装ではtf.clip_by_valueで提案できる分位を0.1から0.9の範囲に制限しています。


Quantile function network:分位点予測ネットワーク

構成要素で一番ややこしいのが。状態特徴state_embeddedと提案分位quantiesを入力として、提案分位に対応する分位点を予測する分位点予測ネットワーク(Quantile function network)です。難解ではなくややこしいだけです。

状態特徴state_embeddedと提案分位quantiesを入力として分位点を予測するネットワーク構造は(たとえばDDPGのように入力直後にconcatするなど)いろいろと考えられますが、FQFでは IQN論文 で提案されたものをそのまま使用します。すなわちCosine embedding(下式)によりquantilesの次元を状態特徴state_embeddedと同じ3136次元まで増幅した後、state_embeddedとの要素積をとります。

f:id:horomary:20210419224330p:plain:w400

Cosine Embedding周りのshape操作が煩雑でわかりにくいのでshapeの遷移図を描きました。

f:id:horomary:20210419232943p:plain
Cosine Embedding (batch_size=1, N=4の場合)


FQFネットワークの更新

前述の通り、(Feature network + Quantile function network ) と (Fraction proposal network) は別のロス関数で独立した訓練を行います。

(Feature network + Quantile function network )のネットワーク更新はQR-DQNとほぼ同じです。ただし、ベルマンオペレータの適用において、オンラインネットワーク*1が提案した分位τおよびオンラインネットワークが出力したターゲットネットワークでも利用していることにだけ注意してください。


分位提案ネットワークの更新

前述の通り、分位提案ネットワークの役割はいい感じの分位τを提案することです。そして分布強化学習におけるいい感じの分位とは2つの分布間のWasserstein距離が最小化されるような分位τです。

f:id:horomary:20210422001845p:plain:w500
Wasserstein距離を小さくするようにτを提案したい

よって、安直にはWasserstein距離をロス関数として分位提案ネットワークを訓練したいところですが、しかしWasserstein距離は直接計算することが現実的ではないため*2 このアプローチは不可能です。

代替案として、FQF論文ではWasserstein距離を直接計算するのは困難だけども、提案分位τについての1-Wasserstein距離の微分なら近似的に計算できるよ、ということを証明(Appendix: Proof for proposition 1)しました。

f:id:horomary:20210422002606p:plain:w600
τについての微分なら計算できる

分位提案ネットワークのパラメータをθとすると、θについての1-Wasserstein距離の微分は連鎖律を利用して、

 \displaystyle{
 \frac{\partial W_1}{\partial \theta} = \frac{\partial W_1}{\partial \tau_{i}} \frac{\partial \tau_{i}}{\partial \theta}
}

と表せます。ここで  \displaystyle{
 \frac{\partial W_1}{\partial \tau_{i}}
} は論文が示す計算式によって、また  \displaystyle{
 \frac{\partial \tau_{i}}{\partial \theta}
} はtensorflowの自動微分によって計算できるので分位提案ネットワークを訓練できるようになりました。

実装上の注意として論文にも記載があるのですがtensorflowで明示的に連鎖律を使用するときは、tensorflow1.Xではtf.gradient(taus, network_params, grad_ys=dw_dtau) のようにgrad_ys引数を利用します。*3。一方、tensorflow2.X系でwith GradientTape() as tapeを使う場合は引数名が変わり tape.gradient(taus, network_params, output_gradients=dw_dtau) とします。

ただし、論文には記載されてませんが Mictosoftの公式実装 のREADMEでは  \displaystyle{
 \frac{\partial W_1}{\partial \tau_{i}}
} の二乗をロス関数として使うことを推奨しています*4。こちらの方が実装がわかりやすいので下の例ではL2ロスを採用しています。


学習結果:Breakout環境

BreakoutDeterministic-v4環境(ブロック崩し)において、GCPのn1-standard-4(4-vCPU, 15GBメモリ) + GPU K80 のプリエンティブルVMインスタンスを使って24時間学習を行い、妥当な性能が得られることを確認しました。アーキテクチャが複雑なのでやはり計算処理が重く、速度パフォーマンスはQR-DQN比較でざっくり60%程度となりました。

f:id:horomary:20210421233630p:plain:w600


*1:target networkじゃないほう

*2:このあたりの議論はC51論文を参照

*3:tensorflow - tf.gradients, how can I understand `grad_ys` and use it? - Stack Overflow

*4:Readme.md, BugFixedの項:It is recommended to use the L2 loss on gradient for probability proposal network

深層分布強化学習 ②QR-DQN

QR-DQNをtensorflow2で実装します。
元論文: [1710.10044] Distributional Reinforcement Learning with Quantile Regression

前記事:
horomary.hatenablog.com

参考:

https://physai.sciencesconf.org/data/pages/distributional_RL_Remi_Munos.pdf

Going beyond average for reinforcement learning | DeepMind

Quantile regression - Wikipedia


はじめに

DeepMindDQNに代表される典型的なQ学習においては、状態行動価値Q(s, a)の期待値関数近似します。

一方、前記事で実装を紹介したCategorical DQN ([1707.06887] A Distributional Perspective on Reinforcement Learning)は、状態行動価値Q(s, a)を明示的に確率分布Z(s, a)としてモデル化することを提案し、これにより大きくパフォーマンスが向上することを当時のatari環境のSotAという結果で示しました。

本記事で紹介するQR-DQNはCategoricalDQNの直接の後継手法*1です。Categorical DQNでは価値分布をそのままカテゴリ分布で近似しようとしたのに対し、QR-DQNは状態行動価値分布Z(s, a)の分位点を近似するというアプローチによりCategorical DQNの残した多くの課題を解決しました。


Categorical DQNの分布モデル

分布強化学習でモデル化したい真の(ground truth?)状態行動価値分布Z(s, a)は連続分布であるはずですが、連続分布は大変扱いづらいのでCategorical DQNではその名の通りZ(s, a)をカテゴリカル分布で近似します。Categorical DQN論文ではカテゴリカル分布のビン数=51の場合がatari環境でもっとも性能が良かったので、この場合をとくにC51と呼称しています。*2

f:id:horomary:20210328225227p:plain:w600
カテゴリ分布によるZ(s, a)のモデル化

状態行動価値分布Z(s,a)へのベルマンオペレータの適用は下図のように行います。rewardによって分布が水平スライドし、割引率によって分布が縮むようなイメージです。※見た目にわかりやすいようにreward=7, 割引率γ=0.6という極端な値で作図していることに留意ください。

f:id:horomary:20210328230353p:plain:w600
分布ベルマン方程式

状態行動価値分布をCategorical分布で近似するC51のアプローチはいくつかの大きな問題を抱えています。

1つはベルマンオペレータの適用によって分布のビン幅がずれることです。上図でも元の分布Z(s,a)のビン幅である赤破線からTZ(s, a)のビン幅はずれてしまっていることがわかります。よってCategorical DQNではこのずれたビン幅を無理に再割り当てして修正する処理*3が必要なのですが、この処理の実装がかなり煩雑&やや重い*4です。

別の問題はカテゴリカル分布では有限領域しか扱えないため、分布の最大値/最小値の設定が非常に重要なハイパーパラメータになってしまうことです。 この問題は学習初期と学習終盤で報酬のスケールが大きく変化するような場合には顕著な問題となります *5。また、最大/最小幅を大きくとった場合はカテゴリカル分布の性質上ビンの数を十分に多くしないと細かな分布の形状を捉えにくいという問題も生じます。

さらにCategorical DQNの最大の問題は、Categorical DQN論文で証明された"p-Wasserstein距離を分布間の距離尺度に設定するとベルマンオペレータが縮小写像である"という理論とCategorical分布のKL距離をロス関数とする実装にギャップがあることです。大雑把には、確率的勾配降下法でWasserstein距離をロス関数にすると biased gradient になるので、言っていることとやっていることが違うのだけどKL距離をロスにするヒューリスティックな実装にしたよ、という感じです。( 前記事を参照)


QR-DQNの分布モデル

Categorical DQNではZ(s,a)をそのままカテゴリカル分布で近似しましたが、QR-DQNではZ(s,a)の累積分布関数Fを近似します。※Z(s,a)とその累積分布関数Fは1対1変換であるのでどちらを近似してもよいことに留意。

f:id:horomary:20210329004825p:plain:w500
Z(s,a)とその累積分布関数F

ここで、QR-DQNのポイントは累積分布関数Fそのものではなく、Fの逆関数をカテゴリカル分布で近似することです。

f:id:horomary:20210329010227p:plain:w500
各ビンは分位点と解釈する

したがって、Categorical DQNでは各ビンの値はZ(s,a)がある状態行動価値θをとる確率でしたが、QR-DQNでは各ビンの値はZ(s,a)の τ%分位点 (Quantile)の値となります。あえてZ(s,a)の累積分布関数の逆関数をカテゴリカル分布で近似することにより、前述したCategorical DQNの問題点を解消することができます。

まず、Categorical-DQNではx軸のカテゴリカル分布でZ(s,a)を近似していましたが、ベルマンオペレータの適用によってビン幅がずれるため煩雑なビンの再割り当て処理(projection)が必要でした。一方、QR-DQNではカテゴリカル分布で価値分布の累積分布関数をy軸にそってモデル化する(つまり累積分布関数の逆関数を近似)ためビン幅ずれ問題に煩わされることは無くなりました(下図)。

f:id:horomary:20210331230243p:plain:w500
Z(s, a)とTZ(s, a)でquantileは当然変わらない

また、カテゴリ分布の最大値/最小値の設定に悩まなくてよくなりました。なぜならば累積分布関数の逆関数は0-1の有限区間で定義される関数であるためです。

さらに、Categorical DQN論文の最大の残課題は理論的にはWasserstein距離を最小化したいのだけれども、Wasserstein距離をそのままSGDのロス関数にするとBiased gradientとなってしまうので仕方なく分布間のKL距離を最小していたことです( 前記事を参照)。

そこでQR-DQNではターゲット分布の分位点を予測することが1-Wasserstein距離を最小化することを示し、このためにSGDのロス関数に 分位点回帰を使用することを提案しました。これにより直接Wasserstein距離をロス関数として使用することを回避してWasserstein距離を最小化できます

f:id:horomary:20210401001211p:plain:w500
論文Fig.2より:分位点を予測することが1-Wasserstein距離を最小化するになることの視覚的な説明


分位点回帰

分位点回帰とそのロス関数を簡単に説明します。

分布 \displaystyle{ Z
}の70%分位点を予測することを考えます。この分布 \displaystyle{ Z
}の 10%, 30%, 50%, 70%, 90% 分位点を  \displaystyle{ \hat{Z}
} = [-1.23, -0.29, 0. , 0.29, 1.23] とします。*6

f:id:horomary:20210401232756p:plain:w500
ターゲット分布Z

70%分位点の予測値をθと置くと、論文より分位点ロスは下式となります。
 \displaystyle{ \delta_{u \lt 0}} \displaystyle{ (Z - \theta) \lt 0 } のとき1、そうでなければ0という意味です。

f:id:horomary:20210401232954p:plain:w500
分位点ロス関数

この分位点ロスの視覚的な説明が下図です。ポイントは分布 \displaystyle{ \hat{Z} } のすべてのサンプルについて計算した分位点ロス(赤破線で表示)の平均が最終的な分位点ロスであることです。直感的には、予測値θより大きい値との距離総和と予測値θより小さい値との距離総和を予測したい分位点に応じてバランスしているという感じです。

f:id:horomary:20210402001048p:plain:w500
70%分位点(τ=0.7)を予測したい場合


分位点Huberloss

この分位点ロスをニューラルネットのロス関数にそのまま使うとu=0付近で滑らかでないため学習が不安定化するらしく、論文ではQuantile HuberLossを提案しています。と言っても |u|≦1のときは \displaystyle{ \rho_{\tau}(u) = 0.5u^{2}(\tau - \delta_{u \lt 0}) }、|u|>1のときは \displaystyle{ \rho_{\tau}(u) = (u-0.5)(\tau - \delta_{u \lt 0}) } とただのHuberLossに分位点重みがかかるだけのなので特に難しくはありません。


QR-DQNの実装

Breakout (ブロック崩し)環境向けにQR-DQNを実装します。 ネットワーク構造とネットワーク更新以外はオリジナルのDQNと完全に同じです。

horomary.hatenablog.com


QRネットワークの実装

ネットワーク構造自体はCategorical DQNとまったく同じです。構造は同じですが解釈が違うだけです。

アクション選択もCategorical DQNの場合と同様に価値分布Z(s, a)の平均値が最も大きいactionを選択します。ここで、分位の刻み幅を均等にとっている場合は、E[Z(s,a)]は分位点の単純平均と一致することに留意しましょう。


分位点ロスによるネットワーク更新


やってることは上述の分位点回帰の説明と同じです。しかし、上述の例では70%分位だけを計算していましたがQR-DQNでは設定されたすべての分位についてそれぞれ分位点ロスを計算する必要があるのでけっこう煩雑です。そこで、やってことがわかりやすいようにbatchsize=1の場合を下記に示しておきます。

import numpy as np
import tensorflow as tf

N = 5  #:分位の分割数
quantiles = np.array([0.1, 0.3, 0.5, 0.7, 0.9], dtype=np.float32)

target_quantile_values = np.array([23, 35, 42, 56, 76], dtype=np.float32).reshape(1, -1)
quantile_values = np.array([20, 32, 45, 50, 70], dtype=np.float32).reshape(1, -1)

target_quantile_values = tf.repeat(target_quantile_values, N, axis=0)
quantile_values = tf.repeat(quantile_values.reshape(-1, 1), N, axis=1)

td_error = target_quantile_values - quantile_values
indicator = tf.where(td_error < 0, 1., 0.)

#: k=1.0の場合のhuberloss
huberloss = tf.where(tf.abs(td_error) < 1.0, 
                     0.5 * tf.square(td_error), 
                     tf.abs(td_error) - 0.5)
quantiles = tf.repeat(quantiles.reshape(-1, 1), 5, axis=1)
quantile_weights = tf.abs(quantiles - indicator)

quantile_huberloss = quantile_weights * huberloss
total_quantile_huberloss = tf.reduce_mean(quantile_huberloss, axis=1, keepdims=True)
loss = tf.reduce_sum(total_quantile_huberloss, axis=0)


Breakoutでの学習結果

BreakoutDeterministic-v4環境(ブロック崩し)において、GCPのn1-standard-4(4-vCPU, 15GBメモリ) + GPU K80 のプリエンティブルVMインスタンスを使って24時間学習した結果十分なパフォーマンスを確認できました。

Breakoutはatariの中では比較的単純な環境であることを考慮して、Adamの学習率は論文より高め設定のlr=0.00025(論文記載はlr=0.00005) & 分位点の刻み数Nを論文より小さめ設定のN=50(論文記載は分位点の刻み数N=200)にしています。

f:id:horomary:20210401004514p:plain:w500

f:id:horomary:20210401004556p:plain:w500

コード全文: github.com


次:FQF

horomary.hatenablog.com


*1:Bellemareさんが著者リストに入ってる

*2:Categorical 51でC51。もしDistributional 51でD51命名されてたとしてもやっぱり蒸気機関車

*3: 論文ではprojectionと呼称

*4:とくにバッチサイズ大きいと処理が重い。このあたりの煩雑さがパフォーマンスは優秀なのにApeX-DQNではハブられた理由なのではないかと邪推している

*5: atari環境ではreward clippingが有効なのであまり問題になりません

*6:分かりやすさのため分位を明示しているが、確率密度に従ってサンプリングされていれば分位が分かっている必要はない

rayで実装する分散強化学習 ③Ape-X DQN

深層強化学習における超大規模分散並列化の有用性を示したApeX-DQN(Distributed Prioritized Experience Replay)をtensorflow2とrayで実装します。手法の構成要素自体はRainbowとだいたい同じであるため、本記事の焦点は分散並列学習の実装です。

f:id:horomary:20210227174718p:plain:w1000

rayで実装する分散強化学習
Pythonの分散並列処理ライブラリRayの使い方 - どこから見てもメンダコ
rayで実装する分散強化学習 ①A3C(非同期Advantage Actor-Critic) - どこから見てもメンダコ
rayで実装する分散強化学習 ②A2C(Advantage Actor-Critic) - どこから見てもメンダコ


前提手法:
DQNの進化史 ①DeepMindのDQN - どこから見てもメンダコ
DQNの進化史 ②Double-DQN, Dueling-network, Noisy-network - どこから見てもメンダコ
DQNの進化史 ③優先度付き経験再生, Multi-step learning, C51 - どこから見てもメンダコ
DQNの進化史 ④Rainbowの実装 - どこから見てもメンダコ


はじめに

[1803.00933] Distributed Prioritized Experience Replay
Distributed Prioritized Experience Replay | OpenReview

Distributed Prioritized Experience Replay、あるいはApe-X*1はその名の通り 優先度付き経験再生を大規模分散並列学習に対応させた手法です。Ape-XはDDPGにもDQNにも適用可能な手法ですが、後者に適用された場合にはApe-X DQNと呼称されます。Ape-X DQNはオフポリシー強化学習の大規模分散並列化は訓練時間*2を短縮するだけでなく、パフォーマンスの向上にも寄与することを当時のatari環境における圧倒的SotAで示しました。

f:id:horomary:20210227130719p:plain:w500
並列化しているので訓練時間が短縮されるのは当然であり、パフォーマンス向上が著しいことのインパクトが強い

DQNはオフポリシー手法であるので、収集した遷移情報(経験)を何度でも再学習してよいはずなのですが、Ape-Xはそのような循環式の経験再生よりも、源泉かけ流し的な贅沢な経験再生の方がパフォーマンスが良くなるということを示しました。この発見が以降の強化学習手法における大規模分散並列学習トレンドを加速していくこととなります*3


Ape-X DQN の概要

Ape-Xの基本コンセプトは遷移情報を収集するプロセス、遷移情報を蓄積を担うプロセス、および勾配計算してネットワークを更新するプロセスを完全に分離することによる効率化です。この分散学習アーキテクチャはFig.1に示されており、Learner, Actor, Replay という3つの主要な役割があることが分かります。

f:id:horomary:20210227144313p:plain:w600
apexの分散並列アーキテクチャ

Learnerの役割

Leanerプロセスには1CPU, 1GPUが割り当てられます。

Leanerの役割はReplayから供給されるミニバッチでひたすらにQネットワークを更新しつづけること、およびActorからのネットワーク重み同期要求に応じることです。学習速度を最大化する(≒GPU稼働率を最大化する)ためにReplayからのミニバッチの供給を途切れさせないことが重要となります。

Actorの役割

各Actorプロセスには1CPU, 0GPUが割り当てられます。このActorプロセスは論文では最大360並列実行されています。

Actorの役割は、遷移情報の収集各遷移の初期優先度の算出です。ActorはQネットワークを持ち自律的にrolloutを行います。100step程のrolloutを行った後に勾配計算は行わず集めた遷移情報をそのままRepalyプロセスに送信します。ただし、遷移情報の送信時にはローカルQネットワークでの推論により初期優先度(∝TD誤差)を算出しておきます。

オリジナルの優先度付き経験再生では、各遷移の初期優先度には最大値を割り当てることで必ず一回は経験が再生されるようにしていましたが、経験再生される速度よりも経験の供給速度の方が圧倒的に速いApe-Xアーキテクチャでそれを行うと直近の遷移ばかりが再生されることになりReplayが意味をなさないため、Actorプロセスで初期優先度を計算することにより遷移情報をふるいに掛けています。換言するとApe-Xアーキテクチャでは一度も再生されないまま消えゆく遷移情報もあるということで、当然サンプル効率は劣悪です*4

ActorがローカルQネットワークを保持して自律的にrolloutを行うという点ではA3Cアーキテクチャと同じですが、Ape-XではA3Cと異なりActorは勾配計算せずReplayへ遷移情報を送信するだけです。これは、勾配情報だとグローバルQネットワークへの反映遅れに気を使う必要がありますが、遷移情報ならばReplayへの反映が多少遅れても問題ないので大規模分散学習にて扱いやすいためです。また、勾配計算するならActorにもGPUが無いと厳しいですが推論だけならCPUだけでもそれほど苦しくないという利点もあります。

また、分散並列化されたActorはすべて異なる探索率εが割り当てられる、というのも重要なポイントです。従来のシングルActorのDQNでは、高い探索率εで学習を開始しゆっくりとεを下げていくアニーリング方式によって探索と活用のトレードオフをバランスしていました。これに対してApe-Xではさまざまな探索率のActorが存在するので自然に多様な経験を収集することができます。このような異なる方策(探索率)を持った並列Actorでのrolloutは、DQNがoff-policyであること生かしたテクニックと言えます。

f:id:horomary:20210227171147p:plain:w300
並列Actorはすべて異なる探索率εを割り当てられる

Replayの役割

Replayの役割はActorからの遷移情報受け取りLeanerに供給するミニバッチの作成、およびLeanerからの更新優先度情報の受け取りです。実態はただの優先度付きReplayBufferなのですが、ActorともLeanerともやり取りしなければいけないため一番忙しいプロセスです。


Rainbowからの継承要素

Ape-X DQNは優先度つき経験再生の後継手法というよりは、DQNの改良トリック全部盛り手法である Rainbow + 大規模分散並列学習 と表現するほうが正確でしょう。実際に、Rainbowが採用していた6つのDQN改良トリック(Double Q-learning, Dueling network, Noisy-network, Prioritized Experience Replay, Multi-step learning, Categorical DQN)のうち、Ape-X DQNではNoisy-networksとCategorical DQN (C51) 以外はすべて採用しています。

上述の通りApe-Xでは各Actorに異なる探索率εを割り当てることにより(計算パワーの力で)探索と活用のバランスをとるので、同じく探索戦略であるNoisy-networkを除外することはごく自然です。一方で、Rainbowに採用された6つDQN改良トリックのうち単体でもっともパフォーマンスの高いC51を除外するのは明らかに違和感がありますが、OpenReviewでの回答を見る限りでは単に実装の煩雑さを嫌っただけのようです*5

Q1: on using all Rainbow components and on using multiple learners.

These are both interesting directions which we agree may help to boost performance even further. For this paper, we felt that adding extra components would distract from the finding that it is possible to improve results significantly by scaling up, even with a relatively simple algorithm. (https://openreview.net/forum?id=H1Dy---0Z)

horomary.hatenablog.com

大規模並列Actorの効果検証

Actorを増やすほどReplayBuffer内の繊維状の入れ替わりサイクルが短くなるため、より最近に収集された遷移情報が再生されやすくなります。また、ある経験が再生される回数が少なくなり源泉かけ流しに近くなっていきます。これはある意味でon-policy学習のやり方でQネットワークを訓練していると解釈できます。

もしon-policyっぽくDQNの訓練を行うことがパフォーマンスの向上の理由ならば、actorの数を大規模並列化せずとも経験再生される回数を制限すればApe-Xと同等のパフォーマンスが得られるはずです。論文ではこれについての検証実験を行っており、Fig.6はactorの並列数(n)=32に固定したうえで、ある遷移情報が再生される回数(k)を変化させるとパフォーマンスにどう影響するのかを示しています。同じactor数(n=32)では再生回数kの違いは大差ないことがわかります。さらにactorの並列数(n)=256の場合はn=32の場合と比較してパフォーマンスに大きな差をつけています。

この結果から論文では経験再生の新陳代謝の速さに由来するon-policyっぽい学習だけでなく、並列マルチ方策(=異なる探索率εが割り当てられた)actorによって生成される多様な経験がパフォーマンスに寄与していると結論付けています。

f:id:horomary:20210227175440p:plain:w600
経験のrecencyの検証実験(Fig. 6)と探索率εの多様さの検証実験(Fig.7)

ついでにFig.7では各Actorへの探索率εの割り当ての多様性の効果について検証しています。各Actorにすべて異なる探索率を割り当てた時と、割り当てる探索率を6つに減らした時にどうパフォーマンスが変わるかの検証実験です。前者はepsilons = np.linspace(0.01, 0.4. num_actors)で後者はepsilons = np.linspace(0.01, 0.4. 6)という感じのイメージ*6だと思います。結果は直感通りで、パフォーマンスにそこまでの大差なしとのこと。


CartPole環境での簡易実装

まずは分散並列アーキテクチャを理解するために、CartPole環境で優先度付き経験再生以外のDQN改良トリックを除外したシンプルな実装を示します。

分散学習の実装

Ape-X アーキテクチャで重要なのはLearner(=GPU)を休ませないことです。このためにLeanerには16セットのミニバッチを渡し、Learnerがネットワークの更新をしている裏でReplayはせっせとActorから経験を受け取っていきます。*7 この流れはrayを使うことですっきり記述できます。

Pythonの分散並列処理ライブラリRayの使い方 - どこから見てもメンダコ


実装上のポイントは39行目のfinished_learner, _ = ray.wait(wip_learner, timeout=0)です。timeout=0を指定したray.waitは実行時点で対象プロセスが未完了である場合、finished_learnerとして空リストを返すためLearnerプロセスの終了判定が可能です。このLeanerプロセス終了判定をActor to Replayでの遷移情報送付が1回行われるごとに実行することで、Learnerプロセスが空き次第すぐに次のminibatchを渡すという疑似的な割り込み処理を実装することができます。

この実装は45,46行目でReplayプロセスが行う、次のミニバッチセットの作成と優先度の更新処理がLeanerプロセスに比べて十分に速いことを前提としていることに注意してください。もしReplayプロセスが遅すぎる、あるいはLeanerプロセスに渡すミニバッチの数が少ないためにLeanerプロセスの完了が早すぎる場合にはActorからの遷移情報送付が滞ってしまいます。このため、一度のネットワーク更新ごとに何回Actorからの遷移情報送付が行われているかはしっかりチェックしておきましょう。

Actorへ最新の重みを渡すためにray.putを使用していることに留意してください。ray.putは大きめのデータ、この場合はメインQ関数の重みを多数のリモートActorに配布する処理を効率化してくれます。

Ray Core Walkthrough — Ray v2.0.0.dev0

ApeXのコアとなるコードはこれだけです。 rayのおかげでシンプルに実装できていると思います。以下では各プロセスの詳細実装を紹介しますが、DQNおよび優先つき経験再生を理解できていればとくに難しいことは無いはずです。


Actorの実装

Actorプロセスは一定stepのrolloutを行って遷移情報するだけなので実装はごく単純です。A3Cと同様にApeXではActorがローカルQネットワーク(LearnerのQ関数のコピー)を持ち自律的にrolloutを行うので、Actor.rolloutではまず初めにleanerのQ関数と重みの同期を行います*8。その後、100step分のrolloutを行い、収集した遷移情報とそれらについてのTD誤差(初期優先度の計算に使用)を返します。


Replayの実装

Replayは単なる優先度付き経験再生です。Ape-XではReplayプロセスが行うミニバッチ作成(Replay.sample_minibatch)の速度パフォーマンスが求められるので、高速な重み付きサンプリングができるSumTree構造で優先度を保存しています。他の留意点として、オリジナルの優先度付き経験再生では、Importance Sampling weights(もどき)のハイパラであるβをアニーリングしていましたが、Ape-Xでは固定値になっています。

horomary.hatenablog.com


Learnerの実装

LearnerはReplayから受け取ったミニバッチでひたすらネットワーク更新するだけのプロセスです。16セットのミニバッチを消費したら最新の重み、および更新された優先度をメインプロセスへ返却します。優先度付き経験再生を理解していれば何も難しいことはありません。

horomary.hatenablog.com


Qネットワーク

特筆することは何もありませんが一応載せておきます。


学習結果:CartPole-v0

CartPole-v0のハイスコア200点に到達するまでおよそ20秒!

f:id:horomary:20210301223207p:plain:w400
1cycle(横軸)ごとにLeanerは16セットのミニバッチでネットワーク更新

(もちろんマシンスペックに依存しますが)8並列Actorの環境ではLearnerが16セットのミニバッチ(各batch_size=32)を消化する間にActorからReplayへの遷移情報送信が25回程度行われていました。


Atari環境(Breakout)での実装

さて、CartPoleでの簡易実装がうまくいったのでBreakout(ブロック崩し)でDQN改良トリックまで含めて本格実装したコード例も掲載します。と言いたいところだったのですが、過去記事で紹介したRainbowの実装と丸被り & コードが長大なので結果だけ示します。改良トリックの詳細は過去記事で、実装全体はGithubでご確認ください。大筋は上に示したCartPoleと同じですが、Dueling-net, Double-DQN, Multi-step Learning、およびメモリ節約のために遷移情報をzlibで圧縮するコードが追加されています。

horomary.hatenablog.com

github.com

学習結果:BreakoutDeterministic-v4

リソースの都合上*9、actorは20並列にしかできませんでしたがそれでも分散学習の威力を実感できる結果となりました。探索率εの大きいactorが混じっていることによる正則化?効果のおかげかパフォーマンスが大崩れせず順調に学習が進みます。ただし、Breakout環境では探索率εの上限値が論文通りの0.4では小さすぎるためか学習の立ち上がりが悪く感じたのでεの上限は0.5に変更しています。

f:id:horomary:20210301215216p:plain:w500
20actorでトータル15時間学習(ε=0.01)

f:id:horomary:20210301215756p:plain:w500
Learnerのlossの経過


次:R2D2

horomary.hatenablog.com

*1:Asynchroneous Prioritized EXperience replay?

*2:wallclock time, 実世界での時間

*3:そして一般人や小規模ラボが参入しにくくなった

*4:サンプル効率についてはFIg.10を参照

*5:C51はネットワーク更新時のCPU処理も多いから設計が面倒になるのかも

*6:εの割り当ての具体的な数値は記述無し

*7:もし実装レベルまで論文を再現したいならtensorflow.Queueで実装する

*8:論文では400stepごとに重みを同期、とあるのでこの実装のように毎回同期はしない

*9:GCPへのリソース割り当て増加リクエストが通らなかった

Segment Tree(セグメント木)による重み付きランダムサンプリング

競技プログラミング界隈では一般教養であるらしいセグメント木のSum-tree構造で高速な重み付きサンプリングを実装します。


はじめに

強化学習の重要手法である優先度付き経験再生(Prioritized Experience Replay)では、重みづけされた100万の経験(遷移情報)からランダムにサンプリングしてミニバッチを作成する、という処理があります。このような重みづけサンプリングはnp.random.choiceの引数pに重み情報を与えることで楽に実装できます。コードの見通しが大変よくなるので過去記事ではこの方法での実装例を紹介しました。

しかし論文ではsum-treeデータ構造で実装すると速いと書いてあります。本記事ではせっかくなのでこちらの実装を試してみます。

[1511.05952] Prioritized Experience Replay

DQNの進化史 ③Prioritized experience replay, Multi-step learning, Categorical DQN - どこから見てもメンダコ

DQNの進化史 ④Rainbowの実装 - どこから見てもメンダコ


A. numpy.choiceによる重み付きランダムサンプリング

まずはベースラインとしてnumpy.random.choiceによる重み付きランダムサンプリングのパフォーマンスを見ます。要素数は100万で各要素には0-5の優先度が割り当てられます。DQNでのミニバッチサイズの32に従って1iterで32要素をサンプリングします。また、Breakout(ブロック崩し)環境ではそれなりに学習が進むと1 episodeで200回くらいはミニバッチ作成するので200iter繰り返します。

f:id:horomary:20210215220249p:plain:w500

結果は約3.2秒となりました。1episodeあたり3.2秒なら趣味で強化学習やるくらいなら許容できる程度ではありますがちょっと遅いですね。


B. 累積和による重み付きランダムサンプリング

つぎに愚直な実装として逆関数法による重み付きランダムサンプリングを実装します。累積密度関数が計算できる確率分布なら逆関数法を使うことでい一様乱数から目的の確率分布に従う乱数に変換できます。

逆関数法 - Wikipedia

逆関数法を用いた乱数生成の証明と例 | 高校数学の美しい物語

たとえば、4要素のリストにおいて各要素の優先度が [4, 2, 1, 3] のときは、0≦ z ≦ 4+2+1+3 = 10 の範囲で一様乱数を発生させ、累積和がzとなるのがどの要素のときかを調べることで優先度に従ったサンプリングを行うことができます。

f:id:horomary:20210215223407p:plain:w600
0≦z≦10で乱数を発生させて累積和がzに該当する要素を選択すれば優先度の大きさにサンプリング確率が従う

f:id:horomary:20210215224616p:plain:w500

たしかに優先度に従ってサンプリングできていることがわかります。では要素数が増えた時のパフォーマンスがどうなるかをnp.random.choiceと同じ条件で確かめます。

f:id:horomary:20210215231046p:plain:w500

結果は144秒、遅い! すべての要素に対して累積和チェックをしているので計算量がNになってしまっていることが原因です。

C. Sum-tree構造を活用した重み付きランダムサンプリング

上で重いのは累積和がzになるのはどの要素番号のときであるかを調べる処理です。これはSegment-tree(セグメント木)構造を使うことで高速に検索することができます。さきほどと同様に各要素の優先度が [4, 2, 1, 3] のときのSum-treeを構築すると下図のようになります。

f:id:horomary:20210215232341p:plain:w500
[4, 2, 1, 3]に対するSum-tree

たとえば累積和が6.5を超える要素番号を検索したいとしましょう。ルートノードである10の左子ノードが6なので、要素0, 1までの累積和が6であることがわかります。よって、要素番号2,3の区間における累積和が0.5(= 6.5 - 6)になる要素を探せばよいというわけです。そこでルートノードの右子ノード4に進みます。この子ノードを見ると1, 3なので左子ノードで累積和が0.5になることが分かります。左子ノードは実要素なのでここで探索終了となります。

実際に格納される要素数Nが  N = 2^{K} のとき、Sum-treeの深さ(階層?)はKになるので探索回数は要素数Nに対してlogNとなり効率的であることがわかります。

Sum-TreeのPython実装

競プロ界隈の人はわざわざ遅いpythonで実装とかしないかもしれませんが、強化学習で使う分にはそこそこのパフォーマンスが出ればよいのでpythonでSum-Treeを実装します。この実装は ray/segment_tree.py at master · ray-project/ray · GitHub から抽象化を削りシンプルに再実装したものです。

実装のポイント:
・格納される要素数がNのときSumtree全体の要素数は2N-1
・ルートノードのインデックス番号を1に設定すると左子ノードのインデックス=2×親のインデックス、右子ノードのインデックス=2×親のインデックス+1 となり便利なので、インデックス番号0を使わない長さ2NのリストでSum-treeを実装する
__setitem____getitem__ を活用してsumtreeであることを感じさせない使い勝手を実現する

f:id:horomary:20210215235032p:plain:w500
クラス外から見た要素番号(赤)と実体の要素番号(青)

f:id:horomary:20210216004207p:plain:w500

速度パフォーマンスの確認

numpyのときと同様にバッチサイズ32のミニバッチを200回作った時の速度を計測します。

f:id:horomary:20210216004530p:plain:w400

サンプリング時間だけ見ればnumpy.random.choiceより30倍程度は速いことがわかります。


おわりに

__setitem____getitem__ が一番輝くのはsegment tree説

DQNの進化史 ④Rainbowの実装

Deep-Q-Network (2013) 以降の深層強化学習(Q学習)の発展を、簡単な解説とtensorflow2での実装例と共に紹介していきます。今回はDQNの改良トリックを全部盛りにしたら強いんでは?という脳筋発想によって生まれた手法であるRainbowを実装します。


DQNシリーズ
DQNの進化史 ①DeepMindのDQN - どこから見てもメンダコ
DQNの進化史 ②Double-DQN, Dueling-network, Noisy-network - どこから見てもメンダコ
DQNの進化史 ③優先度付き経験再生, Multi-step learning, C51 - どこから見てもメンダコ
DQNの進化史 ④Rainbowの実装 - どこから見てもメンダコ


はじめに

[1710.02298] Rainbow: Combining Improvements in Deep Reinforcement Learning

2017年に発表されたRainbowは、それまで報告されてきたDQN改良トリックをすべて搭載したDQNの総まとめ的な手法です。具体的にはオリジナルのDQNに、Double Q-learning, Dueling-network, Noisy-network, Prioritized Experience Replay, Categorical DQN(C51, or Distributional DQN), Multi-step learningの6つの手法を全部盛りにすることにより当時のatari環境のSotAを更新しました。

f:id:horomary:20210211143435p:plain:w400
虹色の線がオシャレ(論文Fig.1)

手法自体に目新しいことは無いので本記事では実装レベルの解説をしていきます。


構成要素の寄与について

論文ではRainbowの各構成要素をひとつ抜いた時にどれだけパフォーマンスが下がるか、という実験をしています。これによると優先度付き経験再生(prior)、分布強化学習(distributional)、Multi-step learningの寄与が大きいようです。

f:id:horomary:20210211144800p:plain:w300
構成要素を一つ抜いたらどれだけパフォーマンスが下がるか(論文 Fig. 3)

ただし、あくまで要素の一つ抜きだけでありすべての組み合わせを試しているわけではないので、各要素の寄与はFigに示されているほど単純に見積もることはできません。実際にtensorflowのチュートリアルDQN C51/Rainbow  |  TensorFlow Agents)ではDistributional DQNとMulti-step learningの組み合わせだけでRainbowと同等のパフォーマンスが得られたという記述がされています。

Although C51 and n-step updates are often combined with prioritized replay to form the core of the Rainbow agent, we saw no measurable improvement from implementing prioritized replay. Moreover, we find that when combining our C51 agent with n-step updates alone, our agent performs as well as other Rainbow agents on the sample of Atari environments we've tested.
(C51とn-step更新は、優先リプレイと組み合わされてRainbowエージェントのコアを形成することがよくありますが、優先リプレイを実装しても測定可能な改善は見られませんでした。さらに、C51エージェントをnステップ更新のみと組み合わせると、テストしたAtari環境のサンプルで他のRainbowエージェントと同様に機能することがわかりました。)


レーニングループ

レーニングループはDQNとほぼ同じです。


ハイパーパラメータはrainbow論文ではなく、 Categorical DQN論文(Distrubutional DQN)に従っていることに注意してください。これはBreakout環境ではDistributional DQN単体の方がパフォーマンスが良いためです。

f:id:horomary:20210211162307p:plain:w600
breakoutはrainbowより分布DQN単体の方が強い

Q-networkの実装(tensorflow2)

Qネットワークには Dueling-network, Categoical DQN (Distributional DQN), Noisy-network の3要素が導入されます。


ReplayBufferの実装

経験バッファには 優先度付き経験再生, Multi-step learning の2要素が導入されます。

注意:このリプレイバッファの実装は見通しの良さを重視しており、速度パフォーマンスを気にしないで実装しています。

SegmentTree構造を利用した高速な優先度付きReplayBufferの実装は別記事を参照ください。

horomary.hatenablog.com


ネットワーク更新(tensorflow2)

ネットワーク更新に絡んでくるのは、Double Q-learning、 優先度付き経験再生、Categorical DQN(Distributional DQN)の3要素です。

見ればわかりますがDistributional DQNの存在がTD誤差の計算をやたらと煩雑にしています。この詳細は過去記事を参照ください。

horomary.hatenablog.com


Breakoutでの学習結果

Breakout(ブロック崩し)を GCPのn1-standard-4(4-vCPU, 15GBメモリ) + GPU K80 のプリエンティブルVMインスタンス*1で24時間学習させました。1Mstep未満で40点取れてるので動作確認としては十分なスコアだと思います。(※rainbow論文では200Mstepを学習)

f:id:horomary:20210211172902p:plain:w500

いろいろ試してみましたが、どうもBreakout環境はNoisy-networkとの相性が悪い印象を受けました。また、優先度つき経験再生を雑に実装*2しているため処理速度がかなり遅く24時間で1Mstepしか進行していません。

実装全体はgithubへ:
github.com


次:Ape-X DQN

C-51がいないおかげでRainbowよりApe-Xのほうが実装が楽に感じる。

horomary.hatenablog.com

*1:24時間でシャットダウンされる代わりに激安なインスタンス

*2:低スぺ対応のためにzlibでオブジェクト圧縮したり重みつきサンプリングをnp.random.choiceで実装したり