どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

オフライン強化学習② Decision Transformerの系譜

Decision transoformer (2021)は、自然言語モデルGPTにおける次トークン予測の枠組みでオフライン強化学習タスクを解けることを示し新たなパラダイムをもたらしました。最近ではDeepMindの超汎用エージェントGATOなどもDecision Transformerベースのアーキテクチャを採用しており、その重要性が増しています。


オフライン強化学習についての説明は過去記事:

horomary.hatenablog.com


Decision Transformer とは

sites.google.com

オフライン強化学習の新たなパラダイム

オフライン強化学習、すなわち環境からのサンプル収集(=オンライン学習)を一切行わず、事前に用意されたデータセットのみで強化学習するという問題設定は実用上大きな意味があります。ビジネスにおいて時間とコストがかかるデータ収集をハイパラ設定を変えるたびにゼロからやり直すのはわりに合いませんし、医療や化学プラントなどではそもそも気軽な試行錯誤が許されないためです。

:オフライン強化学習とは事前用意されたデータセットだけで学習するoff-policy強化学習
([2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems より)

オフライン設定は実用上大変重要ですが、従来のオフポリシー強化学習(例: DQN, SACなど)を単純にオフライン設定で実行するといくつかの理由*1で壊滅的なパフォーマンスになることが知られています。ゆえにそれまでオフライン強化学習とはオフライン設定においてオフポリシー強化学習手法をどうやってperformさせるか、ということが主な論点でした。しかし、Decision TransformerはBERTやGPTのような自然言語モデルのアプローチを模倣学習に導入することで教師あり学習でも強化学習タスクを当時のSotAオフライン強化学習手法(CQL)に相当する性能で解けることを示したことにより、オフライン強化学習に新たなパラダイムをもたらしました。

実スコアでなくゲーマー標準化スコアであることに注意


言語を生成するように行動を生成する

強化学習というのは基本的に逐次意思決定タスク、すなわち各タイムステップごとに過去の観測を考慮した適切な行動を選択するタスクに対して使用されます。一方、自然言語分野における次単語生成タスク、たとえば「吾輩は猫である、名前は~」に続く自然な単語を選択するというのもまた逐次意思決定タスクです。ゆえに次単語生成タスクでは強化学習が用いられることもありますが、しかし現在主流となっているのはTransformerベースの自己回帰的なモデルによる教師あり学習です。

ならば、BERTやGPTモデルによる教師あり学習で自然な文章を生成するように、BERTやGPTモデルによる教師あり学習で自然な行動を生成できるのではないか? というのがDecision Transformerの基本的なコンセプトです。

Decision Transformer:sの出力が次アクションを予測する。R, aの出力はとくに使わない
https://sites.google.com/berkeley.edu/decision-transformer

Decision transfomerではネットワークにGPTアーキテクチャを採用しており、*2 上図においてcausal transformerと記載されている部分はほぼGPTを転用したものです。学習においてもGPTのpretrainingにおける次単語予測と同様に、時刻tにおける状態s_tの出力がアクションa_tを予測するように訓練します。なお、R_t, a_t の出力はとくに利用しません。論文によるとa_tの出力が次状態s_tを予測するようにするなども試したけどとくにパフォーマンスに寄与しなかったとのことです。

ここでR_tとは即時報酬ではなく、t+1時以降に獲得する将来報酬和であることに注意してください。R_tによってDecision Transformerは条件つきの行動予測が可能になっています。たとえばR_tが大きいならエキスパートのような行動を選択し、R_tが小さいなら初心者のような行動を選択します。(詳細は後述)


自然言語風アプローチのメリット

強化学習的なタスクにおいてGPTのような自然言語モデル風のアプローチを採用することには大きく2つの嬉しさがあります。

1. ブートストラップ法を回避できる
強化学習におけるブートストラップ法とは(動的計画法やTD学習のように)教師ラベルに予測値が含まれているような更新方法を示します。ブートストラップ法はオフライン設定において価値の推定誤差を蓄積することがわかっており、これがオフライン強化学習の主要な困難のひとつとなっています。DTではGPTと同じように教師あり学習するだけなので厄介なブートストラップ法を使う必要がありません。

2. 長期credit割当てを効率的に行える
TD学習では行動と報酬発生の時間差が大きい(例: ブロック崩しにおいて報酬が発生するのはボールを弾いてからしばらく後)と、行動と報酬の因果関係を学習するのにかなりの時間がかかることが問題になります。一方、Transformerはそもそも離れた単語間の関連性を効率よく学習することを意図して設計されているのでこの問題を自然に解決することが可能です。


条件付き生成:Reward conditioned

前述のように、R_tとは即時報酬ではなくt+1時以降に獲得する将来報酬和です。ゆえにR_tが大きい=これからたくさんスコアを獲得する=エキスパートによるプレイであり、R_tが小さい場合は同様の理由で初心者のプレイであると理解できます。

将来報酬和R_tが入力シーケンスに含まれているために、DTは与えられたR_tが大きいならエキスパートのような行動を出力し小さいなら初心者のような行動を出力することが可能です。これは単純な模倣学習とは異なりデータセットに質の悪いサンプルが混ざっていても問題ないことを意味します。逆にR_tが無いと単なるGPTアーキテクチャのBehavior Cloningとなります。

Google AI blog: Multi-Game Decision Transformerより(アップロードの都合上フレーム数を削減)

また、このような条件付けの役割を担うのは必ずしも将来報酬和Rtである必要はありません。たとえば格ゲーデータセットにおいてはR_tをプレイヤーの名前に置き換えることでウメハラの行動を再現することができるでしょう。ゲームNPCのAI開発なんかに使うといい感じに難易度調整できそうですね。


Sequence modelingの系譜

Decision Transformerがもたらした新たなパラダイム強化学習 via 自然言語風Sequence Modeling”。 数多くの派生手法が考案されている中から個人的にとくに印象深かったものを紹介します。

Multi-Game Decision Transoformer(NeurIPS 2022)

GPTやBERTなどの巨大transformerアーキテクチャは、教師無し事前学習による未知タスクへのfew-shot learning性能パラメータ数を増やすとパフォーマンスも向上する"べき乗則"などの好ましい特性を持つことが知られています。Multi-Game Decision Transformerでは、このような自然言語基盤モデルの持つ好ましい特性をDecision Transformerも持つのだろうか?ということを調べています。

Multi-Game Decision Transformers

1. Few-shot learningは可能か?

BERTやGPTの最大のメリットは教師無し事前学習済みモデルのファインチューニングによって未知タスクにスモールデータでも対応できることです。Decision Transformerも同様の転移学習性能を持つのかの検証のために、Atariドメイン46ゲーム中から41ゲームを単一のネットワークで訓練した後、未知の5ゲームについて100K ステップのファインチューニングを行い性能を比較しています。事前学習、ファインチューニングどちらもオフラインデータセットでの訓練であることに留意。

 

すべてのゲームで事前学習ありDT(DT pretraining)が事前学習なしDT(DT no pretraining)にoutperformしているので未知タスクに対するpre-trainingの効果はたしかに存在するようです。とくに事前学習なしDTではまったく性能がでていないPongでも事前学習ありDTではしっかりperformしているのが興味深い。PongはなんかDTと相性が悪いみたいで、オリジナルDT論文でもPongだけは入力シーケンスを長くしないとうまく学習しないなど苦労していたようです。

2. パラメータを増やしてGPUで殴ればいい

GPT-3論文がTransformerモデルはパラメータ数を増やせばパフォーマンスもべき乗則 (Power Law)で向上するという実験結果を発表したことによって、NLP分野は大企業が計算リソースで殴り合う末法の世と化しました。DTではどうでしょうか?

Decision Transformerでもべき乗則が適用されるっぽい

3. ViTは有用か?

DTでは観測のトークナイズをDQNのCNNで行っていますが、MGDTではViTを採用しています。これについて比較実験をしたようですがCNNとViTで顕著な差は見られなかったとのことです。

ちなみにMGDTではR, Sも予測対象になっている

ViTだとattentionを簡単に可視化できるので楽しい。

Training Generalist Agents with Multi-Game Decision Transformers – Google AI Blog

Uni[Mask](NeurIPS 2022): MaskedLMの導入

Uni[MASK]: Unified Inference in Sequential Decision Problems | OpenReview

DTは単方向モデルであるGPTアーキテクチャを採用しているため次行動予測タスクに特化しています。そこでUni[Mask]ではBERTのMaskedLMアーキテクチャを採用したうえでマスクの置き方をタスクごとに工夫することで、次行動予測はもちろん状態遷移予測やゴール状態指定などさまざまなタスクに対応できることを提唱しています。アイデアはシンプルですがよくまとまってます。ただし検証はMaze2D環境のみ。

 

GATO(2022):超汎用エージェント

A Generalist Agent | OpenReview

www.deepmind.com

ネットワークアーキテクチャが同じなんだから対話生成もロボット操作もレトロゲームも全部Transformerでマルチタスク学習できるのでは?という頭の悪い発想を世界最高峰の頭脳集団が実現してしまったのがGATOなる超汎用エージェント。実際すごい。

DeepMind Blogより

Decision Transformerでconditioningの役割を担っていた将来報酬和RがGATOでは汎用性確保のため使われていません。かわりにデモンストレーションシーケンスを最初に入力することでconditioningを行っているようです。この方式はprompt conditioningと呼称されています。

prompt

Algorithm Distillation(ICLR2023):学ぶことを学ぶ

openreview.net

※under review

DTのメタ強化学習への発展強化学習アルゴリズム(e.g. DQN, A3C)によって生成されたオフラインデータセットを十分に長い入力シーケンス長を設定したDecision Transformer(というかGATOっぽい)で学習すれば、"学び方を学ぶ"のでは?という手法。強化学習アルゴリズムが探索と活用によって学ぶ様子を再現するのでアルゴリズム蒸留なのでしょう。 

ADはパラメータ更新せずにパフォーマンスを改善する

上図は強化学習アルゴリズムによって生成されたデータセットをオフライン学習済み(学び方の学習済み)のADでパフォーマンスを評価したものですが、ぱっと見では何がすごいのかわかりにくい。ポイントはこの評価中にADは一切のパラメータ更新を行っていないにも関わらずパフォーマンスが向上していることです。これはネットワーク自体が自己改善機能を獲得したことを示しています。なおTransformerでなくLSTMでもOKらしい。(すごそうな手法に見えますが私はメタ強化学習分野に明るくないので先行研究と比べてどれくらいすごいか正直よくわかってない)

同様のコンセプトの先行研究としてOptFormerなどがあります。これはGP, TPEなどが機械学習モデルのハイパラを最適化する過程をTransformerに学ばせることでブラックボックス最適化アルゴリズムの蒸留ができるというものです。

ai.googleblog.com


Decision TransformerのTF2実装

実装全文:

github.com

元論文:[2106.01345] Decision Transformer: Reinforcement Learning via Sequence Modeling

公式実装(pytorch):
GitHub - kzl/decision-transformer: Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.

データセット DQN Replay Dataset

ネットワーク構造

ではDecision TransformerをTF2, Atari/Breakout環境向けに再現実装します。 と言ってもやることは30遷移分の(R, s, a)をトークナイズしてGPTにつっこむだけなので正直コメントすることが無い。

各入力(将来報酬和Rt, 観測状態st、アクションat)は128次元になるよう埋め込む。いずれも最後にtanhでactivationされるのがポイント。

観測状態についてはDQNのCNNで3136次元にした後に全結合層で128次元にする。

stの出力がatを予測するようにカテゴリカルクロスエントロピーを損失関数にネットワークの訓練を行う。Rt, atの出力はとくに使わない。Rtの出力でstを予測するなども試したらしいがとくにパフォーマンス向上に貢献しなかったと論文に書いてあった。

トークンの整列処理はちょっと分かりにくいかもしれないので簡易版を置いておく。

>> rtgs = np.array([R1, R2, R3])
>> states = np.array([s1, s2, s3,])
>> actions = np.array([a1, a2, a3,])
>> tokens = np.stack([rtgs, states, actions], axis=0).T.reshape(1, -1)
>> tokens
[[ R1, s1, a1, R2, s2, a2, R3, s3, a3]]


学習結果

テストではClippedスコアが70点になるようにconditioningして実行。最初の40Kまででほぼ目的の性能に到達している。

clippedスコアで70点を取るようにconditioning

論文掲載スコアを概ね再現できている模様

論文掲載スコア


所感

実装も訓練もデバッグもめっちゃ楽。クロスエントロピー最小化するだけでいいのが楽園すぎる。こんな手軽なら実務でも使いどころあるかも。

*1:[2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems

*2:BERTでなくGPTを使っているのは、GPTが過去コンテクストのみを活用する単方向モデルであるために次行動予測と相性が良かったためと思われる

オフライン強化学習① Conservative Q-Learning (CQL)の実装

オフライン強化学習の有名手法CQLについて、簡単な解説とともにブロック崩し環境向けのtf2実装を紹介します

[2006.04779] Conservative Q-Learning for Offline Reinforcement Learning

sites.google.com

前提手法:
horomary.hatenablog.com


はじめに:オフライン強化学習とは

元ネタ: [2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems

問題設定:ゲーム実況を見るだけで上手にプレイできるか?

一般的な強化学習においては、環境においてエージェントが試行錯誤することで方策を改善していきます。環境で実際に行動することでデータを集めつつ学習する(オンライン学習)ことで、方策更新→環境からのフィードバック→ 方策更新 という改善サイクルを回せることこそが、逐次的意思決定問題において強化学習が単純な教師あり学習(模倣学習)よりも良い性能を獲得することができる根拠の一つとなっています*1 *2。高速なフィードバックサイクルこそが成長の近道というのは仕事やスポーツ、ソフトウェア開発などでも実感できることではないでしょうか?

Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems Fig1より:
オフライン強化学習とは事前に用意されたリプレイバッファだけを使用するoff-policy強化学習である

一方、オフライン強化学習とは事前に用意されたデータセットからの学習だけで良い方策を獲得することを目的とした理論であり、すなわち強化学習の強みであるオンライン学習が禁止されている問題設定です。事前に用意されたデータセットとはエキスパートの行動履歴かもしれませんし素人の行動履歴かもしれません。一人によるものかもしれませんし複数人によるものかもしれません。つまるところオフライン強化学習とはダークソウル*3を一切やったことない人がYouTubeのゲーム実況動画だけを見るだけでダークソウルをノーデスクリアできるようになるか?を目指す問題設定とも表現できます。(無理では?)

なお、オフライン強化学習(offline RL)という用語は Sergey LevineによるNeurIPS2020のチュートリアル講演 Offline RL Tutorial - NeurIPS 2020 から広く使われるようになった感があります。だいたい同じ意味を示す他の単語として batch RL, data-driven RL, fully off-policy RL などがあります*4


実世界でのユースケース

オフライン強化学習とはダークソウルを一切やったことない人がYouTubeのゲーム実況動画だけを見るだけでダークソウルをノーデスクリアできるようになるか?を目指す理論と喩えましたが、そんな縛りプレイみたいなことをするのは実世界にユースケースがあるためです。

たとえば、実環境での大失敗が許容できないドメインではオフライン強化学習が有用です。これは医療や化学プラントあるいは自動運転のように失敗が人命に関わるような分野が該当します。

www.microsoft.com

他には試行錯誤の時間/金銭コストが高いドメインでも活用が期待できます。たとえば試行錯誤の過程での大失敗が高額なハードウェアの損傷につながるロボティクスや、レコメンデーションや自然言語タスクなどにおける人間のフィードバックを報酬とした強化学習 *5のように試行錯誤に人力が必要なためにやり直しコストが大きい分野などです。

arxiv.org

また、そのような極端な環境でなくとも既存の試行錯誤履歴からオフライン学習してそこそこの性能にしたうえでオンライン強化学習をファインチューニング程度に使いたいというのもよくあるユースケースでしょう。アカデミックではあまり研究されない設定だとは思いますが、アルゴ/ハイパラ変更のたびにサンプル収集をゼロからやり直すのはあまりに非効率的なのでこのような転移学習的な視点もまた実務的には大変有用です。


模倣学習との違いなど

オフライン強化学習も模倣学習も事前に用意されたデータセットだけを使う*6という問題設定は共通していますがSergey Levineのチュートリアル論文 では、あくまでオンライン学習前提の強化学習(TD学習や方策勾配法)手法の拡張を指して”オフライン強化学習”と呼称しているように見えるので、基本的に教師あり学習である模倣学習とはこの点で異なります。

与えられたデータセットだけを使う設定に対してオフラインと呼称するので、たとえばシミュレーション環境でオンライン強化学習を行うことは実環境と相互作用しないという意味ではオフラインですが通常はこれをオフライン強化学習と呼称しません。ただし、与えられた固定データセットでシミュレータ(環境モデル)を構築してそこでオンライン強化学習を行うような場合にはオフライン強化学習(オフライン設定の世界モデルベース強化学習)と呼称されてる気がします。


オフライン強化学習の難しさ

図の出典: Offline RL Tutorial - NeurIPS 2020

データセットサイズは問題を解決しない

データセットが十分に大きいならば、オフライン設定のDQNは少なくともデータセットを収集したポリシーと同程度の性能になることを期待したくなります。しかし現実には単純なオフライン設定のDQNはデータセット収集ポリシーどころか、シンプルな模倣学習ポリシーよりもはるかに劣悪な性能となることがしばしばあります。

これはQ学習(など)におけるargmaxオペレータが価値の関数近似において誤差を増幅する効果があるために、オフライン方策とデータセットの行動選択に乖離が生じるためです。すなわちオフライン方策は模倣学習(教師あり学習)による方策と一致しません。

 \displaystyle
{ Q(s, a)  \leftarrow r(s, a) + \arg \max_{a'} Q(s', a)  }

データセットと異なる行動ををすれば当然データセットには無い初見状態に突入するのでパフォーマンスが連鎖的に悪化します。 ダークソウルRTAで完璧に動きを覚えていたのにワンミスで敵の行動パターンが変わってしまいチャート崩壊するようなものです。

価値近似の誤差によりデータセットの状態行動選択(πβ)とオフライン方策(πθ)が乖離すると初見状態に突入する


Out of Distribution: データセット分布外アクションの過大評価

オフラインQ学習においてargmaxオペレータが悪さをする具体例を見ていきましょう。

価値の関数近似の良い点でもあり悪い点でもあるのは、任意の(s, a)についてたとえ一度も試行したことが無くともQ(s,a)を評価することができてしまうことです。この性質とargmaxの悪魔合体によりたまたまQ(s, a)が無根拠に上振れした未実施アクションが採用されてしまいます。オンライン設定であれば次の試行時にうっかり過大評価していたことに気づき修正が行われるのですが、オフライン設定では無根拠な過大評価が永久に修正されません、いわゆるエアプガチ勢です。

argmaxがデータセットのアクション分布(赤)の外にまで最大値を探しに行ってしまう

この課題の解決アプローチとして、TRPOのようにデータセットの行動分布とオフライン方策の分布間距離に制約を課す方法や、[1907.04543] An Optimistic Perspective on Offline Reinforcement Learningのように予測値の不確実性の高い行動(≒未学習領域)にペナルティを課す方法、今回紹介するCQLのようにデータセットに存在するサンプルの評価値に制約を課す方法などがあります。


もっと詳しく

分布シフトの原因として、他にも極値付近で汎化誤差が上振れする問題(Double DQNのアレ)などが指摘されています。すでに何度もリンクを貼っていますがこのようなオフライン強化学習の課題について、より詳しく知りたい場合はSergey LevineによるNeurIPS2020のチュートリアル講演がベストな開始点です。

sites.google.com

この講演は [2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems をベースとしています。ボリューム多すぎてつらい場合は書籍「AI技術の最前線」 4章にてこの論文を3ページ要約してくれているのでこちらもおすすめです。

AI技術の最前線 これからのAIを読み解く先端技術73 | 岡野原 大輔 | コンピュータサイエンス | Kindleストア | Amazon


CQL:保守的なQ学習

[2006.04779] Conservative Q-Learning for Offline Reinforcement Learning

前例が無いからダメです

ようやく本題です。上述したOut of Distributionな行動選択問題についてCQLではサンプルベースのアプローチで対処します。すなわち、データセットに存在しない状態行動(s, a)の評価値Q(s, a)にペナルティを与えることでデータセットに前例のない行動が選択されることを防ぎます。

通常のQ学習ではTD誤差を最小化します。すなわち損失関数は、

 \displaystyle
{ E_{s,a,s' \sim \it{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] }

ここで、 \displaystyle { E_{s,a,s' \sim \text{D}} } とは 遷移サンプル(s, a -> s')がデータセット=オフラインリプレイバッファからサンプリングされたことを示します。CQLではこのTD誤差に対して前例のない行動へのペナルティ項が追加したものを損失関数とします。すなわち、


 \displaystyle
{ \text{ CQL(1) } = \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] \right) + E_{s,a,s' \sim \text{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] }


ここでμとはオフライン学習によって獲得した方策であるので、 \displaystyle { E_{s \sim \it{D}, a \sim \mu} \left[Q(s, a) \right] } とはオフライン学習方策によって行動選択(  \displaystyle { \max_{a} Q(s, a) } , 離散行動空間の場合 )したときの状態行動価値Q(s, a) です。これを損失関数として最小化することでオフライン学習方策μによって選択される行動aのQ(s, a)が低くなるようにネットワークが更新されていきます。データセットに存在しない行動に対してペナルティがつくので保守的(Conservative)なQ学習というわけです。

論文では、CQL(1)の更新式によって獲得されたQ関数はすべての(s, a) についてQ(s,a)の下界が得られることを証明しているのですが、しかしこのCQL(1)は保守的を通り越して新人イジメみたいな更新式になっています。というのもこの更新式では、オフライン方策が選択した行動であればデータセットに(s, a)の前例、があったとしても問答無用でペナルティが与えられるためです。つまりは、過去に前例が無いからダメです、前例があっても提案してるのが新人だからダメです、という感じ。


そこで、すべての(s, a) についてQ(s,a)の下界を得るのではなく、すべてのsについてV(s)の下界を得られればOKとして制約を緩和するとよりタイトな下界が得られます。


 \displaystyle
{ \text{ CQL(2) } = \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]  \right) + E_{s,a,s' \sim \text{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] }


ここでπβとはデータセットにおける行動選択確率です。よって、 \displaystyle
{  E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   }
はオフライン方策の行動選択とデータセットにおける行動選択が一致したときゼロとなり、逆にオフライン方策が選択した行動aについてのQ(s, a) > データセットに記録された行動aについてのQ(s, a) となったときに増大します。したがってCQL(2)を損失関数として最小化することで、前例のある行動についてのQ(s ,a) が常に前例のないデータセット外アクションについてのQ(s,a)より大きくなるようなQ関数を獲得することができます

前例があるならまずは前例を最重視する、まさに保守的なQ学習です。 洗練されたBehavior Cloningという印象。


方策の正則化

実用的にはCQL(2)にさらに方策の正則化項を追加したものを更新式とします。論文では正則化のバリエーションがいくつか提示されていますがもっとも簡単なのは伝統的なエントロピー最大化による方策正則化項を追加したCQL(H)です。


 \displaystyle
{ \text{ CQL(H) } =  \min_{Q} \max_{\mu}   \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   + \text{H}(\mu) \right) + \text{TDError} }


ここで、H(μ)は方策のエントロピーです。Qとオフライン方策μについての二重最適化問題になっていてややこしそうに見えますが、max_μについてはclosed formで解けるので案外シンプルになります。オフライン方策μについての最大化に関係ない定数項をすべて取り除くと、

 \displaystyle
{   \max_{\mu}  E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right]  +  \text{H}(\mu)  }

 \displaystyle
{  = \max_{\mu}  E_{s \sim \it{ D} , a \sim \mu} \left[ Q(s, a) - \log\mu(a | s) \right]    }

これを方策μが確率分布である(非負で総和1)という制約つきラグランジュ最適化問題として解くと、最適方策μ* として、

 \displaystyle
{  \mu{*} =  \frac{ \exp{Q(s,a)} }{ Z(s) }    }

が得られます。ここでZは規格化定数(あるいは分配関数)であり、 \displaystyle {  Z(s) =  \sum_{a}  \exp{Q(s,a)} }  です。このあたりは強化学習 as Inference: Maximum a Posteriori Policy Optimizationの実装 - どこから見てもメンダコ と同じような流れです。


得られた最適方策μ*をCQL(H)に代入すると、

 \displaystyle
{ \text{ CQL(H) } =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a)  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   + E_{s \sim \it{ D} , a \sim \mu} \left[ -\log{\mu{*}(a | s)}  \right]  \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) -\log{\mu{*}(a | s)}  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) -\log{ \frac{ \exp{Q(s,a)} }{ Z(s) } }  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[ \log{ Z(s)}  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha  E_{s \sim \it{ D} }  \left[ \log{ Z(s)}  - E_{ a \sim \pi \beta } \left[Q(s, a) \right]  \right]  + \text{TDError} }

※分配関数Zの期待値は方策に依存しないことに注意


CQL(H)のTF2実装

実装全文:
github.com

最終的に得られたCQL(H)のロス関数をtensorflow2で実装します。

 \displaystyle
{  \text{ CQL(H) } =  \min_{Q}    \alpha \left( \log{ Z(s)} - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) +  E_{s,a,s' \sim \text{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]  }

論文では分布強化学習手法であるQR-DQNをベースとしてCQLを実装しています。これはQR-DQNは実装がシンプルなわりに性能が良好であるためでしょう。QR-DQNのTD誤差項(分位点huber-loss)の実装は過去記事を参照ください。

horomary.hatenablog.com

CQL項自体は論文でも言及されているようにたった20行程度で実装することができます


DQN Replay Datasetの利用

CQLの実装に関して面倒だったのはアルゴリズム実装よりもDQNリプレイデータセットの利用です。

データセットのダウンロード方法:
research.google

DQNリプレイデータセットatariの60 のゲームのそれぞれについて、異なるシードで 5 つの DQN エージェントを200Mフレームでトレーニングしたデータセットです。ただし、4 frame skip設定でのトレーニングのため1エージェントごとに200M/4 = 50M遷移が記録されています。また、50M遷移は(s, a, r, s', a', r', terminal) ごとに50のファイルに分割されて同じディレクトリに保存されています。

このデータセットgoogle強化学習フレームワークであるdopamine-rlに実装されたインメモリreplaybufferであるOutOfGraphReplayBuffer.load()を用いてロードすることが想定されています。

github.com

しかしRAM要求がかなり厳しかったので一度tfrecord形式に変換し、tf.data.TFRecordDatasetを使ってトレーニングしました。この実装だとメモリ16GB程度で十分なのでお財布に優しいです。

github.com


ブロック崩しの学習結果

Breakout(ブロック崩し)環境についてCQL論文の1%データセット設定を再現するように実装し、1Mstepの更新を行いました。論文掲載のスコアが60点くらいなのでだいたい再現できています。

CQL

学習結果


次:Decision Transformer

horomary.hatenablog.com


*1: https://arxiv.org/pdf/2005.01643.pdf 2.4 What Makes Offline Reinforcement Learning Difficult?

*2: DAggerはオンライン設定の模倣学習的な手法

*3:高難易度で有名なフロムソフトウェア社のゲーム

*4:個人的には fully off-policy RL が分かりやすくて好き

*5:例:Learning to Summarize with Human Feedback

*6:ただし模倣学習についてはDAggerみたいにhuman in the loop設定になっていることも多い

論文メモ:深層強化学習によるトカマク型核融合炉の制御

DeepMindの深層強化学習による核融合制御論文を読んだので論文内容と論文を理解するために調べた技術背景をまとめます。

Accelerating fusion science through learned plasma control

www.nature.com



※筆者は原子物理/核融合理論について完全な素人であり当該分野の記述の正しさについてなんら保証がありません。間違いがあればコメントにて指摘お願いします


要約

  • トカマク核融合炉の磁気コイル群は複雑にネストしたPIDコントローラー群によって制御されている
    • この制御方法で性能的な問題はない
    • しかし制御系設計・構築コストの重さが研究開発のボトルネックになっている
  • 深層強化学習でPID制御を置き換えることで制御システムの構築をとても簡単かつ低コストにした
    • シミュレータ環境で方策を事前学習した後、実データでファインチューニング
    • 強化学習手法としての目新しさはなく、Maximum a Posterior Optimization (2018) をほぼそのまま使用
    • ドメイン知識に基づいた即時報酬エンジニアリングがたぶん成功のkey factor
  • これからは研究者の新たなアイデアをすぐに実験投入することができるから核融合炉研究が捗る


技術背景:核融合炉の仕組み

核融合へのとびら / 自然科学研究機構 核融合科学研究所

なっとく!核融合 - 大阪大学

トカマク型 - Wikipedia

核分裂エネルギーと核融合エネルギー

一般に、原子力発電とは核分裂反応によってエネルギーを取り出す方法を示します。この核分裂反応エネルギーとは大きな原子が分裂する際に生じるエネルギーです。核分裂反応による原子力発電は、CO2排出ほぼ無しで莫大なエネルギーを得られるという大きなメリットがある一方で、連鎖反応による暴走の危険性や使用済み燃料の廃棄および廃炉問題など多くの課題を抱えています。

核融合へのとびら / 自然科学研究機構 核融合科学研究所


一方、核融合では小さな原子核が融合し重原子核を生成する反応によってエネルギーを取り出します。CO2排出無しで莫大なエネルギーを得られるという従来の原子力発電のメリットはそのままに、連鎖反応による暴走の危険性が低く、高レベル放射性廃棄物を伴わず、かつ水素という普遍的な資源が燃料ということでしばしば「21世紀の夢のエネルギー」という表現を与えられ、活発に研究開発が行われています。身近な例では太陽のエネルギー源もまた核融合反応です。

核融合へのとびら / 自然科学研究機構 核融合科学研究所


核融合反応のおこしかた

原子核原子核を十分な速度(1000km/sec 以上)でぶつければ核融合反応を起こすことができます。そこでまずは原子を超高温に加熱することで原子運動を加速しプラズマ状態とすします。ここで、プラズマ相とは原子核と電子が自由に動きまわっている状態です。

誰でも分かる核融合のしくみ | プラズマって何? - 量子科学技術研究開発機構

しかしながら、原子核は正荷電しているために反発しあってすぐに散逸してしまうので単純に加熱しただけでは核融合に必要な原子核の衝突がほぼ起こりません。

実用的な核融合反応を起こすためにはプラズマを狭い空間に閉じ込めることで高い密度を保つ必要があります。お互いに反発しあうたくさんの原子核核融合しないと出られない部屋に閉じ込めてやりましょう。また、できるだけ長く閉じ込めることができるほど原子核衝突≒核融合反応発生の可能性が上がります。


トカマク型磁気閉じ込め方式

ここまで、プラズマを狭い空間に閉じ込めることができれば核融合を実現できることがわかりました。しかしどうやって閉じ込めましょう?

太陽では超重力によるプラズマ圧縮核融合を起こしているようですが、地球上では非現実的です。また、超合金的な素材で密閉容器を作るのは一見よさそうなアイデアですが、これをやると物凄い勢いで容器が削れる & プラズマが冷えてしまうのでダメっぽいです。なんとかプラズマに直接触れないで圧縮する必要があります。

そこで、ドーナッツ状の強力な磁力線を発生させることでプラズマに直接触らず閉じ込めを実現するのがトカマク型磁気閉じ込め方式です。

https://atomica.jaea.go.jp/data/detail/dat_detail_07-05-01-06.html

大雑把かつ乱暴な表現をすると、ドーナッツ環に沿った磁力線(トロイダル方向磁力線)と ドーナッツ断面に沿った環状磁力線(ポロイダル方向磁力線)の組み合わせによってドーナッツ表面に沿ったらせん状の磁力線を発生させることにより、プラズマはドーナッツ環内を移動しつづけるためにプラズマの散逸を防ぐことができるようです。


深層強化学習によるトカマク型核融合炉の制御

www.nature.com

研究者はさまざまなプラズマ形状の特性を探索したい

要約するとトカマク方式の核融合炉とは、プラズマを強力な磁力線によってドーナッツ型に圧縮することにより核融合反応を発生させる方式のようです。このトカマク方式の重要な研究課題の一つは閉じ込めの安定性やエネルギー取り出しを最適化するために、より良いプラズマ分布形状(ドーナッツの断面)を探索することです。

様々なプラズマ断面形状 (DeepMind Blog)

課題:目標形状ごとに磁気コイル制御システム構築するのがつらい

論文がターゲットとしているスイス/ローザンヌのトカマク炉ではプラズマの周囲に配置された19個の磁気コイルの精密制御によってプラズマ形状をコントロールしているとのことです*1。問題は、目標とするプラズマ形状ごとに磁気コイルの制御システムを実装する必要があることです。

複雑な従来制御:mはセンサ測定値、aは磁気コイル操作値

制御系の実装には相当なエンジニアリング/設計作業および専門知識を必要とするため、研究者のアイデアを容易に実験で確認できず試行錯誤のコストが研究領域の進歩のボトルネックになってしまっています。


提案:深層強化学習で任意形状の制御を実現する

論文ではセンサー測定値と達成すべきプラズマ形状を入力として、19個の磁気コイルの操作値を出力する方策ネットワークの訓練が深層強化学習で実現できることを示しました。この系においては実験者が新たなプラズマ形状を実験したいときに行うべきことは目標とするプラズマ形状を指示するだけであるため、制御系の実装コストが大きく低減され、トカマク核融合炉の研究が大きく加速することが期待されます。

Control Policy: センサ測定値と目的形状を入力に磁気コイル操作値を出力する4層MLP


制御ポリシーのトレーニン

論文Fig.1

Sim to Real

トカマク炉は物理モデルに基づくシミュレータが利用可能ということで、まずはシミュレータ環境で方策の事前訓練を行った後、実機から収集されたサンプルでファインチューニングを行います。ただし、実機データからのファインチューニングについては厳しいリアルタイム制約からかオンライントレーニングを行ったわけではないようです。

{a, m, t, r} = {アクション, センサ観測値, 目標状態, 即時報酬} (fig.1)

学習回避領域の設定:
トカマクシミュレータ環境について、シミュレーションがうまく現実に一致しないことが事前に分かっている領域があるようです。そこで、指定された条件が発生したときにシミュレーションを停止することで、このような領域をエージェントが学習してしまうことを回避する仕組みを導入しています。


分散並列強化学習

多数の並列actorによって収集される多様な遷移情報が学習の安定性を向上させることが経験的に知られており、この論文でもオフポリシー分散並列強化学習アーキテクチャを採用しています。論文によると環境と相互作用して遷移情報を収集しReplayBufferに送信する多数のActorとひたすらネットワーク更新を繰り返すLearnerで構成されるApe-Xっぽいアーキテクチャとなっているようです。

horomary.hatenablog.com


MPOアルゴリズム

Maximum a Posteriori Policy Optimisation | OpenReview

Actor-Critic系のMaximum a Posteriori Policy Optimization (MPO)を採用して方策ネットワークを訓練します。MPOオフポリシーゆえの高いサンプル効率TRPOのような更新安定性を兼ね備えた使いやすいアルゴリズムです。*2 ざっくりとは方策関数がQ値のボルツマン分布を近似するように更新する手法です。


Actor-Criticネットワーク:
大きめかつRNNつきのCriticネットワークたった4層MLPのPolicyネットワークという非対称な構造を採用しています。これはPolicyネットワークは実機環境におけるリアルタイム推論が必要なために十分に高速に動作しなければならないという実用上の制約のためであるようです。Criticネットワークはネットワーク更新時にしか使われないので動作が遅くともまったく問題ありません。方策ネットワークは92次元のセンサー測定値と132次元で表現される目標状態を入力され、19個のコイルそれぞれの電圧値を出力します。


目標状態tを方策に入力する:
132次元で表現される目的状態tがセンサ測定値mとともに方策関数に入力されます。この研究の目的は研究者の指定する任意のプラズマ形状を実現する制御ポリシーを訓練することなので実現すべきゴール状態tを方策に知らせる必要があるためです。階層強化学習でサブゴールをpolicyに入力するのと同様です。

horomary.hatenablog.com


時報酬rの設計

時報酬rは目標状態tと現在のセンサー観測値mに基づいて決定されるスカラ値です。とても大雑把にはセンサー観測mから推定される現在のプラズマ状態と目標プラズマ状態tが似ていれば高い即時報酬が与えられるような即時報酬関数になっているのですが、ドメイン知識に基づいたさまざまな指標(下表)が即時報酬要素として使われており、さらにベースとするプラズマ形状の種別に応じて即時報酬要素を取捨選択しているようなので報酬設計にはかなりの力が入っていることが推察されます。

Extended Data Table 4 Reward components


実機へのデプロイ

実機への方策デプロイにおいては厳しいリアルタイム推論性能(50μs以内の応答)が求められるため、tfcompile(https://www.tensorflow.org/xla/tfcompile?hl=ja) で高速化を行っているとのことです。


性能検証

目標分布形状(青点)に観測プラズマ分布(オレンジ)が収まっている

所感

PID制御でうまくいっている系を強化学習で置き換えることにパフォーマンス上の意味はないが、R&Dにおいて試行錯誤コストが大きく低減される意義は大きい、という着眼点がさすがだなと感じます。やってることはただの効率化でも、その効率化のケタが違うとゲームチェンジをもたらすというあたりは研究開発デジタルトランスフォーメーションの見本といった印象。

*1:DeepMind Blogより

*2:ちなみにMPOのfirst authorはこの論文のauthorにも入ってる

強化学習 as Inference: Maximum a Posteriori Policy Optimizationの実装

方策が最適である確率の下界をEMアルゴリズムっぽく最大化する強化学習手法 Maximum a Posteriori policy Optimization (ICLR2018) をBipedalWalker-v3向けにtensorflow2で実装します。

openreview.net

MPOで訓練したBipedalWalker-v3

※コード全文:
GitHub - horoiwa/deep_reinforcement_learning_gallery: Deep reinforcement learning with tensorflow2

MPOの実装はdeepmind/acmeを参考にしています
GitHub - deepmind/acme: A library of reinforcement learning components and agents


はじめに

方策勾配法: 劣悪なサンプル効率と不安定な更新

深層強化学習において、連続値行動環境をコントロールする方策(Policy)を訓練するためには方策勾配法が広く使われています。しかし方策勾配法は基本的にオンポリシーなのでサンプルを使い捨てるゆえの劣悪なサンプル効率と、勾配の分散の大きさ*1ゆえの不安定なネットワーク更新という2つの困難が産業応用の壁となっています。

Trust Region Policy Optimization (2015)では信頼領域法の導入により方策勾配法の安定性を大きく高めることに成功しましたが、オンポリシーゆえのサンプル効率の悪さは課題として残ります。源泉かけ流しのごとくサンプルを消費するオンポリシーアルゴリズムはシミュレータならよいですが実機を使うロボティクスなんかでは時間/金銭コストが高すぎます。

horomary.hatenablog.com


Maximum a Posteriori Policy Optimization : 確立推論の枠組みで方策を最適化

Maximum a Posteriori Policy Optimization (MPO) は方策勾配法ではなく Control as Inference、すなわち確率推論のフレームワークで制御ポリシーを訓練する手法です。 MPOはオフポリシーアルゴリズムゆえにサンプル効率が良好であり、かつTRPOのような信頼領域法を用いることでロバストな更新を実現するという実用的な手法です。論文の発表はICLR2018ですが、DeepMindの2022年のNature論文 "深層強化学習によるトカマク型核融合炉の制御 " でMPOが採用されていることからも実用性の高さがうかがえます。

また、MPOにおけるポリシー更新は本質的に教師あり学習のためオフラインセッティングと相性がよいことも重要なポイントです。

www.nature.com


Control as inference では行動の最適性を確率分布で表現する

Control as inferenceの枠組みでもっとも重要なコンセプトは最適性確率変数Oです。例えばあるトラジェクトリτが与えられた時、そのトラジェクトリが最適トラジェクトリである確率はP(O=1 | τ)、そうでない確率はP(O=0 | τ)と表現されます。同様に状態s_tにおいてアクションa_t が最適行動である確率は P(O_t=1 | s_t, a_t)となります。最適性変数Oの導入の最大のメリットはMDPにおける"最適な制御"をグラフィカルに表現できるようになることです。

Sergey Levineの講義資料より

最適性変数Oにより行動の最適性を明示的に確率分布で表現できるために環境の不確実性を自然に扱えるようになります。また、確率推論のさまざまなツールを利用可能になるのも大きなメリットであり、実際にMPOではEMアルゴリズムやELBOなどを活用しています。

↓しっかり学びたい方はSergey Levineの講義動画へ↓

www.youtube.com

PDF: http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_11_control_and_inference.pdf


Control as InferenceとしてのSoft Actor-Critic

連続値コントロール環境ではTRPO/PPOと並んで大人気の Soft Actor-Critic も Control as Inference フレームワークの手法と解釈することができます*2。 しかしSAC論文内ではControl as Inference観点からの説明が乏しいので詳細を知りたい方はやはりSergey Levineの講義動画(上記)を見ましょう。

雑に説明するならSACもMPO方策関数の形状をQ関数に似せることを目的としているので似たような手法という感覚です。 なお、実装の観点ではSACがDDPGの派生みたいな感じになっているのに対してMPOは既存の有力手法とははっきり異なる実装になっています。

horomary.hatenablog.com


最適制御確率の下界を導出する

MPOの目的は方策πで行動決定を実行したときにそれが最適制御である確率 Pπ(O=1) を最大化することです。

 \displaystyle
{ \log p_\pi (O=1 )  }

このままではナンセンスなのでより具体化すると、

 \displaystyle
{ \log p_\pi (O=1 ) = \log \int  p_\pi( \tau )p( O=1  | \tau ) d\tau }

すなわち、最適制御である確率 = トラジェクトリτが方策πに従って生成されるときにトラジェクトリτが最適である確率 の期待値 です。

つづいてVAEなどでもお馴染みのイエンゼンの不等式より任意の確率分布q(τ)について、

 \displaystyle
{ \log \int  p_\pi( \tau )p( O=1  | \tau ) d\tau \geqq \int  q(\tau) (  \log p( O=1  | \tau ) + \log \frac{p_\pi( \tau )}{q(\tau)} ) d \tau }

 \displaystyle
{ = \int  q(\tau) \log p( O=1  | \tau) d\tau + \int  q(\tau) \log \frac{p_\pi( \tau )}{q(\tau)} ) d \tau }

第一項はトラジェクトリτがqに従って生成されるときに、トラジェクトリが最適である確率の期待値です。ここで、トラジェクトリが最適である確率 p(O=1 | τ) はトラジェクトリτの報酬和と指数比例する と想定すると、

 \displaystyle
{  \int  q(\tau) \log p( O=1  | \tau) d\tau  + \int  q(\tau) \log \frac{p_\pi( \tau )}{q(\tau)} ) d \tau  }

 \displaystyle
{  =  E_{\tau \sim q} \left[  \log  \exp  \sum_{t} \frac{r_t}{\alpha}   \right] + \int  q(\tau) \log \frac{p_\pi( \tau )}{q(\tau)} ) d \tau }

 \displaystyle
{  =  E_{\tau \sim q} \left[  \sum_{t} \frac{r_t}{\alpha}   \right]  + \int  q(\tau) \log \frac{p_\pi( \tau )}{q(\tau)} ) d \tau }

報酬和が高いほど最適トラジェクトリである確率が指数的に高まるというのは、まあそりゃそうだけどもという感じ。ここまで方策のパラメータ(≒ニューラルネットの重み)を明示せずにPπ(τ)と表記してきたのでPπ(τ) = P(τ | θ)P(θ) と書き直すと、

 \displaystyle
{  =  E_{\tau \sim q} \left[  \sum_{t} \frac{r_t}{\alpha}   \right]  + \int  q(\tau) \log \frac{p( \tau | \theta )p(\theta)}{q(\tau)} ) d \tau }

 \displaystyle
{  =  E_{\tau \sim q} \left[  \sum_{t} \frac{r_t}{\alpha}   \right]  + \int  q(\tau) \log \frac{p( \tau | \theta )}{q(\tau)} ) d \tau + \log p(\theta) }


第二項はKLダイバージェンス  \displaystyle {  KL( q(\tau) ||  p_\pi( \tau | \theta) ) } そのものなのですが、トラジェクトリτは扱いが困難なので、やや強引ながらトラジェクトリのKLダイバージェンスは各状態行動ステップのKLダイバージェンスに分解できると想定すると、

 \displaystyle
{  =  E_{\tau \sim q} \left[  \sum_{t} \frac{r_t}{\alpha}  -  KL(q( \cdot | s_t) || \pi( \cdot | s_t, \theta)) \right]  +  \log p(\theta) }

 \displaystyle
{ = J(q, \theta )}

ここまでで、方策πで行動決定を実行したときにそれが最適制御である確率 Pπ(O=1)の下界Jを、任意の確率分布qと方策パラメータθで表現することができました。


MPOとは: 最適制御確率の下界JをEMアルゴリズムっぽく最大化

※煩雑になるので割引率γ=1として省略

MPOとは上で導出された最適制御である確率Pπ(O=1)の下界J(q, θ)をEMアルゴリズムっぽい方法で最大化する手法です。すなわち、Estep: θを定数と見なしてqについてJを最大化 → Mstep: qを固定してθについてJを最大化を繰り返します。

 \displaystyle
{ J(q, \theta )  =  E_{\tau \sim q} \left[  \sum_{t}^{\infty} \frac{r_t}{\alpha}  -  KL(q( \cdot | s_t) || \pi( \cdot | s_t, \theta)) \right]  +  \log p(\theta) }


E-step:変分分布qの最適化、あるいはノンパラ版TRPO

Eステップでは方策パラメータθ を定数とみなし、下界Jを最大化するような方策分布qを算出します。

 \displaystyle
{ \text{arg} \max_{q} J(q, \theta ) }

 \displaystyle
{ = \text{arg} \max_{q}  E_{\tau \sim q} \left[  \sum_{t}^{\infty} \frac{r_t}{\alpha}  -  KL(q(\cdot | s_t) || \pi(\cdot | s_t, \theta)) \right]  +  \log p(\theta) }

logp(θ)はE-stepでは定数扱いなのでJの最大化には寄与しないので無視できます。さらに温度パラメータαは非負定数なのでJに掛けてもargmax_qの結果は変わらないため、

 \displaystyle
{ = \text{arg}\max_{q}  E_{\tau \sim q} \left[  \sum_{t}^{\infty} r_t  -  \alpha KL(q(\cdot | s_t) || \pi(\cdot | s_t, \theta)) \right]  }

さらにトラジェクトリ期待値Eτ~q を 状態行動ステップ期待値Eμ(s)Ea~qに書き直すと、

 \displaystyle
{ = \text{arg}\max_{q}  E_{\mu (s)} \left[  E_{a \sim q}  \left[  \sum_{t}^{\infty} r_t  \right] -  \alpha KL(q(\cdot | s_t) || \pi(\cdot | s_t, \theta))  \right] }

ここで初登場するμ(s)というのは定常分布から状態sがサンプリングされるという意味ですが、実装的にはリプレイバッファからとってくるの意味なのであんま気にしなくてOK。ところで時刻t...∞までの報酬和の期待値というのはつまり 状態行動価値Q(s, a) であるので、

 \displaystyle
{ = \text{arg}\max_{q}  E_{\mu (s)} \left[  E_{a \sim q}  \left[  Q(s_t, a_t)  \right] -  \alpha KL(q(\cdot | s_t) || \pi(\cdot | s_t, \theta))  \right] }

つまりJ(q, θ)を最大化する方策分布qとは、[第二項] 現在の方策πから離れすぎないというKL制約のもとで、[第一項] 状態行動価値Q(s, a)が大きな行動を出力する方策分布ということになります。このソフト制約付き最適化問題をハード制約つき最適化問題の形式に書き直したものが以下です。

論文:式7

この制約付き最適化問題において変分分布qを更新後方策、πを更新前方策と見ればほぼTRPOと同じ問題になります。*3。ゆえにTRPOは実質的にEステップのみのMPOと見ることができるます。

ハムスターでもわかるTRPO ①基本編 - どこから見てもメンダコ

しかし、TRPOではqをパラメタライズ(≒ニューラルネットで表現)してハード制約付き最適化問題を解いたのに対して、MPOではノンパラでハード制約付き最適化問題を解きます。

論文Appendix D.2 E-Step より、期待値オペレータを積分形式に書き直す & 確率分布なので総和1の制約を追加した制約付き最適化問題

ラグランジュの未定乗数法で解いた結果が次式(論文:式8)となります。

つまり最適方策qとは現在の方策πを状態行動価値Qのボルツマン分布で重みづけしたものというのがこの制約付き最適化問題の答えとなります。

ここで、温度パラメータη*は次式を最小化するηを求めることによって得られます。


M-step:方策パラメータθの更新

M-stepではE-stepで求めたqを固定して方策パラメータθを更新します。

 \displaystyle
{ \text{arg} \max_{\theta} J(q, \theta ) }

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[  E_{a \sim q}  \left[  \sum_{t}^{\infty} \frac{r_t}{ \alpha}  \right] -  KL(q(\cdot | s_t) || \pi(\cdot | s_t, \theta))  \right] + \log p(\theta) }

