どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

ハムスターでもわかるProximal Policy Optimization (PPO)①基本編

シンプルなようで厄介な強化学習アルゴリズム PPO (Proximal Policy Optimization) を実装レベルの細かいテクニックまで含めて解説します。

※TRPOの理解が前提です

horomary.hatenablog.com

[PPOシリーズ]

ハムスターでもわかるProximal Policy Optimization (PPO)①基本編 - どこから見てもメンダコ

ハムスターでもわかるProximal Policy Optimization (PPO)②TF2による実装 - どこから見てもメンダコ



はじめに

PPO (Proximal Policy Optimization, 2017) はシンプルにされた TRPO (Trust Region Optimization, 2015) として発表された手法です。

TRPOはアイデアも性能も素晴らしいのですが実装が複雑になりすぎる、ActorとCriticでパラメータ共有をするA3C 型のアーキテクチャが利用困難であるCNNやRNNで性能が悪いなどいくつかの課題がありました(TRPO著者の講義スライドより)。そこで、PPOではClipped Surrogate Objectiveという直感的にも実装的にもシンプルなアイデアでTRPOのコンセプトを実現できることを提案し、実際にMuJoCo環境でTRPOと同等以上のパフォーマンスを発揮することを示しました。

さらに同じ論文内で、A2CにPPOのアイデアを導入することでCNN+離散値アクションのAtari環境でも良いパフォーマンスを発揮することを示し、その実装の手軽さと性能の安定性から現在でも人気の強化学習手法となっています。

一方で、PPO(に限らず方策勾配学習全般)は実装レベルの細かいテクニック(たとえば報酬のスケーリングの有無、重みの初期化手法の選択など)によってパフォーマンスが大きく変わるので、コアコンセプトの影響の大きさが比較しにくいことが指摘されてきました。とくに近年のPPO検証論文([2005.12729] Implementation Matters in Deep Policy Gradients: A Case Study on PPO and TRPO, ICML2020)ではMuJoCo環境におけるPPOのパフォーマンス向上は主に実装上の細かいテクニック(code-level optimization)によって得られているものだと指摘しています。

というわけで、本記事ではそのような”実装上の細かいテクニック”まで含めてPPOを紹介します。


TRPO: KL制約付き最大化問題としての方策更新

PPOはシンプル化されたTRPOとして発表された手法ですので、まずはTRPOのコンセプトを紹介します。

方策勾配法は適切な更新サイズを決める困難さからしばしば大きく更新しすぎて方策ネットワークが破綻する(あるいは小さく更新しすぎて学習が全く進まない)という不安定さの課題を抱えています。この課題を”方策ネットワークの出力が更新前と更新後で変化しすぎないように更新サイズを毎回決める” ことによって解決したい、というのがTRPOのコアコンセプトです。

TRPOはこのコンセプトを毎回の方策ネットワークの更新をKL制約付き最大化問題として捉えることによって実現します。 すなわち、更新前方策 \displaystyle{\pi_{old}}の出力と更新後方策 \displaystyle{\pi_{new}}の出力のKLダイバージェンスが一定値以下という制約の下で、代理目的関数 (Surrogate objective)  \displaystyle{L(\theta_{new})}を最大化する \displaystyle{\pi_{new}}を計算します。これは方策パラメータθについての制約付き最大化問題であるのでラグランジュ乗数法で解くことができます。


 \displaystyle{
\underset{\theta_{new}} {\text{argmax }}  L(\theta_{new}) = E_{{s}\sim{d^{\pi{θold}}}, \ {a}\sim\pi_{θold} }\left[ \frac{\pi_{\theta_{new}}(a | s )}{\pi_{\theta_{old}}(a | s )} A^{\pi}(s, a) \right]
}

 \displaystyle{
{\text{subject to }} E\left[ {D_{KL}(\pi_{\theta_{old}} || \pi_{\theta_{new}}) } \right] \leq \delta
}


なお、慣れ親しんだ方策勾配法の目的関数( \displaystyle{ \log{\pi(a | s)}}が入ってるアレ)とは異なる代理目的関数 (Surrogate objective)  \displaystyle{L(\theta_{new})}がどこから湧いてきたのかが気になります。この代理目的関数L(θ)は重点サンプリング (Importance Sampling)によって導出される”方策関数の更新に伴う期待報酬の改善″なのですが解説すると長くなるので詳細は過去記事を参照ください。

horomary.hatenablog.com


PPO:シンプル化されたTRPO

TRPOの欠点は上述したKL制約付き最大化をラグランジュ問題として真面目に解くために実装が死ぬほど煩雑になることです。具体的にはフレームワークが用意するAdamなどのOptimizerを使わずパラメータ更新処理を自力実装する必要があります。またパラメータ数についてのスケーラビリティの問題、および方策ネットワークにCNNが含まれる性能がイマイチという問題もあります。

そこで、PPO論文ではTRPOのコンセプトであるKL制約つきパラメータ更新を継承しつつ実装の煩雑さを改善するために下記の2つのアプローチを提案しました。なお、一般にPPOと言った場合は① Clipped Surrogate Objectiveでの実装を指すことに留意ください。


① Clipped Surrogate Objective

※すべての式と図はPPO論文 より

TRPOでも登場した代理目的関数(Surrogate Objective)の内部には、更新前方策 \displaystyle{\pi_{old}}の出力と更新後方策 \displaystyle{\pi_{new}}の出力の変化の比が含まれます。この比を r(θ) と置きます。(rはimportance ratio のrでありrewardのrではないことに注意)

Clipped Surrogate Objectiveではr(θ)が(1-ε)以下、(1+ε)以上にならないようにクリッピングを行います(εはハイパーパラメータでありε=0.1-0.2くらいがよく採用されます)。このクリッピング処理は、Clipされた代理目的関数とClipされていない代理目的関数を比較し、小さいほう(=悪い方)を採用することで実現できます。

ネストしていて微妙に分かりにくいこのクリッピング処理を視覚的に示したのが下図です。実質的にはAdvantageの正負で場合分けして上限クリップするか下限クリップするか決めるという処理が行われます。


しかし、なぜ代理目的関数のClipping処理によってTRPOコンセプトである”方策関数の極端な更新を避ける”ことを達成できるのでしょうか?

直感的な理解のためにOptimizerの気持ちを考えてみます。

例えば与えられたミニバッチ内のある \displaystyle{(s_t, a_t)}の組が他の \displaystyle{(s, a)}の組に比べて極端に大きな報酬を獲得していたとき、すなわち \displaystyle{A(s_t, a_t)}が極端に大きいとき、Optimizerは代理目的関数L(θ)を最大化するためにミニバッチ内の他の \displaystyle{(s, a)}の組を差し置いても、方策関数が \displaystyle{s_t}において \displaystyle{a_t}を出力する確率  \displaystyle{\pi( a_t | s_t)} をできるだけ高めるように極端な方策関数の更新を行ってしまいます。

