どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

DQN(Deep Q Network)のtensorflow2実装

関連:

horomary.hatenablog.com



はじめに

[1312.5602] Playing Atari with Deep Reinforcement Learning

Human-level control through deep reinforcement learning | Nature

DQNDeepMind社によって2013年(nature版は2015年)に報告された深層CNNを行動価値関数Qの近似に用いる手法です。経験の自己相関低減のために遷移情報を一旦バッファに蓄積し、学習時はそこからランダムに選択してミニバッチを作成するExperience Replay(経験再生) と、ベルマンエラーの計算において重みを固定した過去のQ関数を(教師あり学習で言う)教師データとして使うTarget-networkというトリックを使用します。


準備:CartPole環境の作成

GymのCartPole-v1環境を呼び出し、DQNAgentへ渡します。

CartPole環境にはCartPole-v0CartPole-v1がありますが違いは最大ステップ数のみであり、v0は200ステップ継続で終了、v1は500ステップ継続で終了です。

CartPole環境の詳細については下記リンクを参照
https://gym.openai.com/envs/CartPole-v1/


アルゴリズム概要

各エピソードは以下のように進行します。

  1. 現在の状態(state)からactionを決定
    stateを行動価値関数Q(q_network)に入力して最も価値が高いアクションを採用する。
    ただし、探索率(self.epsilon)の確率で代わりにランダムなアクションをとる。

  2. アクションを実行
    アクション実行前の状態(state)、実行したアクション(action)、即時報酬(r)、アクション実行後の状態(next_state)を記録する。

  3. 行動価値関数Q(q_network)の更新

  4. Target Networkの更新(250ステップごと)

  5. 1-4をエピソード終了(ゲームオーバー)まで繰り返す


Q関数

ネットワーク自体はごくシンプルです。


Experiece Replayの実装

1ステップ分の経験(Experienceクラス)ごとにバッファ(self.experiences)に蓄積します。

Q関数の学習時はバッファからランダムに経験を取得しミニバッチとします。


ベルマンエラーの計算

DQNAgent.update_q_networkでは経験バッファより取得したミニバッチから r_t + \gamma \max (Q_{target}(s_{t+1}, a))を計算します。

ベルマンエラー r_t + \gamma \max (Q_{target}(s_{t+1}, a)) - Q(s_{t}, a_t) の平方二乗和をロスとして勾配を計算します。


結果

f:id:horomary:20200510121924p:plain