どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

オフライン強化学習① Conservative Q-Learning (CQL)の実装

オフライン強化学習の有名手法CQLについて、簡単な解説とともにブロック崩し環境向けのtf2実装を紹介します

[2006.04779] Conservative Q-Learning for Offline Reinforcement Learning

sites.google.com

前提手法:
horomary.hatenablog.com


はじめに:オフライン強化学習とは

元ネタ: [2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems

問題設定:ゲーム実況を見るだけで上手にプレイできるか?

一般的な強化学習においては、環境においてエージェントが試行錯誤することで方策を改善していきます。環境で実際に行動することでデータを集めつつ学習する(オンライン学習)ことで、方策更新→環境からのフィードバック→ 方策更新 という改善サイクルを回せることこそが、逐次的意思決定問題において強化学習が単純な教師あり学習(模倣学習)よりも良い性能を獲得することができる根拠の一つとなっています*1 *2。高速なフィードバックサイクルこそが成長の近道というのは仕事やスポーツ、ソフトウェア開発などでも実感できることではないでしょうか?

Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems Fig1より:
オフライン強化学習とは事前に用意されたリプレイバッファだけを使用するoff-policy強化学習である

一方、オフライン強化学習とは事前に用意されたデータセットからの学習だけで良い方策を獲得することを目的とした理論であり、すなわち強化学習の強みであるオンライン学習が禁止されている問題設定です。事前に用意されたデータセットとはエキスパートの行動履歴かもしれませんし素人の行動履歴かもしれません。一人によるものかもしれませんし複数人によるものかもしれません。つまるところオフライン強化学習とはダークソウル*3を一切やったことない人がYouTubeのゲーム実況動画だけを見るだけでダークソウルをノーデスクリアできるようになるか?を目指す問題設定とも表現できます。(無理では?)

なお、オフライン強化学習(offline RL)という用語は Sergey LevineによるNeurIPS2020のチュートリアル講演 Offline RL Tutorial - NeurIPS 2020 から広く使われるようになった感があります。だいたい同じ意味を示す他の単語として batch RL, data-driven RL, fully off-policy RL などがあります*4


実世界でのユースケース

オフライン強化学習とはダークソウルを一切やったことない人がYouTubeのゲーム実況動画だけを見るだけでダークソウルをノーデスクリアできるようになるか?を目指す理論と喩えましたが、そんな縛りプレイみたいなことをするのは実世界にユースケースがあるためです。

たとえば、実環境での大失敗が許容できないドメインではオフライン強化学習が有用です。これは医療や化学プラントあるいは自動運転のように失敗が人命に関わるような分野が該当します。

www.microsoft.com

他には試行錯誤の時間/金銭コストが高いドメインでも活用が期待できます。たとえば試行錯誤の過程での大失敗が高額なハードウェアの損傷につながるロボティクスや、レコメンデーションや自然言語タスクなどにおける人間のフィードバックを報酬とした強化学習 *5のように試行錯誤に人力が必要なためにやり直しコストが大きい分野などです。

arxiv.org

また、そのような極端な環境でなくとも既存の試行錯誤履歴からオフライン学習してそこそこの性能にしたうえでオンライン強化学習をファインチューニング程度に使いたいというのもよくあるユースケースでしょう。アカデミックではあまり研究されない設定だとは思いますが、アルゴ/ハイパラ変更のたびにサンプル収集をゼロからやり直すのはあまりに非効率的なのでこのような転移学習的な視点もまた実務的には大変有用です。


模倣学習との違いなど

オフライン強化学習も模倣学習も事前に用意されたデータセットだけを使う*6という問題設定は共通していますがSergey Levineのチュートリアル論文 では、あくまでオンライン学習前提の強化学習(TD学習や方策勾配法)手法の拡張を指して”オフライン強化学習”と呼称しているように見えるので、基本的に教師あり学習である模倣学習とはこの点で異なります。

与えられたデータセットだけを使う設定に対してオフラインと呼称するので、たとえばシミュレーション環境でオンライン強化学習を行うことは実環境と相互作用しないという意味ではオフラインですが通常はこれをオフライン強化学習と呼称しません。ただし、与えられた固定データセットでシミュレータ(環境モデル)を構築してそこでオンライン強化学習を行うような場合にはオフライン強化学習(オフライン設定の世界モデルベース強化学習)と呼称されてる気がします。


オフライン強化学習の難しさ

図の出典: Offline RL Tutorial - NeurIPS 2020

データセットサイズは問題を解決しない

データセットが十分に大きいならば、オフライン設定のDQNは少なくともデータセットを収集したポリシーと同程度の性能になることを期待したくなります。しかし現実には単純なオフライン設定のDQNはデータセット収集ポリシーどころか、シンプルな模倣学習ポリシーよりもはるかに劣悪な性能となることがしばしばあります。

これはQ学習(など)におけるargmaxオペレータが価値の関数近似において誤差を増幅する効果があるために、オフライン方策とデータセットの行動選択に乖離が生じるためです。すなわちオフライン方策は模倣学習(教師あり学習)による方策と一致しません。

 \displaystyle
{ Q(s, a)  \leftarrow r(s, a) + \arg \max_{a'} Q(s', a)  }

データセットと異なる行動ををすれば当然データセットには無い初見状態に突入するのでパフォーマンスが連鎖的に悪化します。 ダークソウルRTAで完璧に動きを覚えていたのにワンミスで敵の行動パターンが変わってしまいチャート崩壊するようなものです。

価値近似の誤差によりデータセットの状態行動選択(πβ)とオフライン方策(πθ)が乖離すると初見状態に突入する


Out of Distribution: データセット分布外アクションの過大評価

オフラインQ学習においてargmaxオペレータが悪さをする具体例を見ていきましょう。

価値の関数近似の良い点でもあり悪い点でもあるのは、任意の(s, a)についてたとえ一度も試行したことが無くともQ(s,a)を評価することができてしまうことです。この性質とargmaxの悪魔合体によりたまたまQ(s, a)が無根拠に上振れした未実施アクションが採用されてしまいます。オンライン設定であれば次の試行時にうっかり過大評価していたことに気づき修正が行われるのですが、オフライン設定では無根拠な過大評価が永久に修正されません、いわゆるエアプガチ勢です。

argmaxがデータセットのアクション分布(赤)の外にまで最大値を探しに行ってしまう

この課題の解決アプローチとして、TRPOのようにデータセットの行動分布とオフライン方策の分布間距離に制約を課す方法や、[1907.04543] An Optimistic Perspective on Offline Reinforcement Learningのように予測値の不確実性の高い行動(≒未学習領域)にペナルティを課す方法、今回紹介するCQLのようにデータセットに存在するサンプルの評価値に制約を課す方法などがあります。


もっと詳しく

分布シフトの原因として、他にも極値付近で汎化誤差が上振れする問題(Double DQNのアレ)などが指摘されています。すでに何度もリンクを貼っていますがこのようなオフライン強化学習の課題について、より詳しく知りたい場合はSergey LevineによるNeurIPS2020のチュートリアル講演がベストな開始点です。

sites.google.com

この講演は [2005.01643] Offline Reinforcement Learning: Tutorial, Review, and Perspectives on Open Problems をベースとしています。ボリューム多すぎてつらい場合は書籍「AI技術の最前線」 4章にてこの論文を3ページ要約してくれているのでこちらもおすすめです。

AI技術の最前線 これからのAIを読み解く先端技術73 | 岡野原 大輔 | コンピュータサイエンス | Kindleストア | Amazon


CQL:保守的なQ学習

[2006.04779] Conservative Q-Learning for Offline Reinforcement Learning

前例が無いからダメです

ようやく本題です。上述したOut of Distributionな行動選択問題についてCQLではサンプルベースのアプローチで対処します。すなわち、データセットに存在しない状態行動(s, a)の評価値Q(s, a)にペナルティを与えることでデータセットに前例のない行動が選択されることを防ぎます。

通常のQ学習ではTD誤差を最小化します。すなわち損失関数は、

 \displaystyle
{ E_{s,a,s' \sim \it{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] }

ここで、 \displaystyle { E_{s,a,s' \sim \text{D}} } とは 遷移サンプル(s, a -> s')がデータセット=オフラインリプレイバッファからサンプリングされたことを示します。CQLではこのTD誤差に対して前例のない行動へのペナルティ項が追加したものを損失関数とします。すなわち、


 \displaystyle
{ \text{ CQL(1) } = \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] \right) + E_{s,a,s' \sim \text{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] }


ここでμとはオフライン学習によって獲得した方策であるので、 \displaystyle { E_{s \sim \it{D}, a \sim \mu} \left[Q(s, a) \right] } とはオフライン学習方策によって行動選択(  \displaystyle { \max_{a} Q(s, a) } , 離散行動空間の場合 )したときの状態行動価値Q(s, a) です。これを損失関数として最小化することでオフライン学習方策μによって選択される行動aのQ(s, a)が低くなるようにネットワークが更新されていきます。データセットに存在しない行動に対してペナルティがつくので保守的(Conservative)なQ学習というわけです。

論文では、CQL(1)の更新式によって獲得されたQ関数はすべての(s, a) についてQ(s,a)の下界が得られることを証明しているのですが、しかしこのCQL(1)は保守的を通り越して新人イジメみたいな更新式になっています。というのもこの更新式では、オフライン方策が選択した行動であればデータセットに(s, a)の前例、があったとしても問答無用でペナルティが与えられるためです。つまりは、過去に前例が無いからダメです、前例があっても提案してるのが新人だからダメです、という感じ。


そこで、すべての(s, a) についてQ(s,a)の下界を得るのではなく、すべてのsについてV(s)の下界を得られればOKとして制約を緩和するとよりタイトな下界が得られます。


 \displaystyle
{ \text{ CQL(2) } = \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]  \right) + E_{s,a,s' \sim \text{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] }


ここでπβとはデータセットにおける行動選択確率です。よって、 \displaystyle
{  E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   }
はオフライン方策の行動選択とデータセットにおける行動選択が一致したときゼロとなり、逆にオフライン方策が選択した行動aについてのQ(s, a) > データセットに記録された行動aについてのQ(s, a) となったときに増大します。したがってCQL(2)を損失関数として最小化することで、前例のある行動についてのQ(s ,a) が常に前例のないデータセット外アクションについてのQ(s,a)より大きくなるようなQ関数を獲得することができます

前例があるならまずは前例を最重視する、まさに保守的なQ学習です。 洗練されたBehavior Cloningという印象。


方策の正則化

実用的にはCQL(2)にさらに方策の正則化項を追加したものを更新式とします。論文では正則化のバリエーションがいくつか提示されていますがもっとも簡単なのは伝統的なエントロピー最大化による方策正則化項を追加したCQL(H)です。


 \displaystyle
{ \text{ CQL(H) } =  \min_{Q} \max_{\mu}   \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   + \text{H}(\mu) \right) + \text{TDError} }


ここで、H(μ)は方策のエントロピーです。Qとオフライン方策μについての二重最適化問題になっていてややこしそうに見えますが、max_μについてはclosed formで解けるので案外シンプルになります。オフライン方策μについての最大化に関係ない定数項をすべて取り除くと、

 \displaystyle
{   \max_{\mu}  E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) \right]  +  \text{H}(\mu)  }

 \displaystyle
{  = \max_{\mu}  E_{s \sim \it{ D} , a \sim \mu} \left[ Q(s, a) - \log\mu(a | s) \right]    }

これを方策μが確率分布である(非負で総和1)という制約つきラグランジュ最適化問題として解くと、最適方策μ* として、

 \displaystyle
{  \mu{*} =  \frac{ \exp{Q(s,a)} }{ Z(s) }    }

が得られます。ここでZは規格化定数(あるいは分配関数)であり、 \displaystyle {  Z(s) =  \sum_{a}  \exp{Q(s,a)} }  です。このあたりは強化学習 as Inference: Maximum a Posteriori Policy Optimizationの実装 - どこから見てもメンダコ と同じような流れです。


得られた最適方策μ*をCQL(H)に代入すると、

 \displaystyle
{ \text{ CQL(H) } =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a)  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   + E_{s \sim \it{ D} , a \sim \mu} \left[ -\log{\mu{*}(a | s)}  \right]  \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) -\log{\mu{*}(a | s)}  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[Q(s, a) -\log{ \frac{ \exp{Q(s,a)} }{ Z(s) } }  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha \left( E_{s \sim \it{ D} , a \sim \mu} \left[ \log{ Z(s)}  \right] - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) + \text{TDError} }

 \displaystyle
{  =  \min_{Q}    \alpha  E_{s \sim \it{ D} }  \left[ \log{ Z(s)}  - E_{ a \sim \pi \beta } \left[Q(s, a) \right]  \right]  + \text{TDError} }