しかし、Clipつき代理目的関数の場合は、更新後方策関数が \displaystyle{s_t}において \displaystyle{a_t}を出力する確率 \displaystyle{\pi( a_t | s_t)}を、更新前方策関数が \displaystyle{s_t}において \displaystyle{a_t}を出力する確率  \displaystyle{\pi_{old}( a_t | s_t)} の (1+ε)倍 以上に大きくしても上限Clipによって代理目的関数が改善しないので方策関数の極端な更新を回避することができます

逆もまた同様で、ミニバッチ内のある \displaystyle{(s_t, a_t)}の組が他の \displaystyle{(s, a)}の組に比べて極端に小さな報酬を獲得していた時も下限のClippingが機能し方策関数の極端な更新を回避できます。


Clippingの効果を視覚的に示したのが下図です。更新前の方策関数を0、Clippingつき代理目的関数を用いて一回更新した方策関数を1として、他の目的関数を採用したときに値がどう変化するかを線形補間で示しています。

Clipされた代理目的関数(赤)の場合は、KLダイバージェンス(青)が0.02くらいのところで目的関数が最大になるのに対して、クリップがない場合(オレンジ)は大きく更新すればするほど目的関数の値が大きくなっていきます。さらに、代理目的関数をクリップするがクリップ前の代理目的関数との最小値比較をしない場合(緑)では、大きく更新しても目的関数の値は大きくならないものの、値が小さくもならないので方策関数の極端な更新を回避するモチベーションが弱いことがわかります。


② Adaptive KL ペナルティ

KL制約つき最大化問題を真面目に解かずに、目的関数にソフト制約(ペナルティ)として組み込もう、というのがこのアプローチです。ペナルティの大きさは適応的(adaptive)に変化させていくので Adaptive KL penalty となります。

シンプルですが論文内のパフォーマンスはパッとしなく、ベースライン扱いで提案されている感じがあります。 この実装例はtf_agents (agents/ppo_agent.py at master · tensorflow/agents · GitHub)などで見ることができます。


実装レベルの最適化

※式と図はA Case Study on PPO and TRPO より

方策勾配系の強化学習手法は、各手法のコアコンセプトではない”実装レベルの細かいテクニック”がパフォーマンスに大きく影響することが様々な論文で検証されています。また、この傾向はとくに連続値コントロール環境で顕著なようです。

論文① [1709.06560] Deep Reinforcement Learning that Matters

論文② [2006.05990] What Matters In On-Policy Reinforcement Learning? A Large-Scale Empirical Study

論文③ [2005.12729] Implementation Matters in Deep Policy Gradients: A Case Study on PPO and TRPO

とくに 論文③ ではPPOとTRPOのパフォーマンスを比較しており、この論文内では”パフォーマンスに大きな影響を与えうる実装レベルのトリック"として具体的に下記の示すものが検証されています。


Value Clipping

ポリシーネットワークで行ったClippingと同じことを価値関数でも行います。なんとなく価値ネットワークの学習安定性が向上が期待できます。

しかし、論文③ のabration studyではvalue clippingはあっても無くてもあまり変わらない、という結果になっています。ただし、このトリックは価値関数の出力の正規化を行うトリックとのシナジーがある可能性もあるのではということが論文②の注11に記述されています。

個人的にはεを大きく設定すれば効果は薄れるしとりあえず実装しとけばいいんでは?という印象。

② Reward Scaling

論文③では重要なトリックとされており、割引報酬和の標準偏差でrewardを割ることでスケーリングします。※平均値を引く処理は行いません。

ただしrungning stats の性質上、報酬がスパースな場合は注意が必要です。具体的にはBipedalWalker-v3のようにほとんどのステップでの報酬は-1 から+1であるのに転倒時だけ-100というような極端な負報酬をとるような系では学習が不安定化します。

③ 方策ネットワークの初期化スキーム

A Case Study on PPO and TRPOでも、"What Matters In On-Policy Reinforcement Learning?"でも同様にポリシーネットワークの初期化スキームがパフォーマンスに大きな影響を与えいるとの指摘がされています。

前者ではkerasで言うとkernel_initializer="Orthogonal"がパフォーマンスを向上させると報告しています。後者では初期化スキームについてより詳細な調査を行っています。

④ 学習率のアニーリング

PPOではOptimizerにADAMを使用しますがこの学習率を徐々に下げていくとパフォーマンスが向上することが論文③で指摘されています。

その他

細かいトリックはまだまだあります。たとえば観測の正規化をするか、valueネットワーク出力の正規化をするか、アドバンテージをどのような手法で実装するかなど。実際のところこれらの効果は系によるところが大きいので一概には言えないというのが正直なところではないでしょうか。


そして実装へ

Bipedalwalker-v3をターゲットにtensorflow2でPPOを実装します。

horomary.hatenablog.com


References

[1707.06347] Proximal Policy Optimization Algorithms

[1502.05477] Trust Region Policy Optimization

[1602.01783] Asynchronous Methods for Deep Reinforcement Learning

[1709.06560] Deep Reinforcement Learning that Matters

[2006.05990] What Matters In On-Policy Reinforcement Learning? A Large-Scale Empirical Study

[2005.12729] Implementation Matters in Deep Policy Gradients: A Case Study on PPO and TRPO

Pythonの分散並列処理ライブラリRayの使い方

いままでありがとうmultiprocessing。そしてこんにちはRay

関連:

horomary.hatenablog.com

horomary.hatenablog.com


Rayとは

Rayは分散アプリケーションを構築および実行するための高速でシンプルなフレームワークです。
What is Ray? — Ray v1.6.0

github.com

Rayライブラリを使用することで、既存のライブラリよりもはるかにシンプルでPythonicなコードで分散並列処理を実装できます。

分散並列処理は計算科学に欠かせない要素にもかかわらず、これまでのPythonでの分散並列処理の実装は直感的とは言い難いものでした。しかし、Rayライブラリの登場によってPython言語は真に"科学者のための言語"となったのではないでしょうか。


基本的な使い方

Rayによる並列化処理は非常にシンプルに記述できるためにまずはコードを見た方が早いでしょう。

並列化していないコード

5秒間 sleepする関数を3回実行するので実行時間は当然5×3=15秒程度です


Rayによる並列化コード

ほとんどコードを変更せずに並列化できることがわかります。
また、並列化によって実行時間が15秒から7秒程度まで短縮されていることを確認できます。

実行時間が5秒ではなく7秒であることからわかるように、rayはmultiprocessingなどと比較してプロセス分岐時のオーバーヘッドがやや大きいので、重いタスクの並列処理には向いていますが軽量なタスクの実行のためにサブプロセスが頻繁に生成されるような用途には向いていません。


待ち合わせ処理: ray.get

分岐したプロセスの戻り値はray.getによって取得できます。

各プロセスを格納したリストである変数result に対して ray.get することで、すべてのプロセスが終了して戻り値を返すのを待つ、待ち合わせ処理となります。

このコードからは ray.get する前のリスト result には 、ObjectRef というオブジェクトが格納されていることがわかります。