argmaxθに影響しない項を消去すると、

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[   q(\cdot | s_t) \log \pi(\cdot | s, \theta)  \right] + \log p(\theta) }

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[  E_{a \sim q}  \left[ \log \pi( a | s, \theta)   \right]  \right] + \log p(\theta) }

第一項はqにしたがってアクションがサンプリングされたときのlogπの期待値であるので、重点サンプリング法でπからアクションがサンプリングされたときの期待値に挿げ替えると、

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[  E_{a \sim \pi}  \left[ \frac{q(a | s, \theta)}{\pi(a | s, \theta)}  \log \pi( a | s, \theta)    \right]  \right] + \log p(\theta) }

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[  E_{a \sim \pi}  \left[ \frac{ \pi(a | s, \theta) \exp \frac{Q(s,a)}{\eta} }{\pi(s | a, \theta)}  \log \pi( a | s, \theta)   \right]  \right] + \log p(\theta) }

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[  E_{a \sim \pi}  \left[  \exp \left( \frac{Q(s,a)}{\eta}  \right)  \log \pi( a | s, \theta)    \right]  \right] + \log p(\theta) }

第一項はexp(Q(s, a)/η)と現在方策πのクロスエントロピーを最大化するということなので、直感的にはボルツマン分布で重みづけされたQ(s,a)に方策πができるだけ似るように更新しよう、と解釈できます。