※分配関数Zの期待値は方策に依存しないことに注意


CQL(H)のTF2実装

実装全文:
github.com

最終的に得られたCQL(H)のロス関数をtensorflow2で実装します。

 \displaystyle
{  \text{ CQL(H) } =  \min_{Q}    \alpha \left( \log{ Z(s)} - E_{s \sim \it{ D} , a \sim \pi \beta } \left[Q(s, a) \right]   \right) +  E_{s,a,s' \sim \text{D}} \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]  }

論文では分布強化学習手法であるQR-DQNをベースとしてCQLを実装しています。これはQR-DQNは実装がシンプルなわりに性能が良好であるためでしょう。QR-DQNのTD誤差項(分位点huber-loss)の実装は過去記事を参照ください。

horomary.hatenablog.com

CQL項自体は論文でも言及されているようにたった20行程度で実装することができます


DQN Replay Datasetの利用

CQLの実装に関して面倒だったのはアルゴリズム実装よりもDQNリプレイデータセットの利用です。

データセットのダウンロード方法:
research.google

DQNリプレイデータセットatariの60 のゲームのそれぞれについて、異なるシードで 5 つの DQN エージェントを200Mフレームでトレーニングしたデータセットです。ただし、4 frame skip設定でのトレーニングのため1エージェントごとに200M/4 = 50M遷移が記録されています。また、50M遷移は(s, a, r, s', a', r', terminal) ごとに50のファイルに分割されて同じディレクトリに保存されています。

このデータセットgoogle強化学習フレームワークであるdopamine-rlに実装されたインメモリreplaybufferであるOutOfGraphReplayBuffer.load()を用いてロードすることが想定されています。

github.com

しかしRAM要求がかなり厳しかったので一度tfrecord形式に変換し、tf.data.TFRecordDatasetを使ってトレーニングしました。この実装だとメモリ16GB程度で十分なのでお財布に優しいです。

github.com


ブロック崩しの学習結果

Breakout(ブロック崩し)環境についてCQL論文の1%データセット設定を再現するように実装し、1Mstepの更新を行いました。論文掲載のスコアが60点くらいなのでだいたい再現できています。

CQL

学習結果


次:Decision Transformer

horomary.hatenablog.com


*1: https://arxiv.org/pdf/2005.01643.pdf 2.4 What Makes Offline Reinforcement Learning Difficult?

*2: DAggerはオンライン設定の模倣学習的な手法

*3:高難易度で有名なフロムソフトウェア社のゲーム

*4:個人的には fully off-policy RL が分かりやすくて好き

*5:例:Learning to Summarize with Human Feedback

*6:ただし模倣学習についてはDAggerみたいにhuman in the loop設定になっていることも多い