関連:
CartPole-v1 with DQN pic.twitter.com/5OYfLzpV2S
— めんだこ (@horromary) 2020年5月10日
はじめに
[1312.5602] Playing Atari with Deep Reinforcement Learning
Human-level control through deep reinforcement learning | Nature
DQNはDeepMind社によって2013年(nature版は2015年)に報告された深層CNNを行動価値関数Qの近似に用いる手法です。経験の自己相関低減のために遷移情報を一旦バッファに蓄積し、学習時はそこからランダムに選択してミニバッチを作成するExperience Replay(経験再生) と、ベルマンエラーの計算において重みを固定した過去のQ関数を(教師あり学習で言う)教師データとして使うTarget-networkというトリックを使用します。
準備:CartPole環境の作成
GymのCartPole-v1
環境を呼び出し、DQNAgent
へ渡します。
CartPole環境にはCartPole-v0
とCartPole-v1
がありますが違いは最大ステップ数のみであり、v0は200ステップ継続で終了、v1は500ステップ継続で終了です。
CartPole環境の詳細については下記リンクを参照
https://gym.openai.com/envs/CartPole-v1/
アルゴリズム概要
各エピソードは以下のように進行します。
現在の状態(
state
)からaction
を決定
state
を行動価値関数Q(q_network
)に入力して最も価値が高いアクションを採用する。
ただし、探索率(self.epsilon
)の確率で代わりにランダムなアクションをとる。アクションを実行
アクション実行前の状態(state
)、実行したアクション(action
)、即時報酬(r
)、アクション実行後の状態(next_state
)を記録する。行動価値関数Q(
q_network
)の更新Target Networkの更新(250ステップごと)
1-4をエピソード終了(ゲームオーバー)まで繰り返す
Q関数
ネットワーク自体はごくシンプルです。
Experiece Replayの実装
1ステップ分の経験(Experience
クラス)ごとにバッファ(self.experiences
)に蓄積します。
Q関数の学習時はバッファからランダムに経験を取得しミニバッチとします。
ベルマンエラーの計算
DQNAgent.update_q_network
では経験バッファより取得したミニバッチからを計算します。
ベルマンエラー の平方二乗和をロスとして勾配を計算します。
結果