第二項は方策パラメータの事前分布であるので任意の分布を設定できます。更新前パラメータθiを平均、フィッシャー情報行列F/λを共分散とする多変量正規分布

を事前分布に設定し、さらにargmaxθに影響しない定数項を消去する *4 と、

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[  E_{a \sim \pi}  \left[  \exp \left( \frac{Q(s,a)}{\eta}  \right)  \log \pi( a | s, \theta)    \right]  \right] - \lambda (\theta_i - \theta)^{T} F^{-1}_{\theta i} (\theta_i - \theta) }

ここで、第二項は更新前方策πθiと更新後方策πθのKLダイバージェンスの二次までのテイラー展開であると解釈できます*5

よって最終的にM-stepの目的関数は、

 \displaystyle
{ = \text{arg}\max_{\theta}  E_{\mu (s)} \left[  E_{a \sim \pi}  \left[  \exp \left( \frac{Q(s,a)}{\eta}  \right)  \log \pi( a | s, \theta)   -  \lambda KL(\pi_{\theta i}(a | s)) || \pi_{\theta }(a | s))   \right]  \right] }

結局M-stepもソフトKL制約つき最適化問題に帰着しました。E-stepではハードKL制約問題に書き直してまじめに解きましたが、M-stepではこのままソフト制約問題としてそのまま最大化(実装的には-1をかけて最小化)してしまいます。