ObjectRefはプロセスの戻り値に対するプレースホルダと考えるとよいでしょう。ray.get(ObjectRef)することでプロセスの戻り値を取得できます。また、ray.get(ObjectRef)時点でプロセスが終了していない場合は対象のプロセスが終了するまで待機します。

そしてこのサンプルコードのように、ObjectRefが格納されたリストに対してray.getした場合はリスト内のすべてのサブプロセスが戻り値を返すまで待つので待ち合わせ処理となります。


逐次処理: ray.wait

同期処理はray.getで簡単に実装することができることがわかりました。では非同期処理はどうでしょう?

具体的には同時に多数の並列タスクを開始するが、各プロセスそれぞれの終了時間がまちまちであるために終了したタスクから順に次の処理を行いたい、というケースを考えます。これはray.waitを使うことで実装できます。

デフォルトでは、ray.wait はその時点ですでに終了しているタスクのリストと終了していないタスクのリストを返します。

ただし、引数 num_returns が指定されている場合は、すでに終了しているタスクがnum_returns個 格納されたリストと、それ以外のタスクが格納されたリストを返します。 ray.waitが実行された時点で終了しているタスクがnum_returns個に満たなかった場合はnum_returns個のタスクが終了するまで待機します。

このサンプルコードではray.waitを使用して、

  1. 終了済みタスクのObjectRefを1つ取得する
  2. ray.get(ObjectRef)を出力する

という流れで、終了したタスクから逐次戻り値を出力しています。


クラス単位でのサブプロセス化

Rayではクラス単位での並列化も簡単にできます。

これはmultiprocessingモジュールで実装するのはなかなか面倒だった、状態を保持するサブプロセスの実装が簡単にできるということです。


リソースの設定: ray.init

※ドキュメントに記載されている情報であり動作未検証です:The Ray API — Ray 0.3.0 documentation

ここまでのサンプルコードすべてで最初に記述されている ray.init() ではリソース配分を設定できます。

ray.init()のように引数を指定しないとローカルマシン内の全CPUが利用可能になるのに対して、ray.init(num_cpus=20, num_gpus=2) というように使用するリソースを明示することもできます。とくに子プロセスでGPUを使用したい場合はnum_gpusを明示する必要があるようです。

また、ray.init(local_mode=True) とするとコードそのままでもサブプロセス化されなくなり、printが効くようになるので開発時に便利です。


クラスターでの分散並列処理

rayを使えばマルチノードでの分散並列化も楽にできます

horomary.hatenablog.com


分散強化学習での利用

rayのユースケースのひとつは分散強化学習でありrayのサブプロジェクトとして強化学習フレームワークRLLIBが開発されています。が、もしスクラッチ実装に興味がある場合は場合はぜひ下記の記事も参照ください。

horomary.hatenablog.com

horomary.hatenablog.com

ハムスターでもわかるTRPO ③tensorflow2での実装例

強化学習初学者の鬼門であるTRPO(Trust region policy optimization)を丁寧に解説し、tensorflow2で実装します。その③。

[TRPOシリーズ一覧]

ハムスターでもわかるTRPO ①基本編 - どこから見てもメンダコ

ハムスターでもわかるTRPO ②制約付き最適化問題をどう解くか - どこから見てもメンダコ

ハムスターでもわかるTRPO ③tensorflow2での実装例 - どこから見てもメンダコ

関連: TRPOにおけるHessian-vector-productと共役勾配法 - どこから見てもメンダコ

f:id:horomary:20200826004751p:plain:w300


TRPOの更新式

前回までの解説で、結局のところTRPOでは次式の通りに方策パラメータを更新すればよいことがわかりました。

 \displaystyle{
\theta_{new} = \theta_{old} + \beta{s}^{\prime} = \theta_{old} + \sqrt{\frac{2{\delta}}{{s^{\prime}}^{\mathrm{T}} {H }{s^{\prime}}}}{H^{-1}g}
}

 \displaystyle{
\text{where }
}

 \displaystyle{
g = \nabla_{\theta} L(\theta) | _{\theta_{old}} = {\frac{\nabla_{\theta}\pi_{\theta}(a | s ) | _{\theta_{old}}}{\pi_{\theta_{old}}(a | s )} A^{\pi}(s, a)}
}

 \displaystyle{
H ={\nabla_{\theta}{^2} {D}_{KL}(\pi_{\theta_{old}} || \pi_{\theta})|_{\theta{old}}}
}

 \displaystyle{
s = {H^{-1}g}
}


 \displaystyle{g}については、現在(更新前)のパラメータ周辺でのL(θ)の勾配ですので実装でとくに難しいことはありません。

問題は更新前/更新後方策関数のKLダイバージェンスのヘシアン \displaystyle{H}とその逆行列とgの積  \displaystyle{s=H^{-1}g}の算出です。
この計算処理がTRPOの実装が複雑になる原因となっています。


ヘシアンの耐えられない重さ

上述の更新式を愚直に実装しようとするとヘシアン(ヘッセ行列)の逆行列を計算する必要がありますが、 ヘシアンの計算コストはパラメータ数の3乗オーダーで増加していきますので(参考:Hessian free)、深層学習ではとても実用的な計算速度にはなりません。

しかし、TRPO(に限らずヘシアンを利用する最適化手法)の更新式をよく見るとヘシアンそのものが計算に必要なのではなく、ヘシアンの逆行列とベクトルの積 および ヘシアンとベクトルの積( \displaystyle{H^{-1}g} および  \displaystyle{Hs} )さえわかれば事足りることに気づきます

※方策関数のパラメータ数がNのときヘシアン  \displaystyle{H} は N×N行列で勾配  \displaystyle{g} はN×1ベクトルなので  \displaystyle{s=H^{-1}g} は N×1ベクトル 、したがって  \displaystyle{Hs} も N×1ベクトル であることに注意しましょう。


ヘシアンの逆行列とベクトルの積

ヘシアンの逆行列と勾配ベクトルの積を x と置きます。

 \displaystyle{
x = H^{-1}g
}

これを変形すると

 \displaystyle{
Hx = g
}

となり連立一次方程式  \displaystyle{
Ax = b
} の形になります。そして連立一次方程式の解xは共役勾配法によってよい近似解を得ることができます。

このトリックによりヘシアン逆行列の愚直な計算を回避することができます。


Hessian-vector product の計算

資料4Efficiently Computing the Fisher Vector Product in TRPO より

 \displaystyle{
x = H^{-1}g
} を近似するための共役勾配法アルゴリズム中ではヘシアンと任意のベクトルの積  \displaystyle{
Hv
} を計算する必要があります。

共役勾配法についての詳細は過去記事を参照ください。
horomary.hatenablog.com

ヘシアンを真面目に計算してしまうとヘシアンの逆行列ほどではないにせよやはり計算量が多すぎるのでここでも計算トリックを使って、Hを陽に計算せずにHvを計算します。 具体的には”KLダイバージェンスの勾配と任意のベクトルvの積 の総和” についての勾配 を計算することによって Hv が得られます。

f:id:horomary:20200805225547p:plain:w400
引用元: https://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/

