どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

TRPOにおける共役勾配法とHessian-free

[TRPOシリーズ一覧]

【強化学習】ハムスターでもわかるTRPO ①基本編 - どこから見てもメンダコ

【強化学習】ハムスターでもわかるTRPO ②制約付き最適化問題をどう解くか - どこから見てもメンダコ

【強化学習】ハムスターでもわかるTRPO ③tensorflow2での実装例 - どこから見てもメンダコ


はじめに

TRPO(trust region policy optimization)をはじめとする自然方策勾配派生の強化学習手法では、更新前の方策分布と更新後の方策分布のKLダイバージェンス  \displaystyle{
D_{KL}(\pi_{\theta_{old}} || \pi_{\theta})
} のヘシアン  \displaystyle{H}(≒Fisher情報行列) の逆行列と方策勾配ベクトル  \displaystyle{g}の積である

 \displaystyle{
H^{-1}g
} が更新すべきパラメータの方向となります。

しかし逆行列の計算はパラメータ数に対して計算量が  \displaystyle{O(N^{3})} ですので、深層学習で  \displaystyle{
H^{-1}
} を愚直に計算するのは現実的ではありません。

逆行列どころか、そもそもヘシアンそのものの計算すらしんどいです。


共役勾配法の利用

計算したい逆行列と勾配ベクトルの積を x と置きます。

 \displaystyle{
x = H^{-1}g
}

これを変形すると

 \displaystyle{
Hx = g
}

となり連立一次方程式  \displaystyle{
Ax = b
} の形になります。

この連立一次方程式の解xは共役勾配法によってよい近似解を得ることができます。 (参考資料)

つまりヘシアンの逆行列と特定のベクトルの積の結果だけ欲しいような状況(TRPOや自然勾配法など)であるなら、共役勾配法でxの数値解を求めることでヘシアンの逆行列を計算する必要がなくなります。


共役勾配法アルゴリズム

共役勾配法 \displaystyle{
Ax = b
}を解くアルゴリズムの実装自体は簡単で 英wikiに掲載されている疑似コード通りに実装すればOKです。

f:id:horomary:20200805223831p:plain:w500

上のコード中の \displaystyle{
A
}がヘシアン \displaystyle{
H
}に当たります。

これでヘシアンの逆行列を計算する必要は無くなりましたが、ヘシアンそのものの値は必要です。 逆行列ほどではありませんが、ヘシアンを求める計算は重く、やはりパラメータ数の多い深層学習では実用的ではありません。


Hessian vector product

疑似コードをよく見るとヘシアン \displaystyle{
H
}そのものではなく、ヘシアン \displaystyle{
H
} (=A) とベクトル \displaystyle{
p
}の積  \displaystyle{
Ap
} が分かれば共役勾配法は適用できます。

ヘシアンとベクトルの積 ( Hessian vector product ) であれば、数学的トリックによって効率よく計算できます。

f:id:horomary:20200805225547p:plain:w400
引用元: https://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/

このトリックと共役勾配法の合わせ技によって \displaystyle{
H^{-1}g
}を実用的な速度で計算できます。


Tensoflow2による実装

適当に設定した3パラメータ関数

 \displaystyle{
f(\theta) = { {\theta_1}^{3} + 2 {\theta_1}{\theta_2} + {\theta_2}^{2} - {\theta_1} + {\theta_2}^{3} }
} について

 \displaystyle{
\theta = (3, 1, 5)
} でのヘシアンの逆行列  \displaystyle{H^{-1}}と、これもやはり適当に設定した勾配ベクトル  \displaystyle{g = (3, 12, 6)} の積  \displaystyle{
 H^{-1}g
}共役勾配法で近似解を求めます

出力結果:

愚直な計算結果: tf.Tensor(
[[-0.5625]
 [ 6.5625]
 [ 3.    ]], shape=(3, 1), dtype=float32)

CGによる近似: tf.Tensor(
[[-0.56205505]
 [ 6.5587754 ]
 [ 2.9985006 ]], shape=(3, 1), dtype=float32)


備考

tensorflow2系について、同じ tf.GradientTape コンテキスト で勾配計算を複数回行うときは persistent=True を有効にしないと、

RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.

になります。また、persistent=Trueにした場合は明示的にdelする必要があります。

tf.GradientTape  |  TensorFlow Core v2.4.0


参考資料

baselines/cg.py at master · openai/baselines · GitHub

Efficiently Computing the Fisher Vector Product in TRPO – Telesens