ただし、λは制約違反の大きさに応じて適応的に更新していくことに注意してください。すなわち、KL制約がハイパラとして設定する許容値εを上回るようならλを大きくしてKLペナルティを増大させ、そうでないならλを縮小します。Soft Actor-Criticの温度パラメータ自動調整とやっていることはほとんど同じです。


BipedalWalker-2dでの学習結果

実装詳細はgithubへ:

github.com

BipedalWalker-v3での学習結果

横軸:エピソード数、縦軸:スコア(300点くらいが満点)

雑記

  • 探索力が弱い

MPOにおける方策学習は本質的にはQ関数に似せることを目指す教師あり学習なので、Q学習の弱点である探索力の弱さが目立つ。対照的にSACは探索力が高いのでよく知らない環境に対して雑に使ってもうまくいくのはSACという印象。

  • ポリシーの分散が過度に増大しやすい

方策に正規分布を設定するとQ関数に似せようとする強い力により方策の分散がとんでもなく大きくなりやすい。対処できなかったので本実装ではtf.clip_by_value(sigma, min, max)してしまった。

  • KL制約違反の許容閾値 ε の調整がけっこうセンシティブ

ターゲットネットワークの同期頻度に応じてεをいい感じに調整する必要がある。εは小さいほど安定更新だが小さすぎると学習に時間がかかりすぎる。


*1:方策勾配定理は勾配方向の"平均値"の正しさは保証してくれるが分散は考慮しないのでバッチサイズを相当大きくしないと不安定化する。また、保証してくれるのは適切な更新方向だけで適切な更新サイズは保証してくれない

*2:そもそもSAC論文のlast author はsergey levineである

*3: 正確にはMPOはKL(q || π)であるのに対してTRPOではKL(π || q)

*4:参考: 多変量正規分布の確率密度関数

*5:参考資料:CMUのTRPO講義資料のTaylor expansion of KLのスライド

DreamerV2の実装: 世界モデルベース強化学習でブロック崩し

世界モデル系強化学習の先端手法であるDreamerV2ブロック崩し(BreakoutDeterministic-v4)向けに実装しました。


論文
DreamerV2: [2010.02193] Mastering Atari with Discrete World Models

Google AI Blog: Mastering Atari with Discrete World Models

DreamerV1: [1912.01603] Dream to Control: Learning Behaviors by Latent Imagination

Google AI Blog: Introducing Dreamer: Scalable Reinforcement Learning Using World Models


はじめに

2012年にDeepMindDQNが世界を驚かせてから、たったの10年間で深層強化学習研究は目覚ましい進歩を遂げました。2012年のDQNの頃は深層強化学習エージェントが”人間レベル”でブロック崩しをプレイ可能というだけで驚くべき成果でしたが、2022年現在のモデルフリー強化学習の先端手法である R2D2MuZero ではブロック崩し環境において超人的パフォーマンスに到達しており余裕のパーフェクトプレイを見せつけてきます。


Atari2600 Breakout (ブロック崩し

しかし、このような手法の飛躍的発展にも関わらず深層強化学習の産業応用はそれほど進んでいません。原因の一つは、深層強化学習"劣悪なサンプル効率" 汎化性能の低さという課題を抱えていることでしょう。 たとえばブロック崩し環境について、人間であれば4, 5回の試行錯誤で到達できるようなパフォーマンスに深層強化学習エージェントが到達するためには最先端の手法でさえ何千回の試行錯誤が必要となるのです。また、何千何万回の試行錯誤の末に到達した超人的な強化学習エージェントでさえ、背景画像を変えられてしまうだけでパフォーマンスが崩壊してしまいます。

もし人間のように少数の試行でもそこそこのパフォーマンスを発揮できるような高度な学習能力を持つ強化学習手法を開発できたのならば強化学習の産業実用性は大きく高まります。この実現を目指すアプローチのひとつが世界モデルベース強化学習です。


世界モデルベース強化学習とは

「人間の高度な学習能力を支えているのは脳内シミュレータ―に基づく未来予測と行動計画である」という仮説のもと、脳内シミュレータを備えた強化学習エージェントを訓練しようというのが世界モデルベース強化学習のコンセプトです。

野球を例にとってみましょう。野球の打者は、バットをどのように振るかを決定するのに数ミリ秒かかります。これは、目からの視覚信号が脳に到達するのにかかる時間よりも短い時間です。100マイルの速球を打つことができる理由は、ボールがいつどこに行くかを本能的に予測できるためです。プロのプレーヤーにとって、これはすべて無意識のうちに起こります。彼らの筋肉は、内部モデルの予測に沿って、適切な時間と場所でバットを反射的に振ります。計画を立てるために考えられる将来のシナリオを意識的に展開する必要なしに、将来の予測にすばやく対応できます。(Google翻訳)  ― World Models より

つまりは、『World Models (世界モデル)』ベースの強化学習手法が目指すのは環境のダイナミクス(状態遷移)のモデル化です。もし系の状態遷移を完全にモデル化できたのならばその状態遷移モデルは系の(微分可能な)シミュレータ―として機能します。もし完全なシミュレータ―が得られたらばあとはそのシミュレータ上で満足いく結果が得られるまで行動計画を続ければよいのです。

https://worldmodels.github.io/


DreamerV2:Atari環境で初めてモデルフリー手法に並んだ世界モデルベース強化学習

世界モデルベース強化学習はコンセプトは面白いのですが、モデルフリー手法と比較するとパフォーマンスはいまいちというのがこれまでの傾向であり、とくにAtari2600ベンチマークレトロゲーム)のように視覚的ノイズの大きい環境が苦手でした。

しかしICLR2021にて発表された"DreamerV2"は、非分散並列手法としてはatari環境で長くトップにあったモデルフリー手法Rainbowを超えるパフォーマンスを示したことにより、世界モデルベース強化学習への期待感を大きく高めることとなりました。本稿ではこのDreamerV2のブロック崩し環境向け実装例を簡単な解説とともに紹介します。

論文Fig1 DreamerV2は世界モデルベース強化学習で初めてアタリ環境でのRainbow越えを達成

[2010.02193] Mastering Atari with Discrete World Models(DreamerV2) 


世界モデル(World Models)について

世界モデルがやりたいこととは、t時刻における状態xとアクションaからt+1時刻における状態x を予測することです。しかし画像のような高次元かつ密なデータにおいて精度よく遷移予測を行うのはなかなかに困難な課題です。

画像観測から次の画像観測を直接予測するのは難しい

そこで、系の直接観測はより少数の潜在変数に支配されているはずだ(多様体仮説)という考えにもとづき、Variational AutoEncoder(VAE) の訓練によって得られる潜在変数空間における遷移モデルを関数近似しようというのがWorld Modelの基本コンセプトです。

下図について、潜在変数空間における状態遷移モデル  \displaystyle{
 q( z_{t+1} | z_{t}, a_{t})
} は確率モデルであることに注意してください。これは状態遷移を決定論的な関数で近似してしまうと、(極端には)ブラックジャックのようなランダム系にまったく対応できなくなってしまうためです

World Modelは潜在変数空間における状態遷移モデルを近似する

このような「潜在変数空間における状態遷移のモデル化」というコンセプト自体は実際のところ何ら新しいものではありません。Dreamer論文内でも言及されているようにこれは非線形カルマンフィルタや隠れマルコフモデルのバリエーションと捉えることができるためです。

それにも関わらずWorld Modelを魅力的な手法にしている理由は主に2点です

  1. 高い説明可能性
  2. 微分可能シミュレータの獲得

まず前者について、World ModelではVAEアーキテクチャを用いるので潜在変数から実観測を復元することができるため、高い説明可能性を有しています*1。すなわち「今この行動をしたらこの後はこういう風になるよ」ということを視覚的に説明できるわけです。説明可能性の高さは産業活用へのハードルを大きく低減するため大変重要な特性です。

VAEなので潜在変数を実空間に復元できる

さらに後者について、World Model が獲得する状態遷移モデル=系のシミュレータの実体はニューラルネットワークであるため微分可能です。よって方策も関数近似している場合は、シミュレータから獲得できる報酬を最大化するように方策ネットワークを直接最適化できることとなります。イメージとしてはDDPGにおけるQ関数がシミュレータに置き換わった感じでしょうか。

参考: 微分可能なシミュレータ上での方策最適化 - Preferred Networks Research & Development


微分可能であることを置いておいたとしても、軽量なシミュレータを獲得できることは一つの系で様々なタスクを学習させたい場合や試行の金銭/時間的なコストが重い系においては魅力的な特性です。また、世界モデルの構築は強化学習ではなく単に教師あり学習なのでデータの再利用が容易であるためオフラインセッティングと相性がよいことも特筆すべき利点です。


方策の獲得

世界モデル=系のシミュレータがうまく構築できているならば方策の獲得はどのように行ってもOKです。世界モデルを普通にシミュレータとして使って方策勾配法で最適化してもよいですし、微分可能シミュレータであることを生かしてダイレクトに方策ネットワークを最適化してもよいです。あるいは軽量シミュレータであることを生かしてCMA-ESのような進化戦略アルゴリズムで方策パラメータをブラックボックス最適化してしまうという方法もよく用いられているようです。


Dreamerへの系譜

DreamerV2に至るまでの先行研究の流れをまとめます。

Wolrd Models (2018)