これらのテクニックにより  \displaystyle{H^{-1}g} および  \displaystyle{Hs} を現実的な計算量で得ることができるようになります。


実装

openAI/baselinesの実装(tensorflow1.X) を参考に tensorflow2で実装しました。
baselines/trpo_mpi.py at master · openai/baselines · GitHub

実装全体はgithubへ: https://github.com/horoiwa/deep_reinforcement_learning_gallery

コード全体

1024ステップ分のトラジェクトリを取得済み、アドバンテージを計算済みの状態から方策関数を更新するコードのみ掲載します。


更新ステップの計算まで

fullstep =  \displaystyle{
\sqrt{\frac{2{\delta}}{{s^{\prime}}^{\mathrm{T}} {H }{s^{\prime}}}}{H^{-1}g}
} の計算までは数式通りに実装するだけです。

tensorflow2.X だと GradientTapeのおかげでどこで勾配が流れるのかわかりやすいため、tensorflow1.X での実装に比べてだいぶわかりやすくなっています。


ステップサイズの線形探索

fullstep =  \displaystyle{
\sqrt{\frac{2{\delta}}{{s^{\prime}}^{\mathrm{T}} {H }{s^{\prime}}}}{H^{-1}g}
} はあくまでテイラー展開による近似によって計算される値なので、このステップで更新した結果本当にL(θ)が改善するか、KL距離制約を満たしているかを確認します。

もしL(θ)が改善しない or KL距離制約を満たさないならばステップサイズを縮小します。既定の回数この処理を繰り返しても条件を満たさないならばこの回でのパラメータ更新は諦めてトラジェクトリを破棄します。

これがTRPOにおける Line search です。


Pendulum-v0 でのテスト結果

安定した学習ができていることがわかります。

f:id:horomary:20200828000139p:plain:w500

TRPOは DDPG なんかの決定論的方策勾配と違って動きに人間味があっていいですね。

f:id:horomary:20200828001006g:plain:w400


そしてPPOへ

horomary.hatenablog.com

ハムスターでもわかるTRPO ②制約付き最適化問題をどう解くか

強化学習初学者の鬼門であるTRPO(Trust region policy optimization)を丁寧に解説し、tensorflow2で実装します(その②)。

[TRPOシリーズ一覧]

【強化学習】ハムスターでもわかるTRPO ①基本編 - どこから見てもメンダコ

【強化学習】ハムスターでもわかるTRPO ②制約付き最適化問題をどう解くか - どこから見てもメンダコ

【強化学習】ハムスターでもわかるTRPO ③tensorflow2での実装例 - どこから見てもメンダコ

関連: TRPOにおけるHessian-vector-productと共役勾配法 - どこから見てもメンダコ

f:id:horomary:20200821225540j:plain:w400
https://petponder.com/how-to-care-for-your-pet-djungarian-hamster


前回のまとめ

TRPOでは方策パラメータθの更新を、更新前/更新後 方策のKL距離制約つきの目的関数L(θ)最大化問題と捉えます。前回はKL距離制約の意義と目的関数L(θ)の式がどこから湧いたのかを確認しました。今回はこの制約付き最大化問題を実際にどう解くかを解説します。

 \displaystyle{
\underset{\theta_{new}} {\text{argmax }}  L(\theta_{new}) = E_{{s}\sim{d^{\pi{θold}}}, \ {a}\sim\pi_{θold} }\left[ \frac{\pi_{\theta_{new}}(a | s )}{\pi_{\theta_{old}}(a | s )} A^{\pi}(s, a) \right]
\tag{1}}

 \displaystyle{
{\text{subject to }} E\left[ {D_{KL}(\pi_{\theta_{old}} || \pi_{\theta_{new}}) } \right] \leq \delta
\tag{2}}


制約付き最大化問題

制約付き最大化問題なのでラグランジュの未定乗数法で解けないか、と考えます。

ラグランジュの未定乗数法(例:2変数の場合)

制約付き最大化問題
 \displaystyle{
{\text{maximize}} \ f(\theta_1, \theta_2) \\
{\text{subject to } \ g(\theta_1, \theta_2) = 0 }
}
の解θは
 \displaystyle{
L(\theta_1, \theta_2, \lambda) = f(\theta_1, \theta_2) - \lambda{g(\theta_1, \theta_2)}
} について
 \displaystyle{
\frac{\partial{L}}{\partial{\theta_1}} = \frac{\partial{L}}{\partial{\theta_2}} = \frac{\partial{L}}{\partial{\lambda}} =0
}

ラグランジュの未定乗数法についての解説はリンク先を参照:
ラグランジュの未定乗数法と例題 | 高校数学の美しい物語

これをそのまま適用すれば良さそうに思ってしまいますが、しかし残念ながら話はそう簡単ではなく、 \displaystyle{{\pi_{\theta_{new}}(a | s )}} 内でのパラメータ同士の絡み合いがきつすぎて実用的にはとても解けそうにありません。

そこで目的関数L(θ)および制約式についてテイラー展開による近似を行い単純化していきます。

具体的には目的関数L(θ)は一次までのテイラー展開、KLダイバージェンスについては二次までのテイラー展開を行うことで近似式とします。


テイラー展開(2次まで)

たとえば \displaystyle{
f(\theta)
} について \displaystyle{
\theta_{old}
} 周辺で2次までのテイラー展開を行うと下式のようになります。

 \displaystyle{
f(\theta) | _{\theta_{old}} \simeq f(\theta_{old}) + f^{\prime}(\theta_{old})(\theta - \theta_{old}) + \frac{1}{2}f^{\prime\prime}(\theta_{old})(\theta - \theta_{old})^{2}
}


目的関数L(θ)の1次近似

※注意: 見やすさのためにこれ以降は \displaystyle{\theta_{new}}を単に \displaystyle{\theta}と表記します。  \displaystyle{\theta_{old}}はそのまま \displaystyle{\theta_{old}}の表記です。

L(θ)は \displaystyle{\theta_{old}}周辺での1次までのテイラー展開により近似します。

 \displaystyle{
L(\theta) \simeq L(\theta_{old}) + g^{\mathrm{T}}(\theta - \theta_{old})
}
 \displaystyle{
\text{where } g = \nabla_{\theta} L(\theta) | _{\theta_{old}}
}

 \displaystyle{
\nabla_{\theta} L(\theta) | _{\theta_{old}}
} \displaystyle{
\theta = \theta_{old}
} におけるL(θ)の勾配です。

また、 \displaystyle{L(\theta)}の右辺第一項  \displaystyle{L(\theta_{old})} は定数なので最大化問題では無視してよいことに注意してください。

一次近似とはずいぶん大胆な近似なので、そんなに近似して大丈夫か?と思います。
実際大丈夫じゃないこともあるのでTRPOではパラメータ更新を確定する前に、更新後のパラメータが期待報酬を改善するか、KLダイバージェンス制約を満たしているかを確認します。この処理についての詳細はステップサイズの線形探索で説明します。


KLダイバージェンスの2次近似

資料3 : Taylor expansion of KL

