どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

オフライン強化学習④: 拡散モデルの台頭

オフライン強化学習における拡散方策の近年の適用例を概観し、tensorflowで実装します。

オフライン強化学習シリーズ:
オフライン強化学習① Conservative Q-Learning (CQL)の実装 - どこから見てもメンダコ
オフライン強化学習② Decision Transformerの系譜 - どこから見てもメンダコ
オフライン強化学習③ Implicit Q-Learning (IQL)の実装 - どこから見てもメンダコ
オフライン強化学習④: 拡散モデルの台頭 - どこから見てもメンダコ

[2006.11239] Denoising Diffusion Probabilistic Models

※注:本稿における拡散モデルとは基本的にDDPMを指します

背景

拡散方策(Diffusion Policy)の登場

拡散モデル(Diffusion model)が様々なデータ生成タスクにおけるゲームチェンジャーとなっています。DALLE-EやStable Diffusionなど拡散モデルを利用した画像生成サービスはいまやAI研究者だけでなく一般の人々にも広くが知られるものとなり、ChatGPTなど大規模言語モデルと並んで近年の生成AI(Generative AI)ブームを牽引しています。

拡散モデルは画像生成タスクにおける研究が先行してきましたが、画像に限らず動画、音声、点群などあらゆる連続値データの生成において有用な技術です。そして強化学習においては連続値アクション環境における方策関数として、すなわち行動生成器として拡散モデルを利用することができます

Using generative AI to imitate human behavior - Microsoft Research

拡散モデルが強化学習における方策関数として優れている点は2つあります。まず、拡散モデルは様々な条件付き生成を容易に実現できることです。これにより、観測oで条件づけられた行動aの生成モデル π(a | o) として拡散モデルを実装することで、決定論*1方策関数のように利用することができます

もうひとつは、拡散モデルで実装された方策(拡散方策)は高い表現力をもつために、データの多様性を捉える能力が高いということです。事前データセットだけで方策を訓練しなければならないオフライン強化学習にとってモデルの表現力は非常に重要な要素であり、とくに模倣学習では飛躍的な高性能化が期待できます。


模倣学習の大幅な性能向上

表現力豊かな拡散方策によって最も大きな恩恵を受けるのは模倣学習でしょう。これは多峰性をもつオフラインデータセットを従来的なガウス方策など表現力の低いモデルでフィッティングしようとすると致命的なエラーが生じるためです。

Using generative AI to imitate human behavior - Microsoft Research

実世界のオフラインデータセットでは基本的に複数人の過去の行動決定の寄せ集めとなっていることが想定されます。ゆえにデータセットは同じ状態oにおいて必ずしも同じ行動aをとるとは限らないため、データセット方策 π(a | o)はマルチモーダルな分布となるはずです。このようなマルチモーダルな状態行動分布p(a | o)をユニモーダルなガウス方策(Gaussian)や回帰(MSE)でフィッティングしてうまくいくはずがないことを上図は分かりやすく示しています。

一方、拡散方策は任意の分布を表現することができるためにマルチモーダルな行動分布を持つオフラインデータセットであってもうまくフィッティングすることができ、これによって模倣学習の大幅な性能向上を見込むことができます。

従来的には混合正規分布を採用することでも方策の表現力を向上させることができますが、この場合は適切な混合数Kが状態oに依存するという問題が生じます。拡散方策であれば混合正規分布方策のように適切な混合数を決定する必要がないため、実装上の困難さを軽減しつつもより柔軟なモデルを構築することが可能となります。


Diffusion-QLの衝撃

連続値アクション環境において、多くのオフライン強化学習手法では模倣学習方策を価値関数で多少チューニングすることで方策を獲得しています。ゆえに模倣学習方策の性能向上がそのままオフライン強化学習の性能向上につながることが期待できます。そしてこれを実際にやったDiffusion-QLが多くのタスクでSOTAを示したことにより、拡散モデルは画像生成分野のみならずオフライン強化学習分野でもゲームチェンジャーとなりました。

[2208.06193] Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning

特筆すべき点として、Diffusion-QLは従来のベースライン手法であるTD3+BCのBC(behavior cloning, 模倣学習)を拡散モデルにしただけ、というシンプルさでSOTA性能を達成したことがあります。TD3+BCは実用性重視で開発された手法であるゆえにDiffusion-QLもまた実用的な手法となっているため、実世界の課題にも適用しやすそうなのが嬉しいポイントです。


主要な手法・論文

※順序は発表時系列に一致しない

Diffusion-QL:拡散方策のミニマリストアプローチ

[2208.06193] Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning

前述の通り、Diffusion-QLは[2106.06860] A Minimalist Approach to Offline Reinforcement Learningにて提案されたオフライン強化学習手法TD3+BCに拡散方策を導入することで高性能化に成功した手法です。TD3+BCは模倣学習方策をTD3でチューニングするだけという非常にミニマルな構成ながら頑健な性能を示す実用的な手法であるためにオフライン強化学習のベースラインとして使われ続けてきた手法です。ちなみにこの論文の著者の一人であるShixiang Gu氏はOpenAIの日本担当ということで一躍時の人となっています。

TD3の解説・実装(強化学習) - どこから見てもメンダコ

Diffusion-QLではTD3+BCにおける模倣学習を拡散方策で行うことで、マルチモーダルな状態行動分布にもうまくフィッティングできること(図上段)、さらに拡散モデルによる模倣学習方策をTD3でチューニング*2することで高報酬領域に方策を集中させることができること(図下段)を示しました。

個人的に衝撃だったのは、Diffusion-QLにおけるQ学習はオフライン強化学習向けに調整されていないTD3スタイルQ学習であるにも関わらず既存SOTA手法を凌駕するパフォーマンスを示したことです。これまでのオフラインQ学習とは何だったのかという・・・。逆に言うとDIffusion-QLはQ学習に改良の余地を残しているということでもあり、この方向のアプローチを行ったのが後述するIDQLです。


IDQL: Implicit Q-Learning+拡散方策

[2304.10573] IDQL: Implicit Q-Learning as an Actor-Critic Method with Diffusion Policies

Sergey Levineのオフライン強化学習シリーズ。連続値アクション環境における従来SotA手法 Implicit Q-Learning(IQL)の方策を拡散方策に切り替えたらやっぱり強かった、以上。(ただし比較的単純なタスクではあまり差が出ていない)

IDQLではDiffusion-QLとは異なりQ関数による拡散方策のチューニングは行わず、模倣学習とQ学習を完全に切り離す。このため、推論時にはある状態sについて拡散方策からたくさんのaをサンプリングしたうえでQ評価値を採択確率として確率的に行動決定します。

horomary.hatenablog.com


深堀り模倣学習:Using generative AI to imitate human behavior

Using generative AI to imitate human behavior - Microsoft Research

模倣学習のための拡散モデルという観点でDeepDiveしたMicrosoftの研究。論文タイトルについて、拡散モデルではなくGenerative AIというワーディングに資本主義を感じる。

Using generative AI to imitate human behavior - Microsoft Research

技術的な目新しさは乏しいものの、ネットワークアーキテクチャ、分類器なしガイダンスの効果、サンプリングスキームなど、実務的に重要な項目について詳細な比較検討を行っている。とくに、画像生成タスクで広く普及している分類器なしガイダンス(CFG)を模倣学習に導入すると頻度の低い行動選択をするようになるのでパフォーマンスが悪化するという解析は面白い(3.3および付録E)


Decision Diffuser :分類器無しガイダンス(CFG)の活用

Is Conditional Generative Modeling all you need for Decision-Making?

Decision Diffuserは拡散モデルのガイダンス付き生成能力を活用した手法です。たとえば将来報酬和(returns-to-go)をガイダンスとした使用した場合、データセットから高報酬領域のトラジェクトリを選択して生成することできます。タイトルからも想起されるようにこのコンセプトはDecision Transformerと近いものです。ただし、Decision transformerはGPTアーキテクチャを使用して自然言語生成のように行動生成するために条件付けは(言語モデルでいう)プロンプトによって行う一方で、Decision Diffuserでは分類器無しガイダンス(Classifier-Free Guidance, CFG)を使用して条件付けを行います。