[1803.10122] World Models ( David Ha, Jürgen Schmidhuber, 2018)

"World Models" という用語の初出は(たぶん)この論文です。コンセプトの似た研究はこの論文以前からありましたが、World Models というキャッチーなネーミングによってVAEベースの状態遷移モデルによるアプローチが広く認識されるようになりました。

論文では"MVC"の3コンポーネント*2で構成されるWorld Modelアーキテクチャ のコンセプトを提案しました。

  • Vision: 画像観測から潜在状態zを抽出するVariational AutoEncoder
  • Memory RNN: 過去の観測情報を潜在変数空間で記憶し現観測と合わせて次観測を予測
  • Controller:方策関数

World Models論文 Fig4より

Fig6 Memory RNNアーキテクチャ詳細

この手法における状態遷移の予測は次の3ステップで行います。

  1. 時刻tの潜在変数zと行動aをLSTMに入力
  2. LSTM出力から時刻t+1における潜在変数zの分布を予測(混合ガウス分布のパラメータを予測)
  3. 予測した時刻t+1における混合ガウス分布から潜在変数をサンプリング


Controller(=方策)について、World Models における方策関数は超シンプルな線形モデルです。パラメータ数が少ないことを生かして進化戦略アルゴリズム(CMA-ES)を用いて重みをブラックボックス最適化(勾配フリー最適化, derivaritive-free optimization)してしまいます。

World Models の方策関数はシンプルな線形モデル


PlaNet: Deep Planning Network (2019)

Google AI Blog: Introducing PlaNet: A Deep Planning Network for Reinforcement Learning

https://arxiv.org/pdf/1811.04551.pdf

オリジナルのWorld Modelでは系の潜在状態を確率変数(VAEの潜在変数 z)でのみ表現しますが、PlaNetでは確率的変数と決定論的変数の両方で系の潜在状態を表現することによって、World Modelアーキテクチャ時系列video予測を高精度化することに成功しました。

  • 決定論的潜在変数*3: RNN(GRU)の出力h
  • 確率的潜在変数: VAEの潜在変数z

多くのベンチマーク環境に置いて確率的な遷移をするのは画面のごく一部であり背景などはほぼ決定論的な遷移をすることを考えれば、決定論的潜在変数を導入することでvideo predictionが高精度化することは妥当な結果に感じます。このPlaNet版World Modelアーキテクチャはほぼ姿を変えずに以降のDreamerおよびDreamerV2に継承されています。

Google AI blog Dreamer より(※DreamerとPlaNetの世界モデル部はほぼ同じ)

PlaNetでは方策関数を作成せず、軽量高速並列シミュレータとしてのWorld Modelの特性を生かして最適なアクションシークエンスを直接探索します。たとえば離散行動空間4の系であればこの先15stepについて可能なアクションシークエンスは4**15通り存在します。いくら世界モデルが軽量シミュレータとして機能すると言えどもこれを全探索はさすがに厳しいので、PlaNetでは進化戦略アルゴリズムで最適シークエンスを探索します。

探索に使用する進化戦略アルゴリズムは(計算リソースが許すなら)遺伝的アルゴリズムでも何でもいいと思うのですが、PlaNetではシンプルに Cross-entropy method を採用しています。


Dreamer (2020)

Google AI Blog: Introducing Dreamer: Scalable Reinforcement Learning Using World Models

https://arxiv.org/pdf/1912.01603.pdf

Dreamerとは世界モデル上で強化学習エージェントを訓練するという SimPle的なアプローチがPlaNetの高精度なWorldModelと組み合わさることにより大きな成功を収めた手法と要約できます。

DreamerはPlaNetの高精度なWorld Model上で強化学習エージェントを訓練することによって、連続値コントロール環境において当時SotAのモデルフリー手法であるD4PG(Distributinal Distributed Deep Deterministic Policy Gradient)に相当する高いパフォーマンスを達成することに成功しました。しかもサンプル効率は20倍以上です。

モデルフリーSotA手法D4PG 相当/以上のパフォーマンス

Dreamerは世界モデル上で強化学習エージェントを訓練する

Dreamerの特徴はPlaNet版世界モデル上のみでActor-Critic型の強化学習エージェントを訓練する点です。エージェントのネットワーク更新に実環境との相互作用によって得られたサンプルは一切使わないゆえに、夢の中で訓練する"Dreamer"という命名がなされています。ニュアンスとしてはイメージトレーニングといった感じ。

また、Dreamerは世界モデルが即時報酬について微分可能であることを最大限活用し、方策勾配定理を使わず期待報酬和を直接最大化するように方策関数のパラメータを更新します。これはDDPGの更新式とよく似ており、すなわちDDPGではQ関数の出力値を大きくするように方策関数を更新しますが(下式)、Dreamerではシミュレータ(世界モデル)から獲得するN-step報酬和*4が大きくなるように方策関数を更新します。

これはDDPGの更新式


DreamerV2 (2021)

Mastering Atari with Discrete World Models | OpenReview

Google AI Blog: Mastering Atari with Discrete World Models

DreamerV2では、Dreamerに2つのトリックを追加することでAtari環境においてRainbowを超えるパフォーマンスを発揮することを報告しました。

  1. VAEの潜在変数分布に正規分布でなくOneHotCategoricalDistributionを仮定 (DiscreteVAE)
  2. KLバランシング

1についてはVAEの潜在変数分布に単峰正規分布という強い仮定を置くのではなく、より自在な分布を表現できるカテゴリ分布を用いることで系の確率的な遷移をより表現しやすくなるのではないかと考察されています。

DreamerV2ではVAEをDiscreteVAEに差し替える

潜在変数に単峰ガウス分布を仮定するくらいなら多少目が粗くても自由な分布を表現できるカテゴリ分布のほうがいい

世界モデルベースで初めてRainbowのパフォーマンスを超えたのは間違いなく大きな成果なのですが変更があまりにも単純なのでこれ自体は語るところの少ない論文でもあります。また、論文ではサンプル効率がRainbowと比較して良いことが示されているわけでもなく、モデルベースというくくりで見ればパフォーマンスでMuZeroにはまったく敵わないということもあり、OpenReviewでは"動機不明の研究"との辛口コメントも見られます。さらに、DreamerV1では最大限に活用されていた「世界モデルの微分可能性」もAtari環境ではあまり役に立たないようです。

・・・アカデミックにはあまり面白くない手法かもしれませんが、世界モデルでRainbow超えは素直にすごいと思ったので実装してみたというのが本記事のモチベーションです。


非VAEの状態遷移モデル:MuZero(2020)

MuZero: Mastering Go, chess, shogi and Atari without rules

世界モデルとは呼称されませんがコンセプトのよく似た手法であるMuZero についても簡単に説明します。

MuZeroとはAlphaZeroのモンテカルロ木探索アルゴリズム潜在変数空間*5における状態遷移モデルを導入することで、AlphaZeroを古典ボードゲーム以外(=Atari2600環境)にも適用可能にした手法です。”ニューラルネットワークで近似された状態遷移モデルを利用した先読み検索による行動プランニング”という、PlaNetに極めて近いコンセプトによって当時のトップモデルフリー手法 R2D2を超えAtari環境SotAを達成しました。

MuZero: 潜在変数空間の状態遷移モデルを利用したMCTS


コンセプトが似ているにも関わらず、atari環境に置いてMuZeroはDreamerV2よりもはるかに高いパフォーマンスを発揮するのはなぜでしょうか? あくまで個人の考察ですが、これはMuZeroがいろんなものを切り捨ててatari環境でのパフォーマンスに特化したからと思われます。

1. MuZeroは説明可能性を切り捨てている

VAEベース世界モデルの特徴のひとつは、潜在変数を実画像に復元可能であるために未来予測を人間が視覚的に理解可能であることです。これは実用上は本当に重要である一方で、実画像への復元可能性を保証するために潜在変数が冗長なものになってしまっている可能性があります。たとえば、ブロック崩し攻略において重要なのは極論 "ボールの位置" と ”パドルの位置” だけであるに関わらず、VAEベース世界モデルの潜在変数にはブロックの色やスコア表示など余分なノイズ情報が多量に含まれているはずです。

一方、MuZeroでは潜在変数を実画像へ復元することはできない代わりに効果的な状態表現獲得ができていると考えられます。

2. MuZeroは確率的遷移を切り捨てている

VAEの潜在変数は確率変数であるため、確率的な状態遷移を自然に取り扱うことができます。一方、MuZeroの潜在変数とはResNetの隠れ層でしかないので決定論的です。同じ状態で同じ行動を取っても確率的に遷移する系(コイントスとか)においては状態表現が確率的であることは必須であるように思いますが、しかしAtari2600のほとんどのゲームの挙動は決定論的であると見なせるため、Dreamerでは確率的な遷移を表現できるVAEベース状態表現のメリットがあまり得られない上に学習を不安定にしていると考えられます。

3. MuZeroはPOMDPを切り捨てている

MuZeroはRNNを採用していませんが、代わりに直近32フレーム(2-3秒相当)をまとめて入力することで短期記憶の代わりとしています。アタリ程度なら直近32フレーム入力すれば十分にMDPと見なせる、と割り切ることでシンプルなアーキテクチャを実現しています。

horomary.hatenablog.com


Tensorflow2による実装例

ここからはtensorflow2によるBreakout向け実装例を紹介します。要点のみの掲載ですのでコード全文はGithubを参照ください。

github.com

世界モデル部

DreamerV2の世界モデルアーキテクチャは論文中では下図のようにシンプルに示されていますが、実装はこの図よりもはるかにややこしいことになっています。

DreamerV2世界モデル

DreamerV2世界モデルのデータの流れをより詳細に図示したものがこちらです。

DreamerV2の世界モデル

ロス計算のための予測ヘッドまで含めると下図のようになります*6

予測ヘッドとDecoderまで含めた場合

ロス関数は通常のVAEと同様な(画像再構成誤差+KL項)に (即時報酬予測誤差+割引率予測誤差)を合わせたものとなっています。なお、割引率予測とはほぼエピソード終了予測同じようなものと考えてOKです。

実装例はこちら。discreteVAEなのでカテゴリ分布向けのReparameterization Trickを使ってるのがポイント


ロールアウト部

実観測から潜在状態s(zとhをconcatしたもの)を計算し、sに基づいて行動決定します。


強化学習エージェント部

上述の通り、Atari向けのDreamerV2では世界モデルの微分可能性を活用せず単に軽量なシミュレータとして使用するため、強化学習エージェントの訓練について技術的に特筆するべきことはありません。適当にActor-Criticを実装するだけです。私はPPOで実装しました。


学習結果

着実にパフォーマンス向上しつづけているのですが、実行速度パフォーマンス(実装レベルのチューニング&GPU性能)が悪すぎるため一週間ほど学習を続けてようやく30-40点台に到達するという何とも微妙な結果になってしまいました。R2D2を実装したときも思いましたがRNNが入ると本当に速度チューニングがつらい。こういうややこしいRNNはどうしたら高速化できるのだろうか?


雑記

※個人の定性的な感想です

  • センシティブなハイパラが多すぎる
  • とくに世界モデルとActorが共進化するように学習率を制御するのが大変
  • KLロス項の変動が大きいとActorの学習が破綻する傾向がある
  • VAEの一般的な弱点として小さくて動きの速い要素が存在するとつらい
  • しかしブロック崩しは小さくて動きの速いボールこそが最重要なオブジェクト
  • 世界モデルはコンテクストを理解している印象はなく、暗記ゲーをしている疑いがある
  • 安直だけどMaskedAutoEncoderのような表現学習と組み合わせるとよいのかも

*1:世界モデルから説明可能性を投げ捨てるとMuZeroになる

*2:WebアプリケーションのMVCアーキテクチャとかけてるのもオシャレポイント

*3:決定論的な隠れ層について潜在変数という呼称が適切かは知らない

*4:実際は単純和ではない

*5:ResNetの中間出力を潜在変数と言うことが適切とは思えないが他によいワードが思いつかない

*6:ただし割引率γの図示は割愛

スッキリわかるAlphaFold2

注意:

  • Alphafold2の手法解説です。使い方の説明ではありません
  • 構造生物学ドメインにはある程度の説明をつけます
  • アーキテクチャ設計の意図については個人の考察であり、正しさに何ら保証がありません
  • AttentionとTransformerそのものについての説明は行いません



AlphaFold2論文など:

Highly accurate protein structure prediction with AlphaFold | Nature

Supplementary information(アルゴリズム詳細説明)

AlphaFold: a solution to a 50-year-old grand challenge in biology | DeepMind

DeepMindのCAPS14プレゼン


AlphaFold2とは

タンパク質は生命に不可欠であり、実質的にすべての機能をサポートしています。タンパク質はアミノ酸の鎖で構成された大きな複雑な分子であり、タンパク質が何をするかはその独特の3D構造に大きく依存します。タンパク質がどのような形に折りたたまれるのかを解明することは「タンパク質の折り畳み問題」として知られており、過去50年間生物学の大きな課題となっていました。このたび、AIシステムAlphaFoldの最新バージョンは、隔年で開催されるタンパク質構造精密予測コンテスト(CASP)の主催者によってこの壮大な課題の解決策として認められました。この画期的な進歩は、AIが科学的発見に与える影響と、私たちの世界を説明し形作る最も基本的な分野のいくつかで進歩を劇的に加速する可能性を示しています。DeepMind Blog より)

DeepMind社のタンパク質の立体構造予測アルゴリズムAlphaFoldの最新版である「AlphaFold2」がCASP14(タンパク質立体構造予測コンペ)にてエポックメイキングな結果を残しました。とくに構造未知タンパク質の立体構造*1を予測するという非常に難しいタスクでも、GDT(ざっくり正しい位置を予測できている原子の割合)が87.0という驚くべき精度となっています。

f:id:horomary:20210925182159p:plain:w600
GDT:ざっくり正しい位置を予測できている原子の割合

今後AlphaFold2は構造生物学的研究の必須ツールの一つとなっていくことが予想されますが、内部の仕組みがよくわからないツールを使うのはいくら精度が高いことが実証されていても心地悪いものです。そこで、本稿ではこの「AlphaFold2」のタンパク質立体構造予測の仕組みを論文に基づいて解説していきます。


タンパク質折り畳み問題について

まずは前提知識としてタンパク質折り畳み問題の概略を説明します。※生物系の方はスキップ推奨

タンパク質はバイオ・ナノマシン

タンパク質とはアミノ酸が鎖状に連結した高分子化合物でありすべての生命に不可欠な物質です。タンパク質は筋肉を動かすモーターとして働いたり、赤血球として全身に酸素を運んだり、あるいは緑色に光ることさえできます。生命活動に必要なほとんどすべての機能発現がタンパク質によって担われているのです。

f:id:horomary:20210925191514p:plain:w600
さまざまなタンパク質の立体構造

・タンパク質は20種類のアミノ酸で構成される鎖

多様な機能を発現するタンパク質ですが、驚くべきことにすべてのタンパク質はたった20種類*2アミノ酸が鎖状に連結することによって構築されています。

f:id:horomary:20210925192643j:plain:w600
https://www.genome.gov/genetics-glossary/Amino-Acids より

20種類のアミノ酸のうちでも、アスパラギン酸(Aspartic acid) やトリプトファン (Tryptophan)などは栄養ドリンクやサプリに表記されているのを目にしたことがあるのではないでしょうか。他にはグルタミン酸(Glutamic acid)などはうまみ調味料の主成分としてお馴染みですね。