L(θ)と同様にKL距離の制約式も \displaystyle{\theta_{old}}周辺でテイラー展開していきます。KL距離は2次まで展開します。

 \displaystyle{
{D_{KL}(\pi_{\theta_{old}} || \pi_{\theta})}  \simeq {D_{KL}(\pi_{\theta_{old}} || \pi_{\theta_{old}}) }  + {\nabla_\theta{D}_{KL}(\pi_{\theta_{old}} || \pi_{\theta})|_{\theta{old}}} ^{\mathrm{T}} (\theta - \theta_{old}) + \frac{1}{2}(\theta - \theta_{old})^{\mathrm{T}} {H }(\theta - \theta_{old})
}
 \displaystyle{
\text{where } H ={\nabla_{\theta}{^2} {D}_{KL}(\pi_{\theta_{old}} || \pi_{\theta})|_{\theta{old}}}
}

まず右辺第一項  \displaystyle{{D_{KL}(\pi_{\theta_{old}} || \pi_{\theta_{old}}) }} は0なので消えます。KLダイバージェンスの定義から当然ですね。

さらに右辺第二項  \displaystyle{{\nabla_\theta{D}_{KL}(\pi_{\theta_{old}} || \pi_{\theta})|_{\theta{old}}} ^{\mathrm{T}} (\theta - \theta_{old})} も0になって消えるのですがあまり直感的ではないですね。
そういうものだと放置したくない方は資料3 : Taylor expansion of KLの証明もしくは f-divergenceと汎関数微分 - れおなちずむ を参照ください。

さて、右辺第一項、第二項が消えた結果、結局のところ制約式は次式になります。

 \displaystyle{
{D_{KL}(\pi_{\theta_{old}} || \pi_{\theta})}  \simeq \frac{1}{2}(\theta - \theta_{old})^{\mathrm{T}} {H }(\theta - \theta_{old})
}

HはKLダイバージェンスの二階微分でありすなわちヘッセ行列(Hessian)です。


近似後の制約付き最大化問題

Efficiently Computing the Fisher Vector Product in TRPO – Telesens より

テイラー展開による近似によって目的関数と制約関数がかなり単純化されました。

 \displaystyle{
\underset{\theta} {\text{argmax }}  g^{\mathrm{T}}(\theta - \theta_{old})
}
 \displaystyle{
{\text{subject to }} \frac{1}{2}(\theta - \theta_{old})^{\mathrm{T}} {H }(\theta - \theta_{old}) \leq \delta
}

この制約付き最大化問題についてラグランジュ関数Gを作ります。
ラグランジュ関数はLで表記されることが多いですがすでにLは使っているのでGで表記します

 \displaystyle{
G = g^{\mathrm{T}}(\theta - \theta_{old}) - \lambda({\frac{1}{2}(\theta - \theta_{old})^{\mathrm{T}} {H }(\theta - \theta_{old}) - \delta})
}

ここで \displaystyle{(\theta - \theta_{old})=s}とおくと

 \displaystyle{
G = g^{\mathrm{T}}s - \lambda({\frac{1}{2}s^{\mathrm{T}} {H }s - \delta})
}

ここで、 \displaystyle{s=(\theta - \theta_{old})}はパラメータの更新方向となることに注意してください。


パラメータの更新方向と更新サイズ

まずはGをsについて微分します。

 \displaystyle{
\frac{\partial{G}}{\partial{s}} = {g} - {\lambda}{Hs} = 0
}

よって、

 \displaystyle{
s = \theta - \theta_{old} = \frac{1}{\lambda}{H^{-1}g}
}

 \displaystyle{\lambda} が解けていないので適切な更新サイズはわかりませんが、パラメータを  \displaystyle{
{H^{-1}g}
} 方向に更新すればよいことはわかりました。

そこで、パラメータの更新方向  \displaystyle{
{H^{-1}g}
} \displaystyle{
s^{\prime}} 、更新サイズを \displaystyle{
\beta} と置くと、適切な更新サイズsは、 \displaystyle{
s = \beta{s^{\prime}}
} と表せます。

 \displaystyle{\beta}はKL距離制約を満たすように決めればよいので、

 \displaystyle{
\frac{1}{2}{s}^{\mathrm{T}} {H }{s} = \frac{1}{2}\beta{s^{\prime}}^{\mathrm{T}} {H }\beta{s^{\prime}} = \delta
}

より更新サイズβが

 \displaystyle{
\beta = \sqrt{\frac{2{\delta}}{{s^{\prime}}^{\mathrm{T}} {H }{s^{\prime}}}}
}

と求まります。


θnewはどうなるか

ここまでで、方策パラメータθは更新方向を \displaystyle{s^{\prime} = H^{-1}g } にとり、
更新サイズを  \displaystyle{
\beta = \sqrt{\frac{2{\delta}}{{s^{\prime}}^{\mathrm{T}} {H }{s^{\prime}}}}
} に設定することで、KL制約を満たしつつL(θ)を最大化できることがわかりました。

つまり更新後のパラメータ  \displaystyle{\theta_{new}} は、

 \displaystyle{
\theta_{new} = \theta_{old} + \beta{s}^{\prime} = \theta_{old} + \sqrt{\frac{2{\delta}}{{s^{\prime}}^{\mathrm{T}} {H }{s^{\prime}}}}{H^{-1}g}
\tag{3}}

とすればよいということになります。


ステップサイズの線形探索

理想的には(3)式で更新すればよいのですが、テイラー展開による近似の影響で(3)式通りに更新しても [ KL制約を満たさない or L(θ) が改善されない]、ということがあります。そこでTRPOではまず最大のステップサイズ、つまり(3)式通りの更新で [KL制約を満たすか、L(θ)が改善されるか] を確認します。

もしそうでないのなら更新サイズβを縮小して同じことを繰り返します。たとえばbaselinesのTRPO実装では0.5倍ずつβを小さくしていって条件を満たすステップサイズを探索しています。


そして実装へ

TRPOの理論の説明はここまでで終わりですが、実際はパラメータが非常に多い深層学習で  \displaystyle{ {H^{-1}} } を計算するのは非常に計算コストが高い(パラメータ数の3乗オーダー)という実用上の問題が残っています。しかし \displaystyle{ {H^{-1}} } そのものではなく  \displaystyle{ {H^{-1}}g } であれば共役勾配法によって良い近似解を得ることができます。

次回は Pendulum-v0をtensorflow2で実装したTRPOで解くことを通してこのあたりの実装を解説していきます。

horomary.hatenablog.com


補足: 自然勾配法との関係

資料3 : Natural Gradient Descent, Fisher Information Matrix

KLダイバージェンスのヘシアンHはフィッシャー情報行列Fと近似できます。そうするとTRPOのパラメータ更新式は自然勾配法と同様のものになります。 自然勾配法との違いは線形探索によってステップサイズの検証をするかどうかです。

※KLダイバージェンスのヘシアンとフィッシャー情報行列が同等であることの証明は
強化学習 (機械学習プロフェッショナルシリーズ) , A.4.2 KLダイバージェンスとフィッシャー情報行列の関係性 を参照

