どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

TD3の解説・実装(強化学習)

Tensorflow2で連続値制御のための強化学習手法 TD3 (Twin Delayed DDPG)を実装し二足歩行を学習します。

f:id:horomary:20200629221713p:plain
画像元:https://starwars.disney.co.jp/character/at-at-walker.html

前提手法:DDPG, DQN

horomary.hatenablog.com

horomary.hatenablog.com


はじめに:TD3とは

[1802.09477] Addressing Function Approximation Error in Actor-Critic Methods

Twin Delayed DDPG — Spinning Up documentation

TD3 (Twin Delayed DDPG)はActor-Critic系強化学習手法であるDDPGの改良手法です。

基本的な流れはDDPGとほぼ同じですが、Double DQN論文が指摘したDQNでのQ関数の過大評価がActor-Criticでも生じることを示し、学習安定化のために下記の3つのテクニックを提案しました。

1. Clipped Double Q learning
2. Target Policy Smoothing
3. Delayed Policy Update


1. Clipped Double Q learning

オリジナルのDQNでは、TD誤差の計算における \displaystyle{
\max_{a'} Q_{target}(s', a')
}項で、行動選択を行うネットワークとQ値の評価を行うネットワークが同一であるため、Q(s, a)の過大見積りが発生する傾向があることがDouble Q learning論文で指摘されていました。

horomary.hatenablog.com

Double DQNでは、行動選択をq-networkに、Q値の推定をtarget-q-networkに行わせることにより過大評価の低減を狙いました。 一方、TD3ではQ関数を同時に2つ訓練し常にQ値が小さい方を採用することにより過大評価の低減を狙います

 DDPG: \displaystyle{
L_{critic} = {\frac{1}{N} \sum (r_t + \gamma  Q_{target}(s_{t+1}, \mu_{target}(s_{t+1})) - Q(s_{t}, a_t)) }^2
}

 TD3: \displaystyle{
L_{critic} = {\frac{1}{N} \sum (r_t + \gamma  \min_{i=1,2}({Q_{i, target}(s_{t+1}, \mu_{target}(s_{t+1}))}) - Q(s_{t}, a_t)) }^2
}


コードにするとこんな感じ。簡単ですね。


2. Target Policy Smoothing

 TD3 : \displaystyle{
L_{critic} = { \sum (r_t + \gamma  \min_{i=1,2}({Q_{i, target}(s_{t+1}, \mu_{target}(s_{t+1})+\mathcal{N}(0, \sigma))}) - Q(s_{t}, a_t)) }^2
}

target-valueの計算において、アクションに平均0のガウスノイズを乗せます。これによりQ関数が滑らかになり学習の頑健性の向上が期待できます

直感的には画像分類でdata augmentationとして学習画像を歪ませたりずらしたりするのと似たようなものでないでしょうか。

論文では \mathcal{N}(0, 0.2)からノイズをサンプリングし、ノイズの絶対値が0.5を超えないようにclipしています。※アクションの範囲が-1<x<1である場合を想定したチューニングです。


3. Delayed Policy Update

Q関数に比べてPolicyの更新は緩やかなのでPolicy更新頻度を下げましょう、というテクニック。具体的には論文ではQ関数2回更新するごとにPolicyを1回更新しています。そんなんで効果あるんかなと思うが論文のablation studyを見る限り効果あるらしい。


実装

コード全体はGithubを参照ください。

github.com

ネットワーク構造

CriticNetworkが1つのクラスに実質2つのCritic関数を内包していることがポイント。


更新処理

target_valuesの算出は前述の通りです。


BipedalWalker-v3での学習結果

ちょっと変なフォームではありますが安定した走りを学習しました。

f:id:horomary:20200629225007p:plain