補足資料として下図ではタンパク質を原子レベルで描画した場合の構造タンパク質構造に関する用語をまとめました。とくに、Cα(α炭素)の位置とねじれ角(torsion angle)Φ/Ψ/ω は知らないとAlphaFold2論文を読むのが困難なので覚えておきましょう。

f:id:horomary:20210925222153p:plain:w700
タンパク質構造に関する基本用語集


・タンパク質は自発的に立体構造を形成する

ここまででタンパク質はアミノ酸できた鎖(ペプチド)であることを説明しましたが、このアミノ酸の鎖は適温の水*3に入れることで自発的に*4折りたたまれて立体構造を形成します。 異なるアミノ酸配列では異なる構造に折りたたまれるためにタンパク質は多様な立体構造と多様な生物学的な機能を持つことができるのです。

f:id:horomary:20210925225456p:plain:w600
DeepMind blog "Using AI for scientific discovery"より

補足情報:細胞内でのリアルな折り畳みについての解説
田口 英樹 「タンパク質フォールディングの「理想」と「現実」:凝集形成とシャペロンの役割」日本生化学会


タンパク質折り畳みの駆動力となっているのはタンパク質内部の分子間相互作用や水との相互作用などの物理化学的な力であるので、ごく小さいサイズのタンパク質であれば分子シミュレーション(分子動力学シミュレーション, Molecular Dynamics) によって折り畳み過程をシミュレートすることもできます。が、しかし分子シミュレーションで立体構造予測をするのは計算コストが重すぎるので現時点ではあまり実用的ではありません。

www.youtube.com


タンパク質立体構造の重要性

タンパク質の立体構造がわかることには様々な嬉しさがあるのですが、もっとも身近なのは創薬への応用です。たとえばインフルエンザ治療薬として有名なリレンザタミフルは、ウイルスが増殖するために重要なインフルエンザノイラミニダーゼというタンパクの働きを阻害するような低分子化合物をタンパク質構造に基づいて設計することによって見出された薬です。

f:id:horomary:20210925232521p:plain:w800
応用物理学会 創薬を目指したSPring-8/SACLAの構造生物学研究 より


・ タンパク質立体構造解析の困難

問題はタンパク質の立体構造を実験的に解明するためには多大な労力が必要であることです。

もっともポピュラーなタンパク質立体構造の解析手法はX線結晶構造解析法です。この手法ではその名の通りタンパク質の結晶をX線で解析するのですが、そもそもタンパク質結晶の作製難易度が極めて高いという難点を抱えています。結晶成長に重力が悪さするから宇宙で結晶つくろうぜという実験が実際に行われている、という事実からも苦労が推し量れます。

他には NMR(核磁気共鳴法)という分析機器を使ったり、最近ではCryo-電子顕微鏡でタンパク質を直接見るというような方法も行われています。


データ駆動の立体構造予測

過去50年間にわたり構造生物学者たちは大変な労力をかけて実験的なタンパク質立体構造解析を行ってきました。このような先人の血と汗とPEGの結晶として2021年現在では18万件以上のタンパク質の立体構造がPDB(Protein Data Bank) に登録されています。

これほどのデータが蓄積されているならば教師あり学習の機運が高まるのは自然な流れです。

タンパク質の立体構造は
・ 似ているアミノ酸配列であれば同じような立体構造となる
・ 同じアミノ酸配列で同じ温度の水中なら同じ立体構造となる (と概ね見なせる)
という教師あり学習に適した特性を持っているために、データ駆動での立体構造予測の試みが古くから行われていました。

ja.wikipedia.org

AlphaFold2ではこのような伝統的なタンパク質構造予測アプローチと深層学習をうまく融合させたことにより、圧倒的な性能を実現しました。特筆すべきはAlphaFold2で使われている深層学習のテクニックは最先端のものですが、設計コンセプト自体は伝統的なバイオインフォマティクスの発想(ドメイン知識)に基づいたものであり、所謂「ディープでポン」の対極にあるアプローチであるという事実です。

つまりはAlphaFold2は、深層学習にドメイン知識をどうやって注入すればよいのか?という視点で見ると大変示唆深く面白い手法と感じます。


AlphaFold2の概観

前提の説明が長くなりましたがここからが本題です。

4つのモジュール

f:id:horomary:20210928230812p:plain:w700
論文Fig.1に注釈を追記

AlphaFold2は4つのモジュール*5によって構成されています。

0. データ準備モジュール
立体構造予測を行いたいアミノ酸配列(input sequence)をクエリとし、バイオインフォマティクスのツールを用いて
・ DBからのMSA(Multiple sequence alignment, 後述)作製
・ DBからのテンプレート立体構造(鋳型構造, 後述)検索
を行います。ただしテンプレート立体構造の使用は任意であり無くても構いません


1. Embeddingモジュール
・生MSAにターゲット配列情報を紐づけた MSA representation の作成
・残基間の相対的な位置関係を記録する Pair representation の作成
・スパースな入力値に対してEmbedding(活性化なし全結合層)を行いdenseなベクトルに変換

f:id:horomary:20210929002511p:plain:w500
入力値のEmbedding (Supl. Fig. 1)

2. Evoformerモジュール
Evoformer=Evolution(分子進化)のためのTransformer

・MSAからの特徴抽出
・Pair representationからの特徴抽出
・MSAとPair representationでの情報交換

やりたいことはTransformerのEncoder部とほぼ同じですが、MSAの特性や空間グラフ(タンパク質)の物理的な制約を意識してちょっと変わったattentionの設計(axial attention, triangular attention)が行われています。

f:id:horomary:20210929002352p:plain:w600
Evoformerモジュール (Fig.3)

3. 構造モジュール

IPA (Invariant point attention)モジュールによるMSA表現, 残基ペア表現そして現在の立体構造の統合
各残基への相対的な移動指示(=残基数分の(3,3)の回転行列と(x, y, z)の並進ベクトル)および側鎖のねじれ角(χ1-4)を予測

f:id:horomary:20210929002244p:plain:w500
構造モジュール:shared weightsであることに注意(Fig. 3)


AF2のやってることをざっくり理解する

AlphaFold2とはタンパク質立体構造を入力としてよりrefineされたタンパク質立体構造を出力するネットワークであると理解できます。

AF2の処理を極めて単純化すると以下のようになります。

