どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

Soft Actor-Critic (SAC) ②tensorflow2による実装

連続値制御で大人気の強化学習手法であるSoft-Aactor-Criticのtensorflow2実装を解説します。
対象タスクはPendulum-v0とBipedalWalker-v3。

前記事: horomary.hatenablog.com

f:id:horomary:20201209001757j:plain:w800

SAC論文 ①: [1801.01290] Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor

SAC論文 ②: [1812.05905] Soft Actor-Critic Algorithms and Applications

SAC論文 ③: https://arxiv.org/pdf/1812.11103.pdf


前提手法:DDPG, DQN

horomary.hatenablog.com

horomary.hatenablog.com


ここまでの概要

SAC(Soft Actor Critic)はSoft-Q学習のactor-criticへの適用であり、累積報酬和と同時に方策エントロピーの期待値の最大化を目的関数としてます。

 \displaystyle{
J_{soft}(\pi) = E_{\pi} \left[{ \sum_{t=0}^{T}{(R(s_t, a_t)} -  \alpha \log\pi(a_{t} | s_{t}) )}\right]
}

SACの理論的根拠となっているSoft-Q学習は最大エントロピー強化学習に基づくQ学習であり、Q学習の課題である探索力の弱さを方策エントロピー項を目的関数に組み込むことによって解決する自然なアプローチです。また、SAC(およびSoft Q学習)はオフポリシー強化学習手法であり高いサンプル効率が期待できます。

SACの実装はDDPGおよびその改良であるTD3に非常によく似ていますが、Soft-Q学習に由来する探索力の高さから致命的なハイパーパラメータが少なく安定したパフォーマンスを期待できます。

【TF2】DDPGでPendulum-v0【強化学習】 - どこから見てもメンダコ

【強化学習】TD3の解説・実装【TF2】 - どこから見てもメンダコ


全体的なアルゴリズムの流れはDDPGとほぼ同じであり、環境から得た経験をReplayBufferに蓄積しそこからミニバッチを作成してネットワークの更新を行うというループを繰り返します。DDPGと同様にSACは方策とQについてニューラルネットワークによる関数近似を行いますが、状態に基づくアクションの決定は方策関数が担いQ関数は方策関数の更新のためにだけ使用されます。


Soft-Q関数について

SACはsoft-Q関数をニューラルネットワークによって関数近似します。連続値制御のために、QはDQNスタイルではなくDDPGスタイルの実装になります。すなわち、状態Sと行動Aを入力とし、単一のQ値を出力するように実装します。これに対してDQNスタイルのQは状態Sを入力とし、行動Aの次元数分のQを出力します。

f:id:horomary:20200626002801j:plain:w400
Qの関数近似:DDPGスタイル(左)とDQNスタイル(右)の違い

soft-Q関数の更新

Soft-Q学習におけるベルマン方程式通りに更新していけばOKです。 通常のQ学習における価値関数Vが、  \displaystyle{
V(s_{t}) = E_{\pi}\left[ Q(s_{t}, a_{t})\right]
} であるのに対して、 Soft-Q学習における価値関数Vは  \displaystyle{
V_{soft}(s_{t}) = E_{\pi}\left[ - \alpha \log\pi(a_{t} | s_{t}) + Q(s_{t}, a_{t})\right]
}というように状態に方策エントロピーがボーナスとして付加されています。