ハムスターでもわかるTRPO ①基本編

強化学習初学者の鬼門であるTrust Region Policy Optimization (TRPO、信頼領域ポリシー最適化)を丁寧に解説し、tensorflow2で実装します。


[TRPOシリーズ一覧]

ハムスターでもわかるTRPO ①基本編 - どこから見てもメンダコ

ハムスターでもわかるTRPO ②制約付き最適化問題をどう解くか - どこから見てもメンダコ

ハムスターでもわかるTRPO ③tensorflow2での実装例 - どこから見てもメンダコ

関連: TRPOにおけるHessian-vector-productと共役勾配法 - どこから見てもメンダコ


はじめに

TRPO(Trust Region Policy Optimization) は自然方策勾配法の派生であり、DQN, A3Cと並んで近年の深層強化学習の最重要手法のひとつです。

PPO, ACKTRなど派生手法も多く、強化学習を学ぶ上では避けては通れない手法なのですが、 私のように数学力がハムスターレベルの人間にはなかなかに理解が困難で大変苦労したので同じような人のために解説と実装を残します。実用的にはTRPOではなく、実装が簡単で安定性も良好な後継手法PPOを使えばよいかもしれませんが、理解できるとはぐれメタルのような経験値が入るので逃げずに頑張る価値はあると思います。

※ 方策勾配定理はわかるけどTRPOはよくわからない人向けの解説です。方策勾配定理についての詳細説明はありません。

f:id:horomary:20200714004440p:plain:w300


※この解説は下記の資料を基に作成しています

[1502.05477] Trust Region Policy Optimization : TRPO論文。

資料1:TRPO論文著者の講義スライド

資料2:UCバークレーの講義資料。

資料3カーネギーメロン大学の講義資料。

資料4:個人ブログ。実装方法が詳細に解説されている

動画1:資料1の講義動画


方策勾配法とその問題点

資料1:Two Limitations of “Vanilla” Policy Gradient Methods
資料2:P4, Policy Gradients Review

割引報酬和の期待値を表す関数である  \displaystyle{
J(\theta)
} の勾配  \displaystyle{
g
} は下式のようになるというのが方策勾配定理でした。

 \displaystyle{
g = \nabla J(\theta) = E_{{\tau}\sim\pi_\theta}\left[ {\sum_{t=0}^{\infty}{\gamma^{t}\nabla_\theta\log\pi_{\theta}(a_t | s_t )A^{\pi_\theta}(s_t , a_t )}}\right]
}

 \displaystyle{{\tau}} :トラジェクトリ(状態とアクションのシーケンス)
 \displaystyle{{\tau}\sim\pi_\theta} :トラジェクトリは方策 \displaystyle{\pi_\theta}に従って集められた、の意味


方策勾配法の問題点は、 \displaystyle{
J(\theta)
}を改善するための方策パラメータθの更新方向の正しさしか保証してくれず、適切なステップ幅がわからないことです。

よって方策勾配定理に従って方策関数パラメータθを更新しても、パラメータを大きく更新しすぎると方策が劣化することがあります。さらに悪いことに、劣化した方策関数は報酬の少ない悪いサンプルを集めるようになるので次の更新ではさらに方策が悪化するという負のスパイラルに陥ります。この問題はステップ幅(学習率)を十分に小さく設定することで回避できますが、小さくしすぎると今度は学習が進みません。

この問題へのアプローチがTRPO(Trust region policy optimization, 信頼領域ポリシーの最適化)です。 方策勾配定理がパラメータの更新方向しか保証してくれないのに対して、TRPOではその名前の通り適切な更新方向とともに方策が改善すると信頼できる適切な更新サイズも保証してくれます。


Trust Region Policy Optimization

TRPOでは方策パラメータθの更新を、更新前/更新後方策間のKL距離制約つき目的関数L(θ)最大化問題と捉えます。

 \displaystyle{
\underset{\theta_{new}} {\text{argmax }}  L(\theta_{new}) = E_{{s}\sim{d^{\pi{θold}}}, \ {a}\sim\pi_{θold} }\left[ \frac{\pi_{\theta_{new}}(a | s )}{\pi_{\theta_{old}}(a | s )} A^{\pi}(s, a) \right]
}

 \displaystyle{
{\text{subject to }} E\left[ {D_{KL}(\pi_{\theta_{old}} || \pi_{\theta_{new}}) } \right] \leq \delta
}

更新前の方策関数 \displaystyle{ \pi_{\theta_{old}} }と更新後の方策パラメータ \displaystyle{ \pi_{\theta_{new}} }のKL距離が \displaystyle{
\delta
} 以下という制約のもとで、 目的関数 \displaystyle{
L(\theta)
} を最大化する \displaystyle{ \theta_{new}}を求める、というのがTRPOにおける方策の更新であり、勾配降下法で方策を更新する方策勾配法とはまったく異なったスキームとなっています。方策勾配法では勾配方向を選択してから固定ステップ幅*1で更新するのに対し、TRPOではKL距離で定義した信頼領域半径の中で目的関数を最大化できるように勾配方向とステップ幅を決定します。もっとわかりやすく言うとTRPOでは方策の更新にOptimizer(AdamとかRMSPropとか)を使わず毎回の更新ごとに制約付き最大化問題を解きます。

目的関数L(θ)がどこから湧いたのかなど混乱しますが、目的関数についてはひとまず置いといてまずはKL距離制約の意味から考えてみましょう。


KL制約項について

方策パラメータθの小さな変化≠方策関数の出力の小さな変化

前述の通り方策パラメータθを大きく更新しすぎると方策関数の出力が急激に変化し、結果として劣化する恐れがあります。 これは逆に、方策関数の出力が大きく変化しないならばパラメータθを大きく更新しても問題ない、ということでもあります。 パラメータθを安全に大きく更新できるならば少ないサンプルでも学習が効率的に進むので非常に有用です。

f:id:horomary:20200819223257p:plain:w700
方策パラメータθの小さな変化≠方策関数の出力の小さな変化
(資料2, The Problem is More Than Step Size より)


KL距離は方策関数の変化の大きさの指標

方策関数の更新前/更新後の出力のKL距離は、方策関数の変化の大きさの良い指標となります。
なぜなら方策関数はある状態sにおいてアクションaを選択する確率分布を出力します。そしてKL距離は2つの確率分布間の距離尺度であるからです。

KL距離の制約下で方策パラメータθを更新することで、一回の更新での方策関数の大きすぎる変質を防ぎつつ可能な最大の更新サイズで方策パラメータθを更新することができるようになります。直感的には、KL距離を制約に使うことで更新前パラメータ \displaystyle{\theta_{old}}周辺での方策関数の変化が急斜面なら慎重に、緩斜面なら大胆にパラメータ更新するようになるので安全にサンプル効率を向上させる効果がある、と理解するとよいのではないでしょうか。


目的関数:L(θ)