① すべての残基が原点に集合している構造で主鎖立体構造を初期化(ブラックホール初期化
② 現在の主鎖立体構造とMSA表現およびペア表現を構造モジュールに入力し、各残基への相対的な移動指示を出力
③ 現在の主鎖立体構造に②の出力を適用し、主鎖立体構造を更新する
④ 2-3を一定回数繰り返すことで最終的な立体構造を得る

f:id:horomary:20210928223934p:plain:w600
AlphaFold2の出力とは各残基への相対的な移動(=並進と回転)指示

(実際には側鎖のねじれ角も予測しますが)おおまかな理解としてはAlphaFold2は各残基ごとへの相対的な移動指示(回転と並進)の出力を繰り返すことによってより良い立体構造を模索しているイメージとなります。ここで、相対的な移動指示というのは「残基番号127のグルタミン酸は今の位置から10歩くらい右に動いて」というようなニュアンスです。

ツールとしてのAlphaFold2ユーザーの視点では目的配列を入力すると一発で精確な立体構造が出力されるように見えますが、実際には分子力学シミュレーションによる構造最適化と同じように相対的な立体構造改善のiterationを重ねることで最終的な立体構造を得ています。構造モジュール(Structure module)はこのような立体構造改善iterationを実行する役割を担っています。

また、AlphaFold2が立体構造改善の重要な手掛かりとしているのがMSA(Multiple sequence alignment)です。後述しますがこれはMSAにはアミノ酸配列-タンパク質立体構造相関情報が豊富に含まれているためです。EvoformerモジュールはMSAからの情報抽出を担っています。


0. データ準備

伝統的なバイオインフォマティクス手法による入力データ準備を行うモジュール*6です。立体構造を予測したいタンパク質のアミノ酸配列 (input sequence)をクエリとしたDB検索によってMSA (Multiple sequence alignment) を作成し、さらに自然言語モデルBERTにおけるMasked Language Modelのアイデアを転用してMSAへのマスク・変異導入を行います。

MSA (Multiple sequence alignment) の作成

ヒトもウマもカメも共通の祖先から分化して進化したために、ヒトが持っているタンパク質の多くはウマもカメも持っています。ヒトとウマとカメの外形的な見た目は全く異なる一方で、実はタンパク質レベルでの見た目であればそれほど変わりません。この傾向は下図に示すミオグロビン(酸素を運搬するヘモグロビンの仲間)のような生命維持に不可欠なタンパクであるほど顕著となります。

f:id:horomary:20210929140640p:plain:w300
PDB-101: Molecule of the Month: Globin Evolutionよりミオグロビンタンパクの生物種間比較図
白色で表示されている部分はヒトのミオグロビンとは全く異なるアミノ酸が使われている

ここで重要なのは、同種のタンパク質であれば生物種が変わっても全体的な立体構造はあまり変わらないが、アミノ酸配列レベルではそれなりに差異が生じているということです。ゆえに、あるタンパク質についてさまざまな生物種のアミノ酸配列を並べたもの(=MSA)は、アミノ酸配列-タンパク質立体構造相関を考えるうえで重要な手掛かりとなります。

ただし、アミノ酸配列は分子進化の過程で変異するだけでなく長くなったり短くなったりするために妥当なアミノ酸配列の整列を得るのは容易なことではなく、バイオインフォマティクス分野ではより妥当なMSAを作成するための手法研究が古くから行われてきました。AlphaFold2が使用している HMMER隠れマルコフモデルに基づくMSA作成手法の実装です。

f:id:horomary:20210929143611p:plain:w500
Multiple sequence alignment - Wikipedia より

なお、タンパク質立体構造解析に比べてタンパク質アミノ酸配列の解析は格段にコストが低いために、立体構造は不明だけどもアミノ酸配列ならわかっているというタンパク質が大量にあります。

MSA作成手順(概要):

  1. 入力配列をクエリとして配列の類似性スコアが一定値以上の最大5000配列でMSAを作成する
  2. 配列を雑に間引く(Supl. 1.2.6)
  3. 間引いてなお配列数が128以上の場合は、ランダムに128配列を選択する

配列数を128まで絞っているのは単純に計算コストの問題であり、深い意味はないことに留意ください。また、128配列に選ばれなかった配列についてそのまま捨ててしまうのはもったいないということで統計情報だけは利用するなどさまざまな工夫をしています(Supl. 1.2.7 MSA clustering)が、枝葉の処理なのでここでは詳細を割愛します。


MSAへのBERT風マスク導入

自然言語モデルBERTでは文章中の単語の15%程度をマスクトークンに置き換えて、この穴埋めクイズを事前学習(pre-training)として行うことにより言語への理解を獲得します。(Masked Language Model)

AlphaFold2でもMSAに対してBERTと同様なマスク置き換えを行い、MSA穴埋めクイズを通してAF2ネットワークにMSAの読み方を理解させることを目指します。

f:id:horomary:20210929154531p:plain:w700
Supl. 1.2.7 より

ただし、BERTとは異なりAF2では事前学習(MSA穴埋めクイズ)と目的タスク(立体構造予測)の訓練を分離せず、MSA穴埋め予測クイズのロス項と立体構造予測のロス項を足したものをトータルロスとしてまとめて学習します。(詳細は後述)。


テンプレート構造の検索(任意)

たとえばヒトのミオグロビンタンパクの立体構造を予測したいとします。ここで、もしヒトと十分に近縁であるサルのミオグロビンタンパクの立体構造が既知であるならば、サルのミオグロビンタンパクの立体構造をテンプレート(鋳型)としてヒトミオグロビンタンパクの立体構造を生成するだけで十分に品質の高い予測立体構造が得られます。

AlphaFold2においてもテンプレート構造情報を入力に含めることができますが、必須入力ではない上にablation study (論文 Fig. 4)より使わなくてもパフォーマンスがほぼ変わらないとされているので詳細説明は省略します。


1. Embeddingモジュール

このモジュールではスパースな入力値に対してEmbedding(活性化なし全結合層、つまり線形変換)を行うことでdenseなベクトルに変換します。さらに各入力データを統合し、MSA representation および Pair Representation を出力します。初見では複雑な処理に見えますが、必須のパス(任意パスは灰囲み)だけを見ればMSAに目的配列情報を付加する程度のごくシンプルな処理であることに気づきます。

f:id:horomary:20210929162435p:plain:w500
Supl. Fig1 に注釈を追記。Rの意味はRecycling項を参照


入力データのOne-hot化

入力データはonehot化されています。自然言語処理に慣れた人にはonehot化からのembeddingはお馴染みの定型処理ですが、そうでない人のために何をやってるかの図を置いときます。ここでアミノ酸タイプとは、天然アミノ酸20種+残基不明+欠損+マスクトークンの23タイプです。

f:id:horomary:20210929171905p:plain:w500
アミノ酸配列のonehot表現


MSA Representation

MSA representaion は生のMSAに立体構造予測を行いたいアミノ酸配列(ターゲット配列)情報を紐づけたものと理解できます。また、直後にテンプレート立体構造のtorsion angleがconcatされることから側鎖レベルの詳細な立体構造情報を保持する役割も担っていると解釈できます。

f:id:horomary:20210929183310p:plain:w700
MSA representaionは生MSAにターゲット配列情報が付与されたもの

Pair Representation(残基ペア表現)

Pair Representation(残基ペア表現)とは残基間の関係性(たとえば残基-残基タイプや残基間の空間的距離など)を表現することにより、主鎖レベルの大雑把な立体構造情報を保持する役割を担っていると解釈できます。初期状態では残基-残基タイプ程度の情報しか持ちませんが、AF2ネットワークを進むことにより残基間関係の情報が書き加えられていきます。たとえば、もし立体構造テンプレートを使用している場合は直後にテンプレート立体構造における残基間距離情報が追記(加算)されます。

f:id:horomary:20210929184626p:plain:w600
初期のPair representaionは残基タイプ-残基タイプ情報程度しか保持していない

2. Evoformerモジュール

Evoformerモジュールではself-attention機構によってMSA表現および残基ペア表現からの特徴量抽出を行います。やりたいことはTransformerのEncoder部とほぼ同じですが、MSAおよび空間グラフとしてのタンパク質の物理的制約を意識してやや変わったattention設計を行っています。また、MSA表現と残基ペア表現で情報交換が行われていることも特徴的なアーキテクチャです。

f:id:horomary:20210929215141p:plain:w600
Evoformerモジュール (論文Fig. 3 に注釈を追記)

① axial-attentionによるMSA表現の特徴抽出

上図よりMSA表現からの特徴抽出パスには

  1. MSA row-wise gated self-attention with pair bias
  2. MSA column-wise gated self-attentio
  3. Transition (ただの全結合層)

の3つのブロックが存在することがわかります。

A. MSA row-wise gated self-attention with pair bias:

f:id:horomary:20210929220818p:plain:w800
Supl. Fig2 に注釈を追記

このブロックでは配列方向(row-wise)にaxial self-attention(軸方向注意)を適用することにより、配列内での情報交換を促進します。この(配列, row-wise)軸方向注意の仕組みは、ちょうど人間がMSAを配列方向に眺めて「このあたりはヘリックス*7構造っぽい」とか「システインが複数あるから分子内でスルフィド結合(S-S結合)を形成するかも」と考えつつMSAに注釈をつけていくようなプロセスを再現しています。

なお、このaxial attention(軸方向注意)はAF2論文での新規提案手法ではなく画像認識分野で提案された手法の転用となっています。

f:id:horomary:20210929222358p:plain:w500
Google AI Blog: Axial-DeepLab: Long-Range Modeling in All Layers for Panoptic Segmentation

さてこのMSAについてのself-attentionですが、attention機構を見慣れた人であれば残基ペア表現(Pair representaion)をdot-product affinitiesにバイアスとして加算していることに強烈な違和感を覚えるのではないでしょうか。 安直には残基ペア表現をQuery, MSA表現をKey, Valueとしてattentionをしたくなります。しかし、残基ペア表現(Pair representaion)が保持している情報が残基-残基ペアの空間的関係性であることを思い起こせばバイアスとして加算することはごく自然な発想であることに気づきます。すなわち空間的に近い残基ペアについて大きなattention-weightsが割り当てられるようになるのです。

最後に目につくのはgatingですが、これはLSTMなどにおけるゲート機構と全く同じ役割を担っていると考えられます。すなわち不要な情報を削る”ゲート”の役割です。シグモイド関数で活性化されており素早く0→1を切り替えることができるために、効果的に情報の取捨選択を行えます。


B. MSA column-wise gated self-attention:
MSA row-wise gated self-attentionの軸を残基位置方向に変えただけでやってることはほぼ同じなので技術説明は割愛します。人間がMSAを残基位置方向に眺めて「この残基位置は生物種にわたって変異がほぼがないからきっと重要な残基なのだろう」とか「この残基位置はそれなりに変異あるけど疎水性アミノ酸ばかりだなあ」とか考えつつMSAに注釈をつけていくようなプロセスを再現しています。


② 残基ペア表現(Pair representation) の特徴抽出

残基ペア表現(Pair representation)は残基間の空間的な位置関係情報を保持するように設計されているために、 残基ペア表現の特徴抽出ブロックもまた残基間の空間的な位置関係に着目した設計となっています。

f:id:horomary:20210929231615p:plain:w500
論文Fig. 3より

たとえば上図において残基i-k間の空間的距離と残基j-k間の空間的距離が決まったとすると、三角不等式より "残基i-j間の空間的距離 <= 残基i-k間の空間的距離 + 残基j-k間の空間的距離" でなければならないという制約が生じます。 言い換えると、残基i-j間のペア表現を更新するときには残基i-k間のペア表現および残基j-k間のペア表現と事前に情報交換を行う必要があるということです。関係者への根回しは大事。

そこでEvoformerでは関係の深い残基ペア表現間での情報交換の促進を

  1. Triangular multiplicative update
  2. Triangular self-attention

の2つのブロックで実現しています。ただし、あくまで残基間の物理的な制約を意識して設計されたブロックというだけであり、実際に物理的制約を満たすことは何ら保証されていないということには注意してください。


A. Triangular multiplicative update:

このブロックでは残基iとすべての残基間のペア表現(i行目)および残基jとすべての残基のペア表現(j行目)を使用して残基i-j間のペア表現を更新します。論文Fig3では3残基の組み合わせごとにfor文で処理するような印象を受けますが、Supl. Fig6では行ごとにまとめた計算効率の良い等価処理を説明しています。また、更新前の残基i-j間のペア表現gating機構を通して情報の取捨選択をコントロールしていることがわかります。

f:id:horomary:20210930002353p:plain:w800
論文Fig.3, Supl. Fig6 に注釈を追記

ちなみにこのブロックはTriangular self-attentionの軽量versionとして設計されたものの、Triangular self-attentionと併用することによって精度向上することがわかったという開発秘話が論文Supl.に記述されています。

B. Triangular self-attention:
設計意図はTriangular multiplicative updateと全く同じですが、こちらのブロックではattentionが使用されています。論文Fig.3では残基iとすべての残基間のペア表現(i行)を使用して残基i-j間のペア表現を更新することをfor文で繰り返すような印象を受けますが、論文Supl.では残基iとすべての残基間のペア表現(i行)を使用して残基iとすべての残基間のペア表現(i行)を更新する、という計算効率のよい等価処理で説明されています。

f:id:horomary:20210930010144p:plain:w800
論文Fig.3とSupl. Fig.7 より


3. 構造モジュール (Structure module)

構造モジュール (Structure module)では、Evoformerによって特徴抽出されたMSA表現と残基ペア表現現在の主鎖立体構造
\displaystyle{T_{i, r} = ( \boldsymbol{R}_{i, r}, \vec{t}_{i, r} ) }
を入力として、各残基への「追加」の回転・並進指示 \displaystyle{T_{i+1, r} = ( \boldsymbol{R}_{i+1, r}, \vec{t}_{i+1, r} ) } および 各残基のねじれ角 \displaystyle{ (\omega_{r}, \phi_{r}, \psi_{r}, \chi_{1, r}, \chi_{2, r}, \chi_{3, r}, \chi_{4, r} ) } を出力します。※ i: iteration, r: 残基番号

f:id:horomary:20211001152111p:plain:w800
論文Fig3 に注釈を追記

ここで、"現在の主鎖立体構造"とは、(x, y, z)で表現されるようなリアルな座標ではなく、各残基についての原点からの回転・並進操作 T=(R, t)で表現されることに注意してください。回転・並進操作Tは各残基のCα(主鎖の中心炭素)に対して適用されます。

初めて構造モジュールに到達した場合は、"現在の立体構造"としてすべての残基が原点に集合する並進・回転操作  \displaystyle{
T_{0} = ( \boldsymbol{I}, \vec{0} )
}が初期立体構造として与えられます。なお、この構造初期化スキームは論文中でブラックホール初期化*8と呼称されています。


IPA(Invariant Point Attention)モジュール

構造モジュールではまずIPAモジュールによって、MSA表現と残基ペア表現、そして現在の主鎖立体構造 \displaystyle{T_{i, r} = ( \boldsymbol{R}_{i, r}, \vec{t}_{i, r} ) } が統合されます。

f:id:horomary:20211001003144p:plain:w800
Supplementary Figure 8 に注釈を追記

IPAモジュールの上半分についてはEvoformerのMSA row-wise gated self-attention with pair biasとほぼ同じことをやってるだけなので説明を省省略します。

IPAモジュールの下半分では現在の残基間距離情報を取得する処理を行っています。このために残基数×p点の座標セットを生成し、現在の主鎖構造を表現している回転・並進操作Tを適用したうえで座標セット間の距離を算出しています。

f:id:horomary:20211001000550p:plain:w500
squared distance affinitiesの算出

Tを適用する前の座標セットに固定点を使うのではなくネットワークに動的生成させている意義はよくわかりません。残基タイプを考慮してネットワークがいい感じに座標セットを生成してくれる、たとえばアスパラギン酸のような細長い残基であれば細長い座標セットになるような効果があるのかもしれません。

座標点数pについて、query, keyについてはp=4座標点を生成しています。これは3点以下の座標点では回転操作の効果が薄れるためでしょう。valueについてはp=8座標点と多めに生成しており、より詳細な側鎖構造を意識している気がします。


・ Invariant(不変性)とは?

Invariant Point Attentionの「Invariant」とは、IPAモジュールの出力が主鎖構造のグローバルな回転・並進に依存しないよということを示しています。残基間の相対距離情報のみを利用しているのでグローバルな回転・並進操作に出力が依存しないのは直感的にも明らかです。※証明はSupl.1.8.2

(主鎖のグローバルな回転・並進操作: Pymolなどの3D分子Viewerで対象分子をグルグル回す操作と同様)


Predict relative rotations and translations(主鎖構造の更新)

IPAモジュールの出力するMSA表現は現在の立体構造情報を含むすべての情報を統合した最終MSA表現です。このブロックでは最終MSA表現から相対的な回転・並進指示を予測します。

とはいえやってることは全結合層で残基数分の回転行列と並進ベクトルを出力するだけです。ただし、回転行列(3×3)は直接予測するのではなくquaternion*9の予測を回転行列に変換します。quaternion知らなくともこれを使うと予測すべきパラメータが減って嬉しいくらいの理解でOK。

f:id:horomary:20211001004041p:plain:w600
構造を改善できるような「相対的」な回転・並進指示の予測

出力されるのは「相対的」な回転・並進操作Tなので、現在のTに出力されたTを適用することで現在のTを更新します。

f:id:horomary:20211001005108p:plain:w400
Algorithm 20 Structure module より


各残基のねじれ角予測

Structureモジュールに突入する直前のMSA表現IPAの出力した最終MSA表現から各残基の側鎖レベルの構造=ねじれ角 \displaystyle{ (\omega_{r}, \phi_{r}, \psi_{r}, \chi_{1, r}, \chi_{2, r}, \chi_{3, r}, \chi_{4, r} ) } を予測します。

ω, Φ, Ψは主鎖のねじれ角(二面角)であり、χ1-4は側鎖のねじれ角です。たとえばグリシンなどは側鎖をもたないアミノ酸なのでχ1-4を利用しませんが深層学習の都合上で予測だけは行います。

f:id:horomary:20211001013848p:plain:w400
ω, Φ, Ψは主鎖のねじれ角

f:id:horomary:20211001010833p:plain:w700
Algorithm 20 Structure module より

最終MSA表現だけでなく、s_init = Structureモジュールに突入する直前のMSA表現 が併用されているのは、側鎖構造は主鎖構造に依存するので構造修正の手戻りが多いためでしょう。側鎖構造のバックアップを保持しているとも解釈できます。

なお、角度予測なので安直には[0, 2π]の範囲のスカラ値を予測したくなりますが、そうではなく2次元ベクトルを予測して回転行列に変換するほうが良いとのことです*10


立体構造の出力

ここまでで主鎖構造(各残基の位置)詳細構造(各残基のねじれ角)が決まったので一意的に立体構造を出力できます。しかしこれは暫定的な中間出力構造です。Structure moduleは8回繰り返すことで1サイクルとなっていますので、改善された主鎖立体構造と最終MSA出力を次ブロックの入力として同じ作業を繰り返しましょう。

最終的に出力された立体構造については軽量な分子シミュレーション(AMBER分子力場を使用した構造緩和)を行い、物理化学的に無理な構造を解消します。どの程度の無理があったかはfine-tuning時のみロス関数に使用します(後述)。


ロス関数

AF2のロス関数は複数種類のロス関数の重みづけ和をとったものとなっています。また、初期トレーニング時と自己蒸留データセットを用いたfine-tuning時(後述)ではロスの項が異なる(増える)ことに注意してください。これは物理化学的な無理(Lviol)のような細かい構造変化にセンシティブに反応する項を初期から考慮していると学習が不安定化するためであると考えられます。

f:id:horomary:20211001023502p:plain
AlphaFold2のロス関数

Lfape:
最終出力された予測立体構造と正解立体構造とのズレの指標であるFAPEスコアに基づいたロス項。FAPEはRMSD(RMSD - Protein Data Bank Japan)スコアと似たような意味合いだが回転・並進ベクトルに基づくためキラリティを考慮しやすい。また、構造ズレが大きすぎる場合はClippingされる*11

Laux:
8ブロックの繰り返しで構成される構造モジュールの、ブロックごとの中間予測構造についての平均簡易FAPEロス+ねじれ角予測ロス。簡易FAPEなのでCαについてのみ算出する。

Ldist:
残基間距離(distogram)の予測ロス。Lfapeと相関するはずだがこちらはより大雑把な構造ズレに特化しているので学習を安定させるのではないかと思われる。

Lmsa:

BERT風のマスクが適用されたMSAの穴埋めクイズを通じてAF2ネットワークにMSAの読み方を理解してもらうためのロス項。BERTのpre-trainingを同じ役割が期待されているはず。係数が大きいので実質的に事前学習(pre-training)っぽくなっているのでは。

Lconf:
モデルの信頼性予測スコアであるpLDDTの予測ロス。※後述

Lexp_resolved:
それが実験的に解かれた構造であるかの予測ロス?よく意味がわからない。

Lviol:
AMBER分子力場による構造緩和に基づく、立体衝突などの分子力学的な無理の大きさ。


「予測の信頼性」の予測:pLDDTスコア

AlphaFold2では「予測の信頼性(正確さ)」を予測することを目指します。

より具体的には最終予測立体構造と教師立体構造間の残基ごとに算出されるIDDT-Cαスコアを予測し、その予測誤差をネットワークのロス項に含めます(Lconf)。このIDDT-Cαスコアの予測値についてpLDDTと呼称します。

なおIDDT-Cαスコアとは2つのタンパク質立体構造のズレを妥当に算出するために設計されたスコアのようです。私は改造されたRMSD(RMSD - Protein Data Bank Japan)スコア程度の認識しかできていません。

pLDDTの定義から明らかなようにこのスコアに物理的な意味はありません。が、pLDDTは立体構造が既知の類似配列が乏しいような配列を予測する場合にはスコアが下がると思われますので予測信頼性の予測という表記通りの視点では一定の意味がありそうです。ただし、立体構造既知の類似配列に乏しい ∝ 定まった構造を取りにくいので立体構造が解かれてない ∝ 物理化学的に不安定、 のような疑似相関が見出される可能性は十分にあるかと思います。

pLDDTは後述する自己蒸留における、立体構造が既知の類似配列が乏しい配列についての予測立体構造を除外するような用途であれば存分にworkすることが期待できます。


更なる精度向上のためのトリック

Recycling

48のEvoformerブロックと8のStructureブロックを超えたあなたの手元には、十分に特徴抽出されたMSA表現Pair representaion、そしてすべての残基が原点に集合した初期構造と比較すればはるかに改善された主鎖立体構造を持っています。

ではこの3つを新たな入力として強くてニューゲームしましょう。再開始はEmbeddingモジュールの"R"と書いてある位置です。

f:id:horomary:20210929162435p:plain:w400
"R is for Recycling" (Supl. Fig1 に注釈を追記)

論文Fig.4ではこのRecyclingによって構造精度が徐々に改善するタンパク(緑 T1064*12)があることが示されています。※48ブロックで1cycle

f:id:horomary:20211001153052p:plain:w400
論文Fig.4


自己蒸留データセットによるFine-tuning

実験的に解かれたタンパク質立体構造データセットPDBデータセット)よるトレーニングをある程度終えた後は、データセットを水増ししてfine-tuningを行います。

すなわち配列だけはわかっているタンパク質の立体構造を予測し、予測信頼性スコアが高い立体構造をデータセットに合流させることで教師立体構造データセットの大幅な増強(水増し)を実現します。

このようなラベルなしデータに予測ラベリング(疑似ラベルづけ)を行い、データセットを増強することで精度向上させるという手法は、画像分類タスクで提案された Noisy-studentの手法を踏襲したものとなっています。Noisy-studentは有名かつ汎用性が高い手法であるのでweb上に解説記事が多く存在するため、詳細説明は割愛します。

arxiv.org

なお、予測構造のフィルターとなっている信頼性スコアには上述のpLDDTではなく残基間距離予測に基づいた信頼性スコアを使用しているのですが、これは「その時点でpLDDTが開発されてなかっただけであり、もしplDDT使ってたとしても同じような結果になると思うよ」、という記述があります。


まとめ

AlphaFold2は構造生物学のドメイン知識に基づいたコンセプトをディープラーニングで超人化することに成功した手法と言えます。ドメインの伝統的な発想を高度な深層学習エンジニアリング力で超人化するというDeepMindのアプローチは機械式時計のような精密工芸品を眺めている気持ちになります。こういうドメイン知識を活用するアプローチはAlphaZeroでも成功していますしDeepMindの勝ちパターンになってますね。

