Tensorflow2で連続値制御のための強化学習手法 TD3 (Twin Delayed DDPG)を実装し二足歩行を学習します。
前提手法:DDPG, DQN
はじめに:TD3とは
[1802.09477] Addressing Function Approximation Error in Actor-Critic Methods
Twin Delayed DDPG — Spinning Up documentation
TD3 (Twin Delayed DDPG)はActor-Critic系強化学習手法であるDDPGの改良手法です。
基本的な流れはDDPGとほぼ同じですが、Double DQN論文が指摘したDQNでのQ関数の過大評価がActor-Criticでも生じることを示し、学習安定化のために下記の3つのテクニックを提案しました。
1. Clipped Double Q learning
2. Target Policy Smoothing
3. Delayed Policy Update
1. Clipped Double Q learning
オリジナルのDQNでは、TD誤差の計算における項で、行動選択を行うネットワークとQ値の評価を行うネットワークが同一であるため、Q(s, a)の過大見積りが発生する傾向があることがDouble Q learning論文で指摘されていました。
Double DQNでは、行動選択をq-networkに、Q値の推定をtarget-q-networkに行わせることにより過大評価の低減を狙いました。 一方、TD3ではQ関数を同時に2つ訓練し常にQ値が小さい方を採用することにより過大評価の低減を狙います。
コードにするとこんな感じ。簡単ですね。
2. Target Policy Smoothing
target-valueの計算において、アクションに平均0のガウスノイズを乗せます。これによりQ関数が滑らかになり学習の頑健性の向上が期待できます。
直感的には画像分類でdata augmentationとして学習画像を歪ませたりずらしたりするのと似たようなものでないでしょうか。
論文ではからノイズをサンプリングし、ノイズの絶対値が0.5を超えないようにclipしています。※アクションの範囲が-1<x<1である場合を想定したチューニングです。
3. Delayed Policy Update
Q関数に比べてPolicyの更新は緩やかなのでPolicy更新頻度を下げましょう、というテクニック。具体的には論文ではQ関数2回更新するごとにPolicyを1回更新しています。そんなんで効果あるんかなと思うが論文のablation studyを見る限り効果あるらしい。
実装
コード全体はGithubを参照ください。
ネットワーク構造
CriticNetwork
が1つのクラスに実質2つのCritic関数を内包していることがポイント。
更新処理
target_values
の算出は前述の通りです。
BipedalWalker-v3での学習結果
ちょっと変なフォームではありますが安定した走りを学習しました。
BipedalWalker-v3 with TD3 #reinforcementlearing pic.twitter.com/13uUjJ9tE1
— めんだこ (@horromary) 2020年6月29日