前述の通り、TRPOのパラメータ更新は目的関数L(θ)の制約付き最大化問題です。

 \displaystyle{
\underset{\theta_{new}} {\text{argmax }} L(\theta_{new}) =  E_{{s}\sim{d^{\pi{θold}}}, \ {a}\sim\pi_{θold} }\left[ \frac{\pi_{\theta_{new}}(a | s )}{\pi_{\theta_{old}}(a | s )} A^{\pi}(s, a) \right]
}

このL(θ)は後継手法である PPO にも登場しますが、この式は一体どこから湧いて出たのか、と混乱します。


θold からθnewへの更新による報酬期待値の変化を考える

資料2, P12, Proof of Relative Policy Performance Identity
資料3 Relating objectives of two policies

あるタイミングで方策パラメータθをθold からθnewに更新したときの期待報酬和Jの改善量 を考えると(1)式が導かれます。(証明は上記資料を参照)

 \displaystyle{
J(\pi_{\theta_{new}}) - J(\pi_{\theta_{old}}) = E_{{s}\sim{d^{\pi{θnew}}}, \ {a}\sim\pi_{θnew}}\left[ A^{\pi}(s, a) \right]
\tag{1}}


(1)式から、方策パラメータθの更新による方策改善の最大化するためには、上式の右辺を最大化するようなθnewを求めればよい、ということがわかります。しかし、(1)式の右辺はまだ計算不可能です。なぜならば学習に利用できるサンプルは更新前ポリシー \displaystyle{\pi_{\theta_{old}}}によって収集されたものなので

 \displaystyle{
{s}\sim {d^{\pi_{\theta_{old}}}}, \ {a}\sim\pi_{\theta_{old}}
}
 \displaystyle{
d^{\pi_{\theta_{old}}}
} \displaystyle{
{\pi_{\theta_{old}}}
}に従ってサンプルを集めた時の各状態sの出現確率

であるのに対して、(1)式右辺で必要なサンプルは更新後のポリシー \displaystyle{\pi_{\theta_{new}}}で収集されたものであること、つまり

 \displaystyle{
{s}\sim {d^{\pi_{\theta_{new}}}}, \ {a}\sim\pi_{\theta_{new}}
}
 \displaystyle{
d^{\pi_{\theta_{new}}}
} \displaystyle{
{\pi_{\theta_{new}}}
}に従ってサンプルを集めた時の各状態sの出現確率

であるためです。この問題は重点サンプリング法(Importance Sampling)というトリックを使うことで解決します。


重点サンプリング法

重点サンプリング(Importance Sampling)は、関心のある分布とは異なる分布から生成されたサンプルから、特定の分布の特性を推定する一般的な手法です。 Importance sampling - Wikipedia

ある確率分布Pから生成されたサンプルxから計算されるf(x)の平均値(など統計量)を計算したいが、手元には確率分布Qから生成されたサンプルしかない、というときに使用します。on-policy手法を疑似的にoff-policyにするトリックとも表現できます。

 \displaystyle{
E_{x\sim{p}}\left[f(x)\right] = E_{x\sim{q}}\left[\frac{p(x)}{q(x)} f(x)\right]
\tag{2}}

余談ですが重点サンプリングは分子動力学シミュレーションなんかでもよく使う手法で、例えば知りたいのはあるタンパク質の0℃での平均構造だが、低温だとシミュレーション効率が悪いので50℃でシミュレーションした結果を重点サンプリングで重みづけして0℃での平均構造を推定する、というような利用をします。


目的関数L(θ)の導出

(1)式に対して、(2)式と同様に重点サンプリング法を適用しactionの確率分布を \displaystyle{
{\pi_{\theta_{old}}}
}に挿げ替えます。

 \displaystyle{
\ E_{{s}\sim{d^{\pi{θnew}}}, \ {a}\sim\pi_{θnew}}\left[ A^{\pi}(s, a) \right]
}

 \displaystyle{
= E_{{s}\sim{d^{\pi{θnew}}}, \ {a}\sim\pi_{θold} }\left[ \frac{\pi_{\theta_{new}}(a | s )}{\pi_{\theta_{old}}(a | s )} A^{\pi}(s, a) \right]
}


 \displaystyle{
d^{\pi_{\theta_{new}}}
}については、ポリシーが十分に近いならば各状態sの出現確率は更新前後でほとんど変わらないだろう、という仮定*2を置いて、 \displaystyle{
d^{\pi_{\theta_{new}}}
} \displaystyle{
d^{\pi_{\theta_{new}}}
}と近似してしまいます。すると、

 \displaystyle{
= E_{{s}\sim{d^{\pi{θold}}}, \ {a}\sim\pi_{θold} }\left[ \frac{\pi_{\theta_{new}}(a | s )}{\pi_{\theta_{old}}(a | s )} A^{\pi}(s, a) \right]
}

 \displaystyle{
= L(\theta_{new})
}

となり最大化する目的関数L(θ)が導出できました。


② 近似と実装編へつづく

ここままでで、TRPOの方策更新とは目的関数L(θ)の制約付き最大化問題を解くことである、ということとL(θ)の意味を確認しました。

しかし、この最大化問題を解くためには目的関数および制約式にいくつかの近似を行う必要がありますので、②ではこの解説を行います。

horomary.hatenablog.com


補足: conservative policy iteration

TRPO論文のキモは Approximately Optimal Approximate Reinforcement Learningが提唱したconservative policy iteration (保守的なポリシー更新による単調改善)が 混合ポリシー更新でなくても成立することを証明したことであり、この貢献こそが今後もTRPOが重要な手法でありつづける根拠です。しかし、この話から始めるととんでもなく分かりにくくなるので(そしてTRPO論文がわかりにくい直接の原因でもあるので)、本稿では触れていません。 参考: Summary: Conservative Policy Iteration | by Zac Wellmer | Arxiv Bytes | Medium

*1:Adamとか使うなら実際は固定ステップではないがわかりやすさのため

*2:この近似ゆえにあまりにも離れたpolicy間で重点サンプリングを行うのは危険

TRPOにおける共役勾配法とHessian-free

[TRPOシリーズ一覧]

【強化学習】ハムスターでもわかるTRPO ①基本編 - どこから見てもメンダコ

【強化学習】ハムスターでもわかるTRPO ②制約付き最適化問題をどう解くか - どこから見てもメンダコ

【強化学習】ハムスターでもわかるTRPO ③tensorflow2での実装例 - どこから見てもメンダコ


はじめに

TRPO(trust region policy optimization)をはじめとする自然方策勾配派生の強化学習手法では、更新前の方策分布と更新後の方策分布のKLダイバージェンス  \displaystyle{
D_{KL}(\pi_{\theta_{old}} || \pi_{\theta})
} のヘシアン  \displaystyle{H}(≒Fisher情報行列) の逆行列と方策勾配ベクトル  \displaystyle{g}の積である

 \displaystyle{
H^{-1}g
} が更新すべきパラメータの方向となります。

しかし逆行列の計算はパラメータ数に対して計算量が  \displaystyle{O(N^{3})} ですので、深層学習で  \displaystyle{
H^{-1}
} を愚直に計算するのは現実的ではありません。