horomary.hatenablog.com


本記事の記述に間違いを見つけたらコメントにて指摘をお願い致します。


補足情報

数値のカテゴリ表現

AF2では残基間距離のように数値表現が必要な値は基本的にカテゴリ表現にエンコードします。

本ブログ別記事の転用ですが、下図を見れば数値をカテゴリ表現するという処理の意味がつかめるかと思います。

f:id:horomary:20210729110914p:plain:w400

数値をカテゴリ表現にすることにより
・ スケール感が揃うのでネットワークにやさしい
・ 数値予測時にはロス関数にクロスエントロピーが使えて学習が安定する
・ 複数のロス項間のスケール感の差を気にしなくて良いので学習が安定する
というようなメリットがあります。

このトリックはMuZero(AlphaZeroの後継手法)でも多用されています。

MuZeroの実装解説(for Breaktout) - どこから見てもメンダコ


*1:新たに立体構造が解明された未発表タンパク質が問題として出題される

*2:セレノシステインの話はややこしくなるのでNG

*3:膜タンパクの話はややこしくなるのでNG

*4:分子シャペロンの話はややこしくなるのでNG

*5:ただし論文内でモジュールと呼称されているのはEvoformerとStructure moduleのみ

*6:論文ではdata pipelineの呼称

*7:らせん型の部分立体構造

*8:中心にブラックホールでも存在しないと原子が激しく反発してはじけ飛ぶため

*9:3D-CGを扱うときは必須の概念

*10:Supl. Table 2の直後あたりとAlgorithm 25

*11:極端なHuber lossみたいなものと理解している

*12:ORF8: 新型コロナウイルスのアクセサリータンパクらしい

GKE+Rayで実装するマルチノード分散並列強化学習

Google Kubernetes Engine (GKE) とpythonの分散処並列理ライブラリRayで安価に大規模分散並列強化学習(Ape-Xアーキテクチャ)の実行環境をつくるチュートリアルです。GKEのプリエンプティブルインスタンスを活用することで、総リソース 128 vCPU, NVIDIA Tesla P4 x1, 256 GB memory のクラスタがざっくり150-200 円/時間になります。(2021年8月時点)

f:id:horomary:20210831230023p:plain:w700

rayで実装する分散強化学習シリーズ:
Pythonの分散並列処理ライブラリRayの使い方 - どこから見てもメンダコ
rayで実装する分散強化学習 ①A3C(非同期Advantage Actor-Critic) - どこから見てもメンダコ
rayで実装する分散強化学習 ②A2C(Advantage Actor-Critic) - どこから見てもメンダコ
rayで実装する分散強化学習 ③Ape-X DQN - どこから見てもメンダコ


はじめに

分散並列強化学習のメリット

分散強化学習は、単に処理が高速化するだけでなく多様な状態遷移の収集が可能になり学習が安定化することがメリットです。

このことをA3C論文が当時のatari環境のSotAという結果で示して以来、軽量なシミュレーターが利用可能な強化学習環境(atariやMuJoCoなど)では分散並列化が基本テクニックとして採用されるようになりました。とくにApe-X DQNR2D2Agent57などのDQN派生の手法では、並列化されたagentごとに異なる探索戦略(単純には探索率εの値など)を割り当てるマルチ方策学習が採用されているため、分散並列化することが手法の前提となっています。

このような強化学習の分散並列化トレンドに対応すべく、本稿ではGoogle Cloud PlatformのマネージドKubernetesであるGKE(Google Kubernetes Engine)を利用して分散並列強化学習環境を構築します。また、並列化のバックエンドとしてはRayライブラリを使用します。

cloud.google.com

docs.ray.io


Ape-X アーキテクチャ

より具体的には、本稿はGCPのマネージドKubernetesであるGKE上でApe-Xアーキテクチャを実装するチュートリアルです。
※Ape-Xの手法自体の説明は過去記事をご参照ください。

f:id:horomary:20210227144313p:plain:w600
apex論文より:Actorがマルチノード分散並列化される

rayで実装する分散強化学習 ③Ape-X DQN - どこから見てもメンダコ


各プロセスの概要および要求リソースは以下のようになります。

Replay(1CPU, 0GPU, メモリたくさん)×1プロセス:
Actorが収集した遷移情報の受け取り、およびLearnerへの遷移情報の送信を行います。また、メインプロセスを兼ねます。

Leaner (1CPU, 1GPU, メモリ多少)×1プロセス:
Replayから遷移情報のミニバッチを受け取ってひたすらネットワーク更新だけを行います。これによって1台のGPUを最大効率で活用できるというのがApe-Xアーキテクチャの嬉しさです。

Actor(1CPU, 0GPU, メモリ多少)×200~300プロセス:
ひたすら環境と相互作用(atari環境ならゲームをプレイ)して遷移情報を収集し、Replayプロセスへ送信します。Actorは行動選択時にQネットワークでの推論を行いますが、1サンプル推論なのでGPU無くても問題ないです。


なぜGKEを使うか?

プリエンプティブルVMが格安

GPUが不要であるならマルチノード分散並列化せずに単一ノードのウルトラハイスペックインスタンスをGCE(AWSでいうEC2)で用意してもいいのですが、GCEではGPUを利用する場合には1GPUあたりで利用可能なCPU数に上限がかかるので大規模な並列化はできません。たとえばNVIDIA T4ではGPU1枚あたり24 vCPUが上限です。

Compute Engine の GPU  |  Compute Engine ドキュメント  |  Google Cloud

かといってGCEインスタンスを手動でたくさん立ててクラスタ構築するのはあまりに煩雑ですので、GCPのマネージドKubernetesであるGKEを利用していきます。マネージドKubernetes自体はAWS(EKS)やAzure(AKS)でも提供されていますが、それらではなくGKEを選択するのは安価なプリエンプティブルVM が利用可能であるためです。プリエンプティブルVMとはGCPの余ったリソースを定価の7-8割引きという格安で提供するインスタンス形式です。ただし、最大24時間しか持続しない上にGCPのリソースの状況次第で突然停止されることもあります。

GKEのプリエンプティブルインスタンスを活用することで、総リソース 128 vCPU, NVIDIA Tesla P4 x1, 256 GB memory のクラスタがざっくり150-200 円/時間になります。※2021年8月時点の概算

GPU の料金  |  Compute Engine: 仮想マシン(VM)  |  Google Cloud


Autoscalingがお手軽

プリエンプティブルVM以外のGKEの利点としては、計算リソースのAutoScaling機能がお手軽かつ優秀*1なのでノード構成をあまり意識しなくてよいということがあります。とくに最近リリースされたGKE AutoPilotモードでは物理ノード構成を一切意識しなくてよいという大変便利なものになっていますが、AutoPilotはプリエンプティブルVMをサポートしていない(21年8月時点)ので本稿ではGKE Standardモードで構築します。お金のある人はAutoPilotモードがおすすめです。

cloud.google.com


なぜRayを使うか?

分散並列化のバックエンドにはRayライブラリを使用します。Rayを採用するとPythonのマルチノード分散並列処理が驚くほど楽にできるようになります。*2

horomary.hatenablog.com

Rayの利点①:並列化コードを書くのが楽

分散並列化でなく単なる並列化であれば、multiprocessingjoblibのような並列処理のライブラリも利用可能ですが、Rayはこれらの既存ライブラリと比べても遜色なくシンプルに並列処理のコードが書けます。


Rayの利点②:単ノード並列(MP)→マルチノード分散並列(MPI)でコード変更がほぼ不要

単一マシンでの並列処理コードをほぼ変更することなくマルチノードで分散並列処理ができることはrayの大きな利点のひとつです。ローカルマシンでデバッグしつつ作成した並列処理コードをそのままスケールアップして分散並列処理することができるため生産性が高くなります。

たとえば上のサンプルコードでは、並列処理と分散並列処理でコード変更が必要なのはクラスタの初期化処理(ray.init())の引数のみです。


Rayの利点③:クラスタのセットアップが楽

OpenMPIしかりMPICHしかり、分散並列処理フレームワークは環境構築が煩雑な印象があります。一方でrayはpip install rayだけで環境構築が完了です。クラスタの起動も簡単で、ヘッドノード(pythonスクリプトを起動するノード)でray start --head --port=6379、ワーカーノードで ray start --address='<ヘッドノードのIP>:6379' を実行すればクラスタの準備完了です。あとはヘッドノードにてrayで並列化が記述された任意のpythonコードを実行するだけとなります。


マルチノード分散強化学習チュートリアル

ここからは実際にGKEでクラスタを構築し、rayによって分散並列化された強化学習を実行するチュートリアルです。

f:id:horomary:20210831230023p:plain:w700

強化学習コードやKubernetesマニフェストファイルの詳細はGithubを参照ください
github.com

1. 並列強化学習の実装とDockerイメージ作成

まずは普通にローカルマシン上にてrayによって並列化されたApe-Xアーキテクチャを実装します。
※実装自体は過去記事とほぼ同じなので省略します

horomary.hatenablog.com


次に、実装したコード(code/以下)を動かすためのdocker imageを作成します。

# Dockerfile
FROM tensorflow/tensorflow:2.5.1-gpu
COPY ./code /code
RUN pip install -r code/requirements.txt

作成したイメージはdockerhubかGCR(Google container registry)にpushし、GKEから利用可能な状態にしておきます。


2. Kubernetesマニフェストの作成

分散学習クラスタkubernetesマニフェストに記述します。

  • entrypointtype=LoadBalancerServiceリソースです。
    このServiceリソースは外部からtensorboardおよびray-dashboardにアクセスして学習をモニタリングするためだけに使用するので必須ではありません。(※この設定はIPさえ知っていれば誰でもtensorboardにアクセスできるので機密プロジェクトでは使用しないでください。

  • ray-headless-svcはワーカーノードがヘッドノードを名前解決するためのHeadless Serviceです。

  • masterはおよびメインプロセスおよびtensorboardコンテナを起動するPodです。よってこのPodが稼働しているノードがヘッドノード(pythonスクリプトを起動するノード)です。Ape-XのメインプロセスはReplayBufferを持ちメモリを大量に消費するのでresources.requestsmemory=24GiBをリクエストしています。また、GPUを利用するLearnerプロセスとメインプロセスを物理的に同じノードに置きたいのでnvidia.com/gpu をリクエストしています。


  • actorはワーカーノードに配置されるactor-podを複製するReplicaSetです。各actor-podは15CPUを要求します。ワーカーノードからヘッドノードへの通信確立はray start --address='<ヘッドノードのIP or ホスト名>:6379' コマンドで行います。ただ当然ながらヘッドノードが立ち上がっていないとこのコマンドは失敗するのでmasterが起動するまで待機するスクリプトを先に実行します。この待機スクリプトConfigMapに記述してマウントしています。


3. 環境構築

GKEをローカルマシンから操作する前準備を行います

がローカルマシンから実行可能にしておいてください。


4. GKEへのクラスター構築

※ここからはすべてローカルマシンでの操作です

まずはGCPへのログインと新規プロジェクト作成

#: ブラウザが立ち上がりログイン画面が表示される
gcloud auth login

#: 任意のIDおよび名前でプロジェクトを作成
#: gcloud projects create <ProjectID> --name <ProjectName>
gcloud projects create distrl-project --name distrl


つぎにconfigにデフォルト値を設定することで以後のコマンド入力を楽にします。 regionによっては使えないGPUもあることに留意してください

GPU regions and zones availability  |  Compute Engine Documentation

#: gcloud config set project <ProjectID>
gcloud config set project distrl-project

#: gcloud config set compute/region <RegionName>
gcloud config set compute/region northamerica-northeast1

#: gcloud config set compute/zone <zoneName>
gcloud config set compute/zone northamerica-northeast1-a

gcloud config list

注意:
GCPアカウントの利用実績がない場合は、プロジェクトが同時に利用可能な総CPU数/GPU数/メモリ に強い制限がかかっています。この場合はReplicaSet/actorreplicas=1と設定して並列数を減らすか、下記リンクを参考にリソース割り当ての増加をリクエストしてください

割り当てと上限  |  Network Service Tiers  |  Google Cloud


以下のコマンドでクラスタを構築しますがここからは時間課金されるのでクラスタの消し忘れに注意してください。不安になったらWeb-GUIを確認しましょう。まあ万が一クラスタを消し忘れてもプリエンプティブルVMは24時間で消えるのでダメージは小さいです。

#:  GPU node-pool (1ノード) の作成
#:  16 vCPU,  32GiB memory,  1 NVIDIA Tesla P4 GPU
gcloud container clusters create rl-cluster \
    --accelerator type=nvidia-tesla-p4, count=1 \
    --preemptible --num-nodes 1 \
    --machine-type "custom-16-32768"

#: CPU node-pools (autoscale) の作成
#:  各ノード 16 vCPU,  32GiB memory, 0 GPU
gcloud container clusters node-pools create cpu-node-pool \
    --cluster rl-cluster \
    --preemptible --num-nodes 1 \
    --machine-type "custom-16-32768" \
    --enable-autoscaling --min-nodes 0 --max-nodes 30 \

# ローカルマシンからGKEクラスタにkubectlする権限取得
gcloud container clusters get-credentials rl-cluster

#: Install GPU driver
kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded.yaml

重要なオプション:

  • --preemptible: プリエンプティブルインスタンスを指定
  • --enable-autoscaling: 計算リソースが不足した場合は自動でノードプール内のノード数を増やす

gcloud container clusters create  |  Cloud SDK Documentation


4. 学習実行とモニタリング

ようやく学習を開始します。

#: GKEにマニフェストを反映
kubectl apply -f apex-cluster.yml

#: ヘッドノードへログイン
kubectl exec -it master bash


ヘッドノードへログイン後、ray statusコマンドを実行することでクラスタの状態を確認できます。

>>> ray status
~~ 中略 ~~
Resources
------------------------------------------------------------
Usage:
 0.0/109.0 CPU
 0.0/1.0 GPU
 0.0/1.0 accelerator_type:P4
 0.00/172.930 GiB memory
 0.00/74.506 GiB object_store_memory

Demands:
 (no resource demands)


クラスタが正常に作成されていることが確認できたのでヘッドノードでpythonスクリプト実行することで学習を開始します。

#: 学習の実行(100プロセスのactorを分散並列実行)
python /code/main.py --logdir log/tfboard --cluster --num_actors 100 --num_iters 30000


6. モニタリング

kubectl get svc master-svcを実行して表示される <EXTERNAL-IP>:6006にブラウザアクセスすることでtensorboardを見ることができます。

f:id:horomary:20210831021636p:plain:w500


また、 <EXTERNAL-IP>:8265からはクラスタのリソース使用状況などを確認できるrayの素敵機能ray-dashboardにアクセスできます。

f:id:horomary:20210828170029p:plain:w600
ray dashboard


5. クラスタの削除

クラスタ削除を忘れずに!

gcloud container clusters delete rl-cluster


まとめ

ほんとうは256並列でCartPoleやろうと思ってたのですがCPU割り当て増加リクエストが128 vCPUまでしか承認されなかったので100並列に押さえました。こんな簡単にHPCできるなんてすごい時代になったものだ。


*1:EKSでも同じような機能があると思うけどAWSはよく知らない

*2:ただし用途が巨大なデータフレームの分析なら素直にpysparkとかdaskを使うのが吉