どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

オフライン強化学習② Decision Transformerの系譜

Decision transoformer (2021)は、自然言語モデルGPTにおける次トークン予測の枠組みでオフライン強化学習タスクを解けることを示し新たなパラダイムをもたらしました。最近ではDeepMindの超汎用エージェントGATOなどもDecision Transformerベースのアーキテクチャを採用しており、その重要性が増しています。


オフライン強化学習シリーズ:
オフライン強化学習① Conservative Q-Learning (CQL)の実装 - どこから見てもメンダコ
オフライン強化学習② Decision Transformerの系譜 - どこから見てもメンダコ
オフライン強化学習③ Implicit Q-Learning (IQL)の実装 - どこから見てもメンダコ
オフライン強化学習④: 拡散モデルの台頭 - どこから見てもメンダコ


Decision Transformer とは

sites.google.com

オフライン強化学習の新たなパラダイム

オフライン強化学習、すなわち環境からのサンプル収集(=オンライン学習)を一切行わず、事前に用意されたデータセットのみで強化学習するという問題設定は実用上大きな意味があります。ビジネスにおいて時間とコストがかかるデータ収集をハイパラ設定を変えるたびにゼロからやり直すのはわりに合いませんし、医療や化学プラントなどではそもそも気軽な試行錯誤が許されないためです。