逆行列どころか、そもそもヘシアンそのものの計算すらしんどいです。


共役勾配法の利用

計算したい逆行列と勾配ベクトルの積を x と置きます。

 \displaystyle{
x = H^{-1}g
}

これを変形すると

 \displaystyle{
Hx = g
}

となり連立一次方程式  \displaystyle{
Ax = b
} の形になります。

この連立一次方程式の解xは共役勾配法によってよい近似解を得ることができます。 (参考資料)

つまりヘシアンの逆行列と特定のベクトルの積の結果だけ欲しいような状況(TRPOや自然勾配法など)であるなら、共役勾配法でxの数値解を求めることでヘシアンの逆行列を計算する必要がなくなります。


共役勾配法アルゴリズム

共役勾配法 \displaystyle{
Ax = b
}を解くアルゴリズムの実装自体は簡単で 英wikiに掲載されている疑似コード通りに実装すればOKです。

f:id:horomary:20200805223831p:plain:w500

上のコード中の \displaystyle{
A
}がヘシアン \displaystyle{
H
}に当たります。

これでヘシアンの逆行列を計算する必要は無くなりましたが、ヘシアンそのものの値は必要です。 逆行列ほどではありませんが、ヘシアンを求める計算は重く、やはりパラメータ数の多い深層学習では実用的ではありません。


Hessian vector product

疑似コードをよく見るとヘシアン \displaystyle{
H
}そのものではなく、ヘシアン \displaystyle{
H
} (=A) とベクトル \displaystyle{
p
}の積  \displaystyle{
Ap
} が分かれば共役勾配法は適用できます。

ヘシアンとベクトルの積 ( Hessian vector product ) であれば、数学的トリックによって効率よく計算できます。

f:id:horomary:20200805225547p:plain:w400
引用元: https://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/

このトリックと共役勾配法の合わせ技によって \displaystyle{
H^{-1}g
}を実用的な速度で計算できます。


Tensoflow2による実装

適当に設定した3パラメータ関数

 \displaystyle{
f(\theta) = { {\theta_1}^{3} + 2 {\theta_1}{\theta_2} + {\theta_2}^{2} - {\theta_1} + {\theta_2}^{3} }
} について

 \displaystyle{
\theta = (3, 1, 5)
} でのヘシアンの逆行列  \displaystyle{H^{-1}}と、これもやはり適当に設定した勾配ベクトル  \displaystyle{g = (3, 12, 6)} の積  \displaystyle{
 H^{-1}g
}共役勾配法で近似解を求めます

出力結果:

愚直な計算結果: tf.Tensor(
[[-0.5625]
 [ 6.5625]
 [ 3.    ]], shape=(3, 1), dtype=float32)

CGによる近似: tf.Tensor(
[[-0.56205505]
 [ 6.5587754 ]
 [ 2.9985006 ]], shape=(3, 1), dtype=float32)


備考

tensorflow2系について、同じ tf.GradientTape コンテキスト で勾配計算を複数回行うときは persistent=True を有効にしないと、

RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.

になります。また、persistent=Trueにした場合は明示的にdelする必要があります。

tf.GradientTape  |  TensorFlow Core v2.4.0


参考資料

baselines/cg.py at master · openai/baselines · GitHub

Efficiently Computing the Fisher Vector Product in TRPO – Telesens

TD3の解説・実装(強化学習)

Tensorflow2で連続値制御のための強化学習手法 TD3 (Twin Delayed DDPG)を実装し二足歩行を学習します。

f:id:horomary:20200629221713p:plain
画像元:https://starwars.disney.co.jp/character/at-at-walker.html

前提手法:DDPG, DQN

horomary.hatenablog.com

horomary.hatenablog.com


はじめに:TD3とは

[1802.09477] Addressing Function Approximation Error in Actor-Critic Methods

Twin Delayed DDPG — Spinning Up documentation

TD3 (Twin Delayed DDPG)はActor-Critic系強化学習手法であるDDPGの改良手法です。

基本的な流れはDDPGとほぼ同じですが、Double DQN論文が指摘したDQNでのQ関数の過大評価がActor-Criticでも生じることを示し、学習安定化のために下記の3つのテクニックを提案しました。

1. Clipped Double Q learning
2. Target Policy Smoothing
3. Delayed Policy Update


1. Clipped Double Q learning

オリジナルのDQNでは、TD誤差の計算における \displaystyle{
\max_{a'} Q_{target}(s', a')
}項で、行動選択を行うネットワークとQ値の評価を行うネットワークが同一であるため、Q(s, a)の過大見積りが発生する傾向があることがDouble Q learning論文で指摘されていました。

horomary.hatenablog.com

Double DQNでは、行動選択をq-networkに、Q値の推定をtarget-q-networkに行わせることにより過大評価の低減を狙いました。 一方、TD3ではQ関数を同時に2つ訓練し常にQ値が小さい方を採用することにより過大評価の低減を狙います

 DDPG: \displaystyle{
L_{critic} = {\frac{1}{N} \sum (r_t + \gamma  Q_{target}(s_{t+1}, \mu_{target}(s_{t+1})) - Q(s_{t}, a_t)) }^2
}

 TD3: \displaystyle{
L_{critic} = {\frac{1}{N} \sum (r_t + \gamma  \min_{i=1,2}({Q_{i, target}(s_{t+1}, \mu_{target}(s_{t+1}))}) - Q(s_{t}, a_t)) }^2
}


コードにするとこんな感じ。簡単ですね。


2. Target Policy Smoothing

 TD3 : \displaystyle{
L_{critic} = { \sum (r_t + \gamma  \min_{i=1,2}({Q_{i, target}(s_{t+1}, \mu_{target}(s_{t+1})+\mathcal{N}(0, \sigma))}) - Q(s_{t}, a_t)) }^2
}

target-valueの計算において、アクションに平均0のガウスノイズを乗せます。これによりQ関数が滑らかになり学習の頑健性の向上が期待できます

直感的には画像分類でdata augmentationとして学習画像を歪ませたりずらしたりするのと似たようなものでないでしょうか。

論文では \mathcal{N}(0, 0.2)からノイズをサンプリングし、ノイズの絶対値が0.5を超えないようにclipしています。※アクションの範囲が-1<x<1である場合を想定したチューニングです。


3. Delayed Policy Update

Q関数に比べてPolicyの更新は緩やかなのでPolicy更新頻度を下げましょう、というテクニック。具体的には論文ではQ関数2回更新するごとにPolicyを1回更新しています。そんなんで効果あるんかなと思うが論文のablation studyを見る限り効果あるらしい。


実装

コード全体はGithubを参照ください。

github.com

ネットワーク構造

CriticNetworkが1つのクラスに実質2つのCritic関数を内包していることがポイント。


更新処理

target_valuesの算出は前述の通りです。


BipedalWalker-v3での学習結果

ちょっと変なフォームではありますが安定した走りを学習しました。

f:id:horomary:20200629225007p:plain