Soft-Q学習における更新式(ベルマンエラー)
 \displaystyle{
L = r(s, a) + \gamma V(s')  - Q(s, a) = r(s, a) + \gamma \left[-\log\pi( a' | s' ) + Q(s', a')  \right] - Q(s, a)
}
※a'は実際に行ったアクションではなく毎回の更新時に方策関数からサンプリングして決める


さらに、この更新式にTD3で提案されたClipped-Double-Qトリックを適用します。Clipped-Double-QとはQ学習のmax演算子(soft-Q学習ではsoftmax)に由来するQ値の過大評価を軽減するための手法であり、Double Q learning (https://arxiv.org/pdf/1509.06461.pdf) と同様のコンセプトの手法です。具体的には2つのQ関数を用意して小さい方の評価値を採用することでQ値の過大評価を打ち消します。

Soft-Q学習における更新式 with Clipped-double-Q
 \displaystyle{
L = r(s, a) + \gamma \left[-\log\pi( a' | s' ) + \min(Q_1(s', a'), Q_2(s', a')) \right] - Q(s, a)
}
※a'は実際に行ったアクションではなく毎回の更新時に方策関数からサンプリングして決める


プログラミング的にはClipped-Double-Qトリックのために実際に2つのQ関数インスタンスを作ってもよいのですが、コードをすっきりさせるためひとつのQ関数インスタンス内に(self.dualqnet)2つのq関数を内包させることにしました。


2つのq関数を内包するQ関数はtensorflow2で下記のように実装しました。レイヤー構成などはDDPG論文に準拠しています。


ソフトターゲット更新

※Soft-Q学習の”ソフト”とソフトターゲット更新の”ソフト”はとくに関係ありません

Target Networkは、DQN(2013)で提案されて以降、Q学習では標準的に使用される学習安定化のためのテクニックです。

DQNでは10000stepごとくらいの頻度でtarget Q関数とメインQ関数の重みを同期していましたが、この同期頻度の程度がそれなりに重要なハイパーパラメータになってしまっていました。これに対して、DDPGでは1-8step程度の高頻度で少しずつ重みを同期していくことによりtarget Q関数がメインQ関数をゆるやかに後追いするようにさせるSotf-Targetという手法を提案しました。SACでもこのSoft-Targetを使用します。

f:id:horomary:20201220230900p:plain:w500
Deep Deterministic Policy Gradient — Spinning Up documentation

Q関数をtensorflow.keras.Modelで作成しておけばget_weightsおよびset_weightsが使えるので実装上とくに難しいことはありません。


方策関数について

確率的方策は任意の形式が使用可能ですが、論文で単ガウス方策が使用されているのでこれに倣います。SAC論文の初期versionでは混合ガウス分布を使っていましたがmujuco環境では単ガウスでも混合ガウスでもあまりパフォーマンスに影響が無いようです。

https://openreview.net/pdf?id=HJjvxl-Cb


方策関数の更新

前記事(Soft-Actor-Critic (SAC) ①Soft-Q学習からSACへ - どこから見てもメンダコ)では、 方策関数はsoft-Q関数のSoftmax方策に似せていく、つまりKL距離を最小化することにより最適方策が得られるということを解説しました。

これを数式で表現するとSAC論文③ https://arxiv.org/pdf/1812.11103.pdfより

方策関数の更新:Qのsoftmax方策に似せる
f:id:horomary:20201220233643p:plain:w400
f:id:horomary:20201220233652p:plain:w300

この数式は一見難しそうに思えますが、式を書き下していくとそうでもないことがわかります。

 \displaystyle{
D_{KL} ( {\pi( \cdot | s) || \frac{\exp{\frac{1}{\alpha}Q(\cdot, s) } }{Z(s)} } )
}

について、まずはKL距離の定義通りに式変形して、

 \displaystyle{
=  \int \pi(a | s) \log \frac{\pi(a | s)}{ \frac{\exp{\frac{1}{\alpha}Q(\cdot, s)}}{Z(s)} }  da
}

ここで、 \displaystyle{p(x)} がxの確率密度関数であるとき、 \displaystyle{\int p(x)f(x) dx = E_{x \sim p}\left[ {f(x) } \right] } であることから、

 \displaystyle{
=  E_{a\sim\pi} \left[ {  \log \frac{\pi(a | s)}{ \frac{\exp{\frac{1}{\alpha}Q(\cdot, s)}}{Z(s)} }  } \right]
}

 \displaystyle{
=  E_{a\sim\pi} \left[ { \log\pi(a | s) - \frac{1}{\alpha}Q(a, s) + Z(s) } \right]
}

とすっきり変形できました。

よって、

 \displaystyle{
\underset{\pi}{\text{argmin}} \, D_{KL} ( {\pi( \cdot | s) || \frac{\exp{\frac{1}{\alpha}Q(\cdot, s) } }{Z(s)} } )
}

 \displaystyle{
=  \underset{\pi}{\text{argmin}} \, E_{a\sim\pi} \left[ { \log\pi(a | s) - \frac{1}{\alpha}Q(a, s) + Z(s) } \right]
}

ここで、Z(s)はπに依存しない規格化定数((softmax方策の規格化定数、統計力学で言えば分配関数)) なので \displaystyle{\underset{\pi}{\text{argmin}}}の中では無視できます。温度パラメータαもまたπに依存しない正の定数なので-αを掛けることで式を整理しつつ、最小化問題を最大化問題に変換します。

 \displaystyle{
=  \underset{\pi}{\text{argmax}} \, E_{a\sim\pi} \left[ { Q(a, s) -\alpha\log\pi(a | s) } \right]
}

というわけで結局のところ、Q値を最大化しつつ方策エントロピー \displaystyle{ H(\pi) = -\log\pi(a | s)} を最大化するというSoft-Q学習の定義通りの式が出てきました。更新式が単純なので実装も簡単です。


方策関数の実装

モデル構造自体はなんの変哲もないガウス方策ですが、状態sからの行動aのサンプリングにおいて”Reparameterization Trick” と "Squashed Gaussian Policy"という2つのテクニックが用いられています。


Reparameterization trick

上述の方策関数の更新を見れば明らかなように、SACでは方策ロスの計算において

  1. 状態sをガウス方策関数に与えてアクション分布の平均(μ)と標準偏差(σ)を得る
  2. 正規分布 N(μ、σ)からのサンプリングにより確率的にアクションaを決定する。
  3. 決定したアクションaと状態sをQ関数に与えてQ(s, a)を計算する

という処理があります。確率的な処理が誤差逆伝搬の通り道にいるとそこで自動微分が止まってしまうため、方策関数の重みを更新できません。*1

f:id:horomary:20201224002641p:plain:w800

この問題はVAE(Variational Auto Encoder)なんかでもお馴染みの Reparameterization Trick を使用することで解決します。 と言ってもやるべきことは簡単で、確率的処理を誤差逆伝搬の通り道から追い出すだけです。

f:id:horomary:20201224003944p:plain:w800

N(μ, σ) からのサンプリングと μ+σz (zは標準正規分布からサンプリングしたノイズ)は同じ結果になりますので確率的処理を通り道から追い出すことができます。いちおう実験しておきましょう。

f:id:horomary:20201224004808p:plain:w400


Squashed Gaussian Policy

ガウス分布は-∞から∞まであらゆる値をとりうる分布である一方で、たとえば今回のターゲットであるBipedalWalker-v3ではアクションの値を-1から1の範囲に制限する必要があります*2。このような場合にはガウス方策からサンプリングされたアクションにtanhを適用することで-1から1に出力されるアクションの数値範囲を制限することができます。

平均0.8, 標準偏差0.3のガウス分布からサンプリングされた値へtanh関数を適用したのが下図です。まさにガウス分布が-1から1の範囲に押しつぶされた(Squashed)ような分布になっていることがわかります。*3

f:id:horomary:20201224233659p:plain:w500

通常のガウス鵜方策の場合には、ロスの計算に必要な \displaystyle{\log\pi(a | s)}ガウス分布確率密度関数をそのまま使用すればよいですが、tanhによるガウス分布の押し潰しを行った場合  \displaystyle{\log\pi(a | s)} は下記のように計算します。

f:id:horomary:20201224230339p:plain:w400
SAC論文①より

数学強い人なら上の式みただけで理解できるのかもしれませんが私の数学力はハムスターレベルなので困惑しました。

まず、ガウス分布tanhで押しつぶされただけなので*4、破線区間での積分が等しくなるということが直感的にイメージできるでしょうか。

f:id:horomary:20201225002628p:plain:w600

 \displaystyle{ 0.54 = \tanh^{-1}(0.6) = 0.5 \log\frac{1+0.6}{1-0.6}}
 \displaystyle{ 1.09 = \tanh^{-1}(0.8) = 0.5 \log\frac{1+0.8}{1-0.8} }

これを数式で表現すると、tanhを適用されたガウス分布確率密度関数 \displaystyle{P_{\text{squash}}}ガウス分布確率密度関数 \displaystyle{P_{\text{gauss}}}として

 \displaystyle{ \int_{0.6}^{0.8} P_{\text{squash}}(a) da }

 \displaystyle{ = \int_{0.54}^{1.09} P_{\text{gauss}}(u) du }

ここまでイメージできればあとは下記の参考リンクを見れば大丈夫(丸投げ)。

normal distribution - Change of variables: Apply $\tanh$ to the Gaussian samples - Mathematics Stack Exchange

sinhx, coshx, tanhxの逆関数 | 高校数学の美しい物語


温度パラメータαの自動調整

SAC(というかSoft-Q学習)の目的関数では累積報酬和と方策エントロピーの期待値の最大化という多目的最適化問題を、単純な和をとることで一つの式にまとめ単目的最適化問題のように表現しています。しかし、Q値およびエントロピーのスケール感はタスクおよび学習の進み具合によって変わるため、エントロピー項の係数αを適切に設定することによって累積報酬和と方策エントロピーのバランスをとる必要があります。

 \displaystyle{
J_{soft}(\pi) = E_{\pi} \left[{ \sum_{t=0}^{T}{(R(s_t, a_t)} -  \alpha \log\pi(a_{t} | s_{t}) )}\right]
}

※αは論文では温度(temperature)パラメータと表現されています。これはSoft-Q学習における最適方策が統計力学におけるボルツマン分布と同じ形であり、αはボルツマン分布における温度Tに対応するためです。

最初のSAC論文ではこの係数αはハイパーパラメータとして調整されるべき値とされましたが、SAC論文② では係数αの自動調整手法が提案されています。具体的には下式に示すように、エントロピー下限値の制約付き報酬累積和の最大化問題と捉えることで適切なαを決定します。

f:id:horomary:20201226033442p:plain:w500
SAC論文②より

この双対問題を解くことによりαの更新式が得られます。しかし、真面目に最小化問題を解くのはpracticalでないので実際は右辺のEの中身をlossとしてSGDでαを更新していきます

f:id:horomary:20201226034212p:plain:w500
αの更新式

実装はこんな感じ。tf.Variableを直接定義してSGDって案外やる機会がないのでちょっと焦る。


直感的にはエントロピーが目標値Hより小さくなったらαを大きくし、逆に目標値Hより大きくなったらαを小さくして、というように適応的にαを更新していくと理解することができます。ただしエントロピーの目標値Hはやはりハイパーパラメータであり、"-1×アクションの次元数" が推奨値として提案されているものの、とくに理論的根拠があるわけではないのである程度ハイパラチューニングした方がよいと思われます。


結果

コード全文はgithubへ: https://github.com/horoiwa/deep_reinforcement_learning_gallery

Pendulum-v0, Bipedalwalker-v3ともにハイパラ調整の試行錯誤なしで良いパフォーマンスを得ることができました。

*1:実装が似ているDDPGは決定論的方策を用いるため、当然ではあるがこの問題が生じない

*2:ただし、BipedalWalker-v3は-1から1という制約はあるがこれに違反しても勝手にClippingされるのでエラーは吐かない

*3:Squashはレモンスカッシュのスカッシュ

*4:1対1写像であるので