ガイダンスの対象には将来報酬和(returns-to-go)だけでなく制約条件やスキル(ロボット犬制御なら歩く、走るなど)も含まれます。興味深いのは画像生成における「月面で馬に乗る宇宙飛行士」の例のように、データセット内の既知の概念を組み合わせることでデータセット外のサンプルを生成する能力をDecision Diffuserも備えていることです。(下図)

アーキテクチャについて、Decision Diffuserでは拡散モデルによって行動aではなく状態sを生成します。これまで紹介した手法ではいずれも拡散モデルを行動aを生成する方策モデルとして利用していたこととは対照的です。 方策モデルは明示的に持たず、まずSt+1の生成を拡散モデルによって行った後に、St→St+1の遷移に対応する行動aを予測するinverse dynamics問題を解くことで行動選択とする二段構えのアプローチになっています。

 


Tensorflowによる拡散方策の実装

[2208.06193] Diffusion Policies as an Expressive Policy Class for Offline Reinforcement Learning

GitHub - Zhendong-Wang/Diffusion-Policies-for-Offline-RL

Diffusion-QLの公式Pytorch実装を参考にして、拡散モデルによる模倣学習方策をTensorflow2で実装しました。

実装全文:
github.com

拡散方策

拡散モデルはしっかり理解しようとするとややこしいですが、実装するだけならわりと簡単です。実際、下記のたった70行のコードがほぼすべてであり、模倣学習だけならcompute_bc_lossを最小化すれば完了です。

gist.github.com

ノイズスケジュール

画像生成向けのナイーブなDDPMだと拡散過程を数百ステップ以上繰り返すようですが、強化学習の行動次元は画像に比べてはるかに小さいために数ステップの繰り返しで十分なようです。しかし、よく使われるコサインスケジューラはそれほど小さいタイムステップを想定していないので、Diffusion-QLでは以下に示す特殊なスケジューラを使用しています。

ノイズスケジュールβ

拡散過程/逆拡散過程

コンピュータビジョン最前線 Summer 2023に書いてある通りに実装。式中のα、ハット付きαはノイズスケジュールβから算出される値。

拡散過程

 \displaystyle
{ x_{t} = \sqrt{\hat{\alpha_{t}} } x_{0} + \sqrt{ 1 - \hat{\alpha_{t}}}\epsilon_{t} }

gist.github.com

(ノイズεを予測する場合の)逆拡散過程

 \displaystyle
{
 \mu_{t} = \frac{1}{ \sqrt{1-\beta_t}} (x_{t} -  \frac{\beta_t}{\sqrt{1-\hat{\alpha_t}}} \epsilon_{\theta} )
}

 \displaystyle
{
 \sigma_{t} = \frac{1-\hat{\alpha}_{t-1}}{1 - \hat{\alpha_{t}}} \beta_{t}
}

 \displaystyle
{
 x_{t-1} \sim \mathcal{N}(\mu_{t},  \sigma_{t})
}

正規分布からのサンプリングではreparametarization trickを使う

gist.github.com


テスト結果

Box2D/BipedalWalker-v3環境。安定した学習ができています。

bipedalwalker-v3

参考文献

コンピュータビジョン最前線 Summer 2023

わずか30Pの解説にDDPMのエッセンスが凝縮されていて大変理解しやすく、実装時のリファレンスとして最適。拡散モデルではないが同号掲載されている品川先生のCLIP解説もわかりやすくてお得感がある。

拡散モデル データ生成技術の数理

PFN岡之原氏による本格派の拡散モデル解説書。難解な内容を扱っている割に読みやすいのは一般向けに技術解説を書き続けてきた岡之原氏の文章力の賜物か。


*1:拡散モデルの生成自体は決定論的ではないが生成確率の算出が困難であるために決定論的方策の枠組みで扱うしかないため

*2:実際にはTD3ロスと模倣学習ロスの和を損失関数として同時に学習するのでファインチューニングではない