どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

rayで実装する分散強化学習 ④R2D2

Ape-XにRNNを導入することでatari環境において圧倒的SotAを叩き出した分散強化学習手法 R2D2(Recurrent Experience Replay in Distributed Reinforcement Learningをtensorflow+pythonの分散並列処理ライブラリrayで実装します

Recurrent Experience Replay in Distributed Reinforcement Learning | OpenReview

rayで実装する分散強化学習
Pythonの分散並列処理ライブラリRayの使い方 - どこから見てもメンダコ
rayで実装する分散強化学習 ①A3C(非同期Advantage Actor-Critic) - どこから見てもメンダコ
rayで実装する分散強化学習 ②A2C(Advantage Actor-Critic) - どこから見てもメンダコ
rayで実装する分散強化学習 ③Ape-X DQN - どこから見てもメンダコ


はじめに

R2D2(Recurrent Experience Replay in Distributed Reinforcement Learning) とは Ape-XLSTMを導入した手法と表現できます。

DQNにRNNを導入すればエージェントのパフォーマンス向上するのでは? というのは(私ですら思いつく)ごく自然な発想ですが、学習の難しさからか目立った結果を残せていませんでした*1R2D2 (2018) はこのDQN+LSTMアーキテクチャにおける学習を安定化するテクニックを確立し、atari環境において圧倒的SotAを達成しました。RNNを導入するという発想は自然でも学習が困難で実現できていなかったという意味では、強化学習+CNNにおける学習安定化トリックを確立したDQNと似たような立ち位置とも言えます。

f:id:horomary:20210506013517p:plain:w600


RNNの必要性

Q学習はMDP(マルコフ決定過程)を前提としています。MDPとは乱暴に言うなら適切な行動決定に必要な情報はすべて現在の状態観測に含まれている、という仮定が成立するような系です。atari環境では現在の状態観測とはゲーム画面1フレームにあたりますが、しかし1フレームだけではアクション決定には情報がまったく不十分であることは明らかです。

たとえばBreakout(ブロック崩し環境)では1フレームだけの観測情報ではボールの進行方向がわかりません。

f:id:horomary:20210125000147p:plain:w400
1フレームではボールの進行方向がわからない

適切な行動選択のためにはより過去の観測情報も考慮する必要があります。このような系をPOMDP(部分観測マルコフ決定過程)と言います。そこで、DQN(2013)では直近4フレームの観測を重ねてQネットワークの入力とすることで、atariの多くのゲームをPOMDPからMDPっぽい系にすることに成功し、エポックメイキングな手法となりました。

とはいえ、DQNで考慮できる過去とは所詮直近4フレームまでです*2。 直近4フレームはボールの進行方向を判断する程度なら十分ですが、たとえばMs. Pacmanにおいて”そろそろパワーエサ状態が切れそうだな”というような数秒スケールの判断を適切に行うには全く不十分です。 この課題に対する有望なアプローチは Deep Q-networkへ時系列情報を考慮できる Recurrent Neural Network (RNN, 再帰ニューラルネットワーク) を導入することです。R2D2ではRNNファミリーの中でもよく使われるLSTMを採用しています。


RNN(LSTM)の困難

前述の通りPOMDP打破のためにRNNを使うというアイデアは何ら独創的なものではないので、過去にも同様の検討がされてきましたが華々しい結果とはなっていませんでした。これはRNNに関する2つの困難により学習が不安定化するためであると考えられます。


困難①:経験再生時の初期LSTM状態をどうするか?

LSTMの入力は3つであり(下図)、すなわち 入力  \displaystyle{
x_{t}
}, 1step前の出力  \displaystyle{
h_{t-1}
}, そして1step前のセル記憶  \displaystyle{
c_{t-1}
} です。LSTMを持つネットワークで推論するときには当然これらすべてを入力する必要があります。また以下ではc,hをまとめてLSTM状態と呼称します。

※エピソード開始時、つまりt=1の  \displaystyle{
h_{0}
},  \displaystyle{
c_{0}
} はゼロ行列です。

f:id:horomary:20210508235320p:plain:w400
LSTMの構造(Long short-term memory - Wikipedia より)

R2D2では連続する40遷移のセグメントを1サンプルとしてreplay bufferに格納します。ここで、通常のDQNのように遷移情報として(s, a, r, s')だけを蓄積していると、セグメントが再生されたときに(そのセグメントからエピソード開始される場合を除いて)、 \displaystyle{
h_{t-1}
}および  \displaystyle{
c_{t-1}
} が無い=LSTMの初期状態が無いため困ってしまいます。この問題へのもっとも単純な対応策は、エピソード全体を保存しておいてt=0からunrollする(タイムステップを進めていく)ことで対応するセグメントへの初期入力を作ることです。この方法は正確なLSTMの初期状態が得られる一方で、しかし計算量が酷いことになるので実用的ではありません。

そこでR2D2が採用しているのがStored stateトリックです。このトリックでは経験バッファにセグメントの初期LSTM状態 \displaystyle{
(c_{t-1}, h_{t-1})
} も保存しておくことで、セグメントが再生されたときは保存されている初期LSTM状態 \displaystyle{
(c_{t-1}, h_{t-1})
} をLSTMへの初期入力として使用し、t=0からの愚直なunrollを回避します。


困難②:ネットワーク更新によるStored LSTM state の陳腐化

Stored state トリックだけでは経験再生時の初期LSTM状態の問題は解決していません。なぜならば保存されているLSTM状態は過去のQネットワークによって計算されたLSTM状態であり、現在のQネットワークでLSTM状態を計算しなおすと異なる値になるはずだからです。

この保存されたLSTM状態の陳腐化問題を軽減するためにR2D2で提案されたのがBurn-inトリックです。これはStored state トリックで保存された初期LSTM状態を初期入力に使うものの、Stored stateによる入力に近いところでは実際のLSTM Stateとの乖離が大きいと予想されるため、しばらくタイムステップを進めてから学習に使うことで鮮度の低いLSTM状態の問題を軽減しようというアイデアです(下図)。よってburn-inを最大限長くした場合は上述したt=0からの愚直なunrollと同じになります。

f:id:horomary:20210509220012p:plain:w600
burn-inフェーズはtimestepを進めるだけでネットワーク更新に使わない


余談ですが日本語を当てるなら、burn-inの意味・使い方・読み方 | Weblio英和辞書 に例文として記載されている”ならし運転”がしっくりきます。

f:id:horomary:20210510232659p:plain:w700


LSTM+大規模分散学習

上述したStored state & Burn-in トリックを使っても古すぎるセグメントの初期LSTM状態を再現することは難しいと考えられるため、経験バッファにはできるだけ鮮度の高い(on-policynessの高い)セグメントが蓄積されていることが望ましいはずです。

単純には経験バッファのサイズを小さくすれば全体の鮮度が高まることが期待できますが、そうするとサンプル多様性が失われ学習が不安定化することが予想されます。この問題を力押しで解決するのがApe-X で提案された大規模並列分散マルチ方策学習です。分散並列による圧倒的なサンプル投入速度とマルチ方策(異なる探索率ε)エージェントによって経験バッファ内のサンプル多様性を確保します。

ただし、分散並列の効果についてApe-X論文のFig.6でやってたような検証実験が無いので確実なところはわかりません。


その他の重要なトリック

R2D2はLSTMに目が行きますが、パフォーマンスに大きな影響を与えうる(Ape-Xには無かった)トリックがいくつか追加されています。

報酬クリッピングの廃止

atari環境ではいかなる報酬でも (-1, 0, 1) にクリップする reward clippingトリックが長らく使われてきました。これは多くのゲームで学習を安定化させる有用なトリックである一方、一部のゲームの学習を困難にしてしまいます。

[1805.11593] Observe and Look Further: Achieving Consistent Performance on Atari ではその分かりやすい例として、

For example, the agent no longer differentiates between striking a single pin or all ten pins in Bowling.
ー たとえば、agentはボーリングゲームでピンを1本倒すことと10本倒すことを区別できなくなります。

と述べています。もう少し親しみのあるゲームで言えば、Pacmanで通常クッキーを食べるのもオバケを倒すのも同じ+1点になってしまいます。そこで、同論文ではこの問題低減のためによりソフトな報酬(というか target-Q の)スケーリング関数を提案しており、R2D2でもこれを採用しています。


f:id:horomary:20210509172158p:plain:w600
R2D2論文より

f:id:horomary:20210509171851p:plain:w600
Observe and Look Further より


割引率 γ=0.997

これも同様に [1805.11593] Observe and Look Further: Achieving Consistent Performance on Atari で報告されていることですが、R2D2ではγ=0.997という従来(γ=0.99とか)よりかなり高い割引率を採用することでパフォーマンスを向上させています。ablation studyは Fig.7を参照。


Life loss as episode end の廃止

残機を使い切ることではなく、残機が1減ることをエピソード終了と見なすトリックは、報酬クリッピングと共にatari環境のヒューリスティックスとして長らく使われてきましたがR2D2ではこれを廃止しています。ablation studyを見るとこれによって必ずしもパフォーマンスが向上するわけではないようですが、少なくともヒューリスティクスを一つ排除してSotAを達成したことは重要な成果です。

f:id:horomary:20210509222401p:plain:w600
life loss (roll) が従来のやり方


R2D2の実装(CartPole-v0)

ここからはtensorflow+rayによる実装レベルの解説です。まずは単純なCartPole環境でR2D2の実装を確認してみます。ただし、ここでは簡単のためにDueling-network, n-step return, およびvalue-rescalingは省略しています。分散学習部分はApe-X DQN とほぼ同じなので過去記事も併せて参照ください horomary.hatenablog.com

コード全文:
github.com

分散学習の流れ

前述の通り、分散学習の流れ自体はApe-Xと何も変わりません。


R2D2のネットワーク構造

DQNアーキテクチャのDense層がLSTMに変更されただけです。このネットワークはLSTM状態(c, h)に加えて前のアクションの入力も要求することに留意ください。前ステップのアクションはonehot化したうえでconv層からの出力とconcatします。

※論文ではさらに前ステップのrewardも入力すると書いていますが省略しました。


Actor

各セグメントはepisode-endを跨がないという設定から、rolloutは1episode区切りにすると実装が楽です。1episode分のrolloutが終わったらセグメントの切り出しを行い、優先度付き経験再生のための初期優先度を算出したうえでセグメントを送信します。セグメントへの優先度の割り当てはR2D2論文にて提案された方法です。


ActorはEpisodeBufferに1episode分の遷移を蓄積します。


Replay

Actorから受け取ったセグメントを蓄積するSegmentReplayBufferは、対象がセグメントであること以外はApe-Xとまったく同じ優先度つき経験再生バッファなので掲載を省略します。

rayで実装する分散強化学習 ③Ape-X DQN - どこから見てもメンダコ


Learner

Learnerは16セグメントで構成されるミニバッチを16セット受け取りネットワークを更新します。Actorで初期優先度割り当てとほぼ同じ処理ですが、ターゲットネットワーク(target_q_network)はオンラインネットワーク(q_network)とは別にburn-inする必要があるため計算量が増えています。

学習結果

CartPoleでLSTM使う意味はほぼありませんが、問題なく学習出来ています。

f:id:horomary:20210511005946p:plain:w400
x軸:Leaner.update_networkが呼ばれた回数


R2D2の実装(Breakout)

Breakoutでの実装はCartPoleのコードに

  • N step-return

  • Dueling network

  • Value function rescaling

  • RAM節約のためのsegment圧縮

を追加したものとなっていますが、Ape-Xと同様にコードを直接掲載するには多すぎるので結果だけ示します。 詳細はGithubを参照ください。

github.com

学習結果

BreakoutDeterministic-v4環境(ブロック崩し)を、GCPで 24-vCPU/128GB RAM/GPU T4 のプリエンティブルVMインスタンスを使って24時間学習しました。actor数は論文では256であるのに対してここでは20と圧倒的に少ないですが、なんとか正常に学習出来ているっぽくはあります。プリエンティブルインスタンスの24時間制限によりパフォーマンスが急激に向上してきたところで時間切れとなってしまいました。

f:id:horomary:20210514004442p:plain:w600

速度パフォーマンスは論文記載の20%程度しかでていなかったので単純計算でR2D2論文の5時間時点相当くらいの更新回数になっています。プロファイリングしたところLearnerのネットワーク更新がボトルネックになっていたので、GPUをもっと性能が良いものにするかGPU利用効率の良い実装を考える必要があります。


次: Agent57

R2D2をベースに内発的報酬を追加し、さらにエージェントへの方策割り当てをバンディット問題と捉えることでついにすべてのatariゲームで人間超えを達成した手法。そのうち。

*1:https://arxiv.org/abs/1507.06527

*2:NoFrameSkip環境でなければ実質16フレーム