:オフライン強化学習とは事前用意されたデータセットだけで学習するoff-policy強化学習
([2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems より)

オフライン設定は実用上大変重要ですが、従来のオフポリシー強化学習(例: DQN, SACなど)を単純にオフライン設定で実行するといくつかの理由*1で壊滅的なパフォーマンスになることが知られています。ゆえにそれまでオフライン強化学習とはオフライン設定においてオフポリシー強化学習手法をどうやってperformさせるか、ということが主な論点となっていました。しかし、Decision TransformerではGPTライクな系列生成アプローチを模倣学習に導入することにより、教師あり学習によって強化学習タスクを当時のSotAオフライン強化学習手法(CQL)に相当する性能で解けることを示しオフライン強化学習に新たなパラダイムをもたらしました。

実スコアでなくゲーマー標準化スコアであることに注意


言語を生成するように行動を生成する

強化学習というのは基本的に時系列における逐次意思決定タスク、すなわち各タイムステップごとに過去の観測を考慮して適切な行動を選択するタスクに対して使用されます。そして自然言語分野における次単語生成タスク、たとえば「吾輩は猫である、名前は~」に続く自然な単語を選択するというのもまた逐次意思決定タスクです。

ならば、BERTやGPTモデルによる教師あり学習で自然な文章を生成するように、BERTやGPTモデルによる教師あり学習で自然な行動を生成できるのではないか? というのがDecision Transformerの基本的なコンセプトです。

Decision Transformer:sの出力が次アクションを予測する。R, aの出力はとくに使わない
https://sites.google.com/berkeley.edu/decision-transformer

Decision transfomerではネットワークにGPTアーキテクチャを採用しており、*2 上図においてcausal transformerと記載されている部分はほぼGPTを転用したものです。学習においてもGPTのpretrainingにおける次単語予測と同様に、時刻tにおける状態s_tの出力がアクションa_tを予測するように訓練します。(なお、R_t, a_t の出力はとくに利用しません。論文によるとa_tの出力が次状態s_tを予測するようにするなども試したけどとくにパフォーマンスに寄与しなかったとのことです。)

ここでR_tとは即時報酬ではなく、t+1時以降に獲得する将来報酬和であることに注意してください。R_tによってDecision Transformerは条件つきの行動予測が可能になっています。たとえばR_tが大きいならエキスパートのような行動を選択し、R_tが小さいなら初心者のような行動を選択します。(詳細は後述)


自然言語風アプローチのメリット

強化学習的なタスクにおいてGPTのような自然言語モデル風のアプローチを採用することには大きく2つの嬉しさがあります。

1. ブートストラップ法を回避できる
強化学習におけるブートストラップ法とは(動的計画法やTD学習のように)教師ラベルに予測値が含まれているような更新方法を示します。ブートストラップ法はオフライン設定において価値の推定誤差を蓄積することがわかっており、これがオフライン強化学習の主要な困難のひとつとなっています。DTではGPTと同じように教師あり学習するだけなので厄介なブートストラップ法を使う必要がありません。

2. 長期credit割当てを効率的に行える
TD学習では行動と報酬発生の時間差が大きい(例: ブロック崩しにおいて報酬が発生するのはボールを弾いてからしばらく後)と、行動と報酬の因果関係を学習するのにかなりの時間がかかることが問題になります。一方、Transformerはそもそも離れた単語間の関連性を効率よく学習することを意図して設計されているのでこの問題を自然に解決することが可能です。


条件付き生成:Reward conditioned

前述のように、R_tとは即時報酬ではなくt+1時以降に獲得する将来報酬和です。ゆえにR_tが大きい=これからたくさんスコアを獲得する=エキスパートによるプレイであり、R_tが小さい場合は同様の理由で初心者のプレイであると理解できます。

将来報酬和R_tが入力シーケンスに含まれているために、DTは与えられたR_tが大きいならエキスパートのような行動を出力し小さいなら初心者のような行動を出力することが可能です。これは単純な模倣学習とは異なりデータセットに質の悪いサンプルが混ざっていても問題ないことを意味します。逆にR_tが無いと単なるGPTアーキテクチャのBehavior Cloningとなります。

Google AI blog: Multi-Game Decision Transformerより(アップロードの都合上フレーム数を削減)

また、このような条件付けの役割を担うのは必ずしも将来報酬和Rtである必要はありません。たとえば格ゲーデータセットがあるならばR_tをプレイヤーの名前に置き換えることでウメハラ氏やときど氏の行動を再現することができるでしょう。ゲームNPCのAI開発なんかに使うといい感じに難易度調整できそうですね。


Sequence modelingの系譜

Decision Transformerがもたらした新たなパラダイム強化学習 via 自然言語風Sequence Modeling”。 数多くの派生手法が考案されている中から個人的にとくに印象深かったものを紹介します。

Multi-Game Decision Transoformer(NeurIPS 2022)

GPTやBERTなどの巨大transformerアーキテクチャは、教師無し事前学習による未知タスクへのfew-shot learning性能パラメータ数を増やすとパフォーマンスも向上する"べき乗則"などの好ましい特性を持つことが知られています。Multi-Game Decision Transformerでは、このような自然言語基盤モデルの持つ好ましい特性をDecision Transformerも持つのだろうか?ということを調べています。

Multi-Game Decision Transformers

1. Few-shot learningは可能か?

BERTやGPTの最大のメリットは教師無し事前学習済みモデルのファインチューニングによって未知タスクにスモールデータでも対応できることです。Decision Transformerも同様の転移学習性能を持つのかの検証のために、Atariドメイン46ゲーム中から41ゲームを単一のネットワークで訓練した後、未知の5ゲームについて100K ステップのファインチューニングを行い性能を比較しています。事前学習、ファインチューニングどちらもオフラインデータセットでの訓練であることに留意。

 

すべてのゲームで事前学習ありDT(DT pretraining)が事前学習なしDT(DT no pretraining)にoutperformしているので未知タスクに対するpre-trainingの効果はたしかに存在するようです。とくに事前学習なしDTではまったく性能がでていないPongでも事前学習ありDTではしっかりperformしているのが興味深い。PongはなんかDTと相性が悪いみたいで、オリジナルDT論文でもPongだけは入力シーケンスを長くしないとうまく学習しないなど苦労していたようです。

2. パラメータを増やしてGPUで殴ればいい

GPT-3論文がTransformerモデルはパラメータ数を増やせばパフォーマンスもべき乗則 (Power Law)で向上するという実験結果を発表したことによって、NLP分野は大企業が計算リソースで殴り合う末法の世と化しました。DTではどうでしょうか?

Decision Transformerでもべき乗則が適用されるっぽい

3. ViTは有用か?

DTでは観測のトークナイズをDQNのCNNで行っていますが、MGDTではViTを採用しています。これについて比較実験をしたようですがCNNとViTで顕著な差は見られなかったとのことです。

ちなみにMGDTではR, Sも予測対象になっている

ViTだとattentionを簡単に可視化できるので楽しい。

Training Generalist Agents with Multi-Game Decision Transformers – Google Research Blog

Uni[Mask](NeurIPS 2022): MaskedLMの導入

Uni[MASK]: Unified Inference in Sequential Decision Problems | OpenReview

DTは単方向モデルであるGPTアーキテクチャを採用しているため次行動予測タスクに特化しています。そこでUni[Mask]ではBERTのMaskedLMアーキテクチャを採用したうえでマスクの置き方をタスクごとに工夫することで、次行動予測はもちろん状態遷移予測やゴール状態指定などさまざまなタスクに対応できることを提唱しています。アイデアはシンプルですがよくまとまってます。ただし検証はMaze2D環境のみ。

 

GATO(2022):超汎用エージェント

A Generalist Agent | OpenReview

www.deepmind.com

ネットワークアーキテクチャが同じなんだから対話生成もロボット操作もレトロゲームも全部Transformerでマルチタスク学習できるのでは?という頭の悪い発想を世界最高峰の頭脳集団が実現してしまったのがGATOなる超汎用エージェント。実際すごい。

DeepMind Blogより

Decision Transformerでconditioningの役割を担っていた将来報酬和RがGATOでは汎用性確保のため使われていません。かわりにデモンストレーションシーケンスを最初に入力することでconditioningを行っているようです。この方式はprompt conditioningと呼称されています。

prompt

Algorithm Distillation(ICLR2023):学ぶことを学ぶ

openreview.net

※under review

DTのメタ強化学習への発展強化学習アルゴリズム(e.g. DQN, A3C)によって生成されたオフラインデータセットを十分に長い入力シーケンス長を設定したDecision Transformer(というかGATOっぽい)で学習すれば、"学び方を学ぶ"のでは?という手法。強化学習アルゴリズムが探索と活用によって学ぶ様子を再現するのでアルゴリズム蒸留なのでしょう。 

ADはパラメータ更新せずにパフォーマンスを改善する

上図は強化学習アルゴリズムによって生成されたデータセットをオフライン学習済み(学び方の学習済み)のADでパフォーマンスを評価したものですが、ぱっと見では何がすごいのかわかりにくい。ポイントはこの評価中にADは一切のパラメータ更新を行っていないにも関わらずパフォーマンスが向上していることです。これはネットワーク自体が自己改善機能を獲得したことを示しています。なおTransformerでなくLSTMでもOKらしい。(すごそうな手法に見えますが私はメタ強化学習分野に明るくないので先行研究と比べてどれくらいすごいか正直よくわかってない)

同様のコンセプトの先行研究としてOptFormerなどがあります。これはGP, TPEなどが機械学習モデルのハイパラを最適化する過程をTransformerに学ばせることでブラックボックス最適化アルゴリズムの蒸留ができるというものです。

ai.googleblog.com


Decision TransformerのTF2実装

実装全文:

github.com

元論文:[2106.01345] Decision Transformer: Reinforcement Learning via Sequence Modeling

公式実装(pytorch):
GitHub - kzl/decision-transformer: Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.

データセット DQN Replay Dataset

ネットワーク構造

ではDecision TransformerをTF2, Atari/Breakout環境向けに再現実装します。 と言ってもやることは30遷移分の(R, s, a)をトークナイズしてGPTにつっこむだけなので正直コメントすることが無い。

各入力(将来報酬和Rt, 観測状態st、アクションat)は128次元になるよう埋め込む。いずれも最後にtanhでactivationされるのがポイント。

観測状態についてはDQNのCNNで3136次元にした後に全結合層で128次元にする。

stの出力がatを予測するようにカテゴリカルクロスエントロピーを損失関数にネットワークの訓練を行う。Rt, atの出力はとくに使わない。Rtの出力でstを予測するなども試したらしいがとくにパフォーマンス向上に貢献しなかったと論文に書いてあった。

トークンの整列処理はちょっと分かりにくいかもしれないので簡易版を置いておく。

>> rtgs = np.array([R1, R2, R3])
>> states = np.array([s1, s2, s3,])
>> actions = np.array([a1, a2, a3,])
>> tokens = np.stack([rtgs, states, actions], axis=0).T.reshape(1, -1)
>> tokens
[[ R1, s1, a1, R2, s2, a2, R3, s3, a3]]


学習結果

テストではClippedスコアが70点になるようにconditioningして実行。最初の40Kまででほぼ目的の性能に到達している。

clippedスコアで70点を取るようにconditioning

論文掲載スコアを概ね再現できている模様

論文掲載スコア


所感

実装も訓練もデバッグもめっちゃ楽。クロスエントロピー最小化するだけでいいのが楽園すぎる。こんな手軽なら実務でも使いどころあるかも。

*1:[2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems

*2:BERTでなくGPTを使っているのは、GPTが過去コンテクストのみを活用する単方向モデルであるために次行動予測と相性が良かったためと思われる