どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

DDPGでPendulum-v0(強化学習, tensorflow2)


はじめに

DDPG(決定論的方策勾配法, Deep Deterministic Policy Gradient)をtensorflow2で実装して連続値制御の基本タスクであるPendulum-v0を解きます。

f:id:horomary:20200625000158g:plain:w500

※DDPGはDQNのエッセンスを多く含むため、DQNの理解が前提となります。

horomary.hatenablog.com

DPG (Deterministic Policy Gradient)): Deterministic Policy Gradient Algorithms

DDPG (Deep DPG): [1509.02971] Continuous control with deep reinforcement learning


DDPG (Deep Deterministic Policy Gradient) とは

一般的な方策勾配法では獲得報酬期待値が最大化されるように方策を最適化します。一方、DDPGではQ値が最大化されるように方策を最適化します。直感的には、Q関数を環境のシミュレータとして利用して方策を最適化していると見なすことができます。

また、DDPGでは方策関数が確率分布ではなくスカラ値を出力します。

一般的な方策勾配法で連続値制御を行う場合は、方策関数(Actor)はある状態の入力に対してアクションの確率分布を出力し、その確率分布からのサンプリングによってアクションを決めていました。典型的には状態Sを入力として正規分布パラメータ (μ, σ) を出力し、この平均μ、標準偏差σ の正規分布からアクションをサンプリングします。これは、方策勾配法に基づいて方策関数のパラメータを更新するためにはある状態sにおいてあるアクションa が選択される確率  \pi (a | s) が必要であるためです。(このような方策関数を確率的方策と呼称します。)


 \displaystyle{
g = \nabla_\theta\log\pi_{\theta}(a_t | s_t )A^{\pi}(s_t , a_t )
}


大雑把な説明ですが、方策勾配法とは  (s_{t}, a_{t}) における  A(s_{t}, a_{t}) の値が良好ならば s_{t} におけるの a_{t} 選択確率 \pi (a | s) が大きくなるように、逆なら小さくなるように更新するという手法です。ここで、もし方策関数がアクションの確率分布でなくスカラ値を出力するならば、 s_{t} におけるの a_{t} の選択確率が計算できない(というか定義できない)ので方策勾配法が成立しません。

そのような方策勾配法の常識にも関わらず、DDPGの方策関数はある状態の入力に対してアクションが一意に決まるようなスカラ値を出力する関数なので決定論的方策と呼称されます。 決定論的方策では例えば状態sを入力としてスカラ値 a を出力します。

このような決定論的(スカラ値出力)方策 は、DPG論文が方策関数(Actor)のパラメータを、Q関数 (Critic) の出力値が大きくなるように更新するならば方策関数は確率分布 \pi (a | s) が必須ではないよね、ということを示したことにより発見されました。冒頭にも書きましたが、直感的にはQ関数を環境のシミュレータとして利用して方策を最適化していると見なすことができます。

f:id:horomary:20200625010116p:plain:w500

更新式内に、ある状態sにおいてあるアクションa が選択される確率  \pi (a | s) が式内に登場しないことに注意してください。

説明よりも実装を見た方が理解しやすいかもしれません。


DDPGのネットワーク構造

DQNスタイルのQ関数は入力が状態sのみでしたが、DDPGスタイルのQ関数(Critic)は状態sとアクションaの2つの入力を要求していることが上述の更新式からわかります。違いを図にしてみましょう。

左がDDPGスタイルのQ関数右がDQNスタイルのQ関数です。

f:id:horomary:20200626002801j:plain

tensorflow2での実装はこんな感じ。

※2つの入力を必ずしも直接concatする必要はありません。隠れ層を噛ませて次元を揃えた後に和をとるみたいな実装も見かけました。


Q関数 (Critic) の更新

上式内の状態行動価値関数 Q の目的関数はDQNとほぼ同様です。

  \displaystyle{
L_{critic} = MSE \left[{ r_t + \gamma  Q_{target}(s_{t+1}, \mu_{target}(s_{t+1})) - Q(s_{t}, a_t)  }\right]
}

実装はこんな感じ


学習を安定させるためのテクニック

DDPGでは学習安定化のためのkey techniqueが2つ提案されています。

1. Soft-Target

DQNでは学習の安定化のためにTarget networkを導入しました。Target networkは定期的にMain Networkと重みを同期します。

このようなtarget networkの同期方法をhard targetと呼称します。hard targetでは”定期的に”を具体的に何ステップごとにするかが学習の安定性を大きく左右するハイパーパラメータとなってしまっているという問題があります。

そこで、論文ではMain NetworkをTarget networkが緩やかに後追いするsoft targetという手法を提案しています。τは更新率のハイパーパラメータであり論文では0.001を採用しています。

f:id:horomary:20200625233451j:plain:w200

DQNではtargetの更新は10000stepごとに行っていましたが、DDPGではネットワークの更新頻度と同じくらい頻繁に少しずつ更新していきます。

この更新処理はtensorflow2.0では以下のように実装できます。


2. 探索ノイズ

DQNでは探索を促進するために、ε-Greedy(一定確率でランダム行動する)を採用していました。一方、DDPGでは探索を促進するために方策関数の出力であるアクションにノイズをのせます。論文ではOUノイズを採用していましたが、シンプルに平均ゼロのガウスノイズを乗せるだけでも問題なく機能するそうなのでこちらで実装しました。

Deep Deterministic Policy Gradient — Spinning Up documentation


DDPGの問題点

DQNに対するDouble DQN の指摘と同様に、DDPGは行動価値を過大評価することがTD3(Twin Delayed DDPG)論文で指摘されています。この問題は後継手法のTD3(Twin Delayed DDPG)では Double DQNと似たようなアプローチで解決することが提案されています。

horomary.hatenablog.com

また、決定論的方策ゆえにQ関数に似たような値ばかり渡すため、Q関数が過学習気味になり学習が不安定になりがちです。この問題はTD3ではQ関数の学習時には方策関数の出力したアクションにノイズを乗せることで解決を提案しています。


実装

コード全文はGithubを参照ください

github.com


結果:Pendulum-v0

Pendulum-v0では初期開始位置によって可能なスコアの上限が変わるのでこれくらいが実質的な限界パフォーマンスでしょう。

f:id:horomary:20200625014501p:plain

結果だけ見ると簡単そうですがハイパラ調整が地味に渋く、(MuJoCo向けに調整された)論文記載の学習率ではまったくスコアが上がりませんでした。


後継手法:TD3, SAC

horomary.hatenablog.com

horomary.hatenablog.com

horomary.hatenablog.com


備考: DDPGはoff-policy

DDPGはoff-policyです。off-policyなのでExperience Replayによる経験の使いまわしが可能なためサンプル効率が良いです。

しかし同じActor-CriticであるA2C/A3Cはon-policyです。なぜでしょう?

これはDDPGでは方策更新の指針となる  Q^\mu(s_{t}, \mu(a_{t})) を再評価できるが、A2C/A3Cでは方策更新の指針となるアドバンテージ A^\mu(s_{t}, a_{t}) = r_{t+1}  +  V(s_{t+1}) -V(s_{t}) を再評価できないためです。

もっと具体的にはA3Cの方策更新の指針であるアドバンテージ関数内の r_{t+1}過去の方策による行動選択の結果であるため現在の方策更新の指針には使えないのでon-policyであるが、DDPGの方策更新の指針であるQ関数なら方策が更新されたら行動選択をやり直してQ関数の値を再評価すればいいのでオフポリシーというわけです。

言われてみればそりゃそうだという感じですが、私は当初混乱しました。


A2CでのBreakout攻略 (multiprocessing利用)

はじめに

A2C (Advantage Actor Critic) は A3Cのバリアントであり、A3Cから非同期 (Asynchronous) 要素を除いた手法です。

A3Cはいろいろ盛り込んでて属性過多な手法だったので、手法の発表後にそれぞれの要素が性能にどの程度の寄与があったのかが検証されました。

結果、各Agentによるネットワークの非同期更新ではなく、各Agentから受け取ったトラジェクトリによりネットワークをまとめて更新する同期更新実装でもパフォーマンスが落ちないことが発見されました。それなら実装が楽だし1台のGPUを効果的に利用できるしこっちの方がいいね、となったのがA2Cです。

名前だけ聞くとA2Cの方が先に発表されたように思えるので私は当初混乱しました。

詳細は OpenAI Baselines: ACKTR & A2C を参照



同期更新をどのライブラリで実装するか

A2Cの学習アルゴリズム自体はA3Cとまったく同じです。

horomary.hatenablog.com

ですので焦点はプロセス間の同期更新をどう実装するか、もっと言うとどのライブラリでプロセス並列化するかです。

本記事ではmultiprocessingライブラリでmulti-agentによる同期更新を実装しますが、もしマルチノードで学習させたいならRayが良いでしょう。

github.com

【更新2020年12月: rayでの実装例を公開しました】

horomary.hatenablog.com


A2Cの実装

コード全体はgituhubに置いています。

github.com


A3Cでは各agentがローカルネットワークを持ち自律的にゲームプレイを行っていましたが、A2Cでは各agentはmaster agent (Brain、あるいは中央指令室) による指示に従って動きます。

具体的には各agentは現在の状態をmaster agentに提示し、master agentは提示された状態から次のアクションを判断しagentにアクション指示を与えます。

f:id:horomary:20200530150556p:plain

Master Agent (Brain)

MsterAgentの役割は①各agentにアクション指示を出して遷移後の状態を受け取る、①を5step繰り返したらネットワークをアップデートする、というだけなのでシンプルに書けます。

Multi-agent環境

このコードはopenai-baselinesのmultiprocessingによるマルチエージェント実装をシンプルにしたものです。

baselines/subproc_vec_env.py at master · openai/baselines · GitHub

multiprocessing.PipeがMasterAgentと各環境への通信手段となります。

Pipe()は2つのコネクションオブジェクトを生成しますがこれは糸電話の両端のようなものと理解しましょう。片端をworker_funcに渡し、もう片端はSubprocVecEnvが保持するようにします。

たとえばMaster Agent がSubProcVecEnv(N並列設定)に対してN個のアクション指示を出すと、SubProcVecEnvが各Agent(worker_func)にaction指示を与えます。SubProcVecEnvは各Agentのアクション実行結果を集約してMaster Agentに返します。


Actor Critic Network

ネットワーク構造は最後にvalueとpolicyに分かれる以外はNature DQNと同じです。


学習結果

CPUのみ15並列で20時間ほどの結果です。かなり安定して学習が進んでいます。 f:id:horomary:20200530163332p:plain

※スコアは定期的(15000ステップごと)に行ったテストプレイ5episodeの平均スコアです。

DQNに比べてメモリ消費も少ないのでそこそこスペックのノートPCでも学習できます。爆熱になるからやらないけど。

A3CでCartPole (強化学習)

深層強化学習において分散並列学習の有用性を示した重要な手法であるA3Cの解説と Tensorflow 2 での実装を行います。

[1602.01783] Asynchronous Methods for Deep Reinforcement Learning


pythonの分散並列処理ライブラリのrayでa3cを実装し直しました(2020/12)

horomary.hatenablog.com


A3C: Asynchronous Actor Critic

A3Cとは、Vanilla Policy Gradient*1の学習を非同期分散並列で行う手法です。分散並列化されたエージェントが好き勝手にサンプル収集&学習行った結果(=NN重みの勾配)だけを中央のパラメータサーバに集めます。加えて、これ以前のActor-Critic系手法ではActorとCriticを別のネットワークとして実装するのが普通でしたが、A3CではActorとCriticを、入力に近い層を共有する双出力ネットワークとしてまとめるという工夫により学習の効率化を実現しています。これは画像入力系のタスク(ゲームとかね)においてとくに効果的なようです。

Vanilla Policy Gradient — Spinning Up documentation


Asynchronous (非同期) とは

A3Cの3つのAの先頭は Asynchronous(非同期)で、複数のAgentによる非同期並列学習を行うことに由来します。

具体的には並列化された各Agentが自律的にrollout (ゲームのプレイ) を実行 & 勾配計算を行い、その勾配情報だけをパラメータサーバ(global network)に送信します。各Agentは定期的に自分のネットワーク (local network) の重みをパラメータサーバ(global network) の重みと同期します。


f:id:horomary:20200523222341p:plain:w600
概要

並列分散Agentで学習を行うことは、単純にCPUリソースに応じて学習が高速化するという恩恵以上に、経験の自己相関を低減し学習を安定化する効果が期待できます。

経験の自己相関による学習の不安定化は強化学習が長く抱えてきた課題でした。この課題について、DQN (2013) は Experience Replay (経験再生) 機構 でバッファに蓄積した経験をランダムに取り出すことで経験の自己相関を低減することにより学習の安定化に成功しまさにエポックメイキングというにふさわしい手法となりました。しかし、経験再生は(基本的には)オフポリシー手法でしかとれないトリックです。

そこでオンポリシー手法であるA3Cではサンプルを集めるAgentを並列化することで自己相関を低減するという手段をとりました。この並列化アプローチは非常に効果的である上、他手法でも容易に転用可能なアイデアであるので、A3Cの発表後には強化学習分野には分散並列化ブームが到来することになりました。

ただし、Pythonは言語特性上、非同期並列処理を行うのがなかなか面倒であるという実装上の問題があります。また、各agentが自律的に学習するというアーキテクチャであるため Agentの数=GPUの数 のときに最大のパフォーマンスを発揮するという計算資源が豊富でない一般人/小規模ラボにはなかなか辛い手法です。


A3CとA2C

A3Cの後にA2Cという手法が発表されていますので、この2つの手法の違いについて解説しておきます。

上述した通り、Pythonの言語特性上の理由で並列Agentたちが自律的に学習し好き勝手なタイミングで共有ネットワークを更新する非同期 (Asynchronous ) 学習の実装は相当面倒です。

また、パフォーマンスを最大化しようと思うとAgentの数と同数のGPUが必要です。

しかしこれがもし同期処理でもよいなら、すなわち各Agentが中央指令室からアクションの指示を受けて一斉に1step進行する、中央指令室は各Agentから遷移先状態(next_state)の報告を受けて次のアクションを指示する、という処理で実装するならば推論 する (=GPUを使う) のは中央指令室だけなのでGPUが一つでOKです。 また、Pythonでよく使われるmultiprocessingライブラリなどで容易に実装可能です。このような非同期でないA3Cの実装をA2Cと呼びます。

A3Cが発表されたあとの検証研究により、A2Cの同期学習でもパフォーマンスが落ちないことがわかったので実装が楽なA2Cがよく使われるようになりました。

A2Cの実装は別記事を参照ください horomary.hatenablog.com


分岐型 Actor-Critic ネットワーク

典型的なActor-Criticアーキテクチャでは、方策ネットワークと価値ネットワークを別に定義して、それぞれ別のロス関数(方策勾配ロス/価値ロス)でネットワークを更新します。

一方、A3CのActor-Criticでは一つのネットワークが方策と価値を出力する分岐型のネットワークを実装し、後述するトータルロスでネットワークを更新します。

Actor関数でもCritic関数でも、観測情報 から情報を抽出する役割を持つInputに近い層は似たような重みになると思われるため、このようなパラメータ共有型のActor-Criticは画像のように高度な表現抽出処理が必要な場合に効果的と思われます。

一方で、今回ターゲットにするCartPoleのように生の観測情報(角度、加速度など)が十分に系の状態を表現している場合にはA3C型のパラメータ共有Actor-Criticの恩恵が受けにくいと考えられることに留意ください。


A3Cのロス関数

上述の通り、A3Cでは一つのネットワークが方策と価値を出力する分岐型のネットワークを実装し、一つのロス関数でネットワークを更新します。

具体的にこのA3Cのロス(Total loss)はアドバンテージ方策勾配, Value loss, 方策エントロピー, の3項に分けて次式のように表せます。

Total loss  = - アドバンテージ方策勾配  + \alpha Value loss  -\beta 方策エントロピー

※方策勾配と方策エントロピーは最大化したいので-1を掛けます。
※係数αとβはハイパーパラメータです

① アドバンテージ方策勾配項

アドバンテージ方策勾配項は、名前の通りアドバンテージ関数Aで評価する方策勾配です。

 \displaystyle{
\log{\pi(a_t | s_t )} A^{\pi}(s_t, a_t)
}

AはAdvantage項であり状態行動価値Q から価値のベースラインとも言える状態価値を差し引いたものと定義され、もっともシンプルには下式のように実装できます。

 \displaystyle{
A(s_t, a_t) = Q(a_t , s_t) - V(s_t) = r + V({s_{t+1}}) - V(s_t)
}

方策勾配の重みづけに状態行動価値  \displaystyle{
Q(a_t , s_t) = r + V({s_{t+1}})
} をそのまま使用するのではなく、価値のベースラインであるV(s_t)を引くことで分散が小さくなり学習の安定化が期待できます

直感的にはアクションの価値(状態行動価値 Q)はしばしば現在の状態(tex:V(s_t))に大きく依存するので分散が小さくなる、と考えると理解しやすいでしょう。

f:id:horomary:20200524010842p:plain:w500
アクションの価値が現在の状態に大きく依存する具体例

アドバンテージ関数について上ではわかりやすさのために1step後までの即時報酬しか使用しないもっともシンプルな例を紹介しましたが、Advantageの実装にはいくつかのパターンがあります。

A3Cのbaselines実装では1-5step程度分までの即時報酬を使用するmulti-stepアドバンテージ(※名称合っているかわからない)を採用しているので今回はこの方法で実装します。これ以外ではGAE (Generalized Advantage Estimation)という手法がよく用いられます。

GAE: [1506.02438] High-Dimensional Continuous Control Using Generalized Advantage Estimation


Value Loss 項

Valueloss 項は DQNとほぼ同じです。もっともシンプルには  r + V({s_{t+1}}) をターゲットとして学習します。

 r + V({s_{t+1}}) - V(s_t)

式から明らかなようにAdvantage関数とA3CにおけるValuelossは同じものになります。

ただし、アドバンテージ方策勾配項におけるAdvantageは定数(勾配を流さない)として扱うのに対して、ValueLoss項では  r + V({s_{t+1}})は定数として扱うが  V({s_{t}}) は勾配が流れるようしなければならないことに注意してください。具体的には適切にtf.stop_gradientするのですが詳細は下記の実装を参照ください。


方策エントロピー

 -\sum_a{\pi(a_t | s_t )} \log{\pi(a_t | s_t  )}

たとえば、ある状態sの入力について出力であるアクションの採用確率が (a1, a2, a3, a4) = [0.25 0.25, 0.25, 0.25] のときと (a1, a2, a3, a4) = [0.85, 0.05, 0.05, 0.05] のときでは前者のほうが方策のエントロピーが大きい状態となります。

方策エントロピー項の追加は、方策関数の正則化効果が期待できます。

具体的には方策のエントロピーが大きくなることにボーナスを与えることで、方策関数の早すぎる収束による局所最適化を防ぎ学習を安定化します。

エントロピー項の係数βは探索の度合いを調整するハイパーパラメータです。


実装

この実装はTensorflow blog に掲載されたのA3C実装 (tensorflow1系での実装) を参考にしています。

Deep Reinforcement Learning: Playing CartPole through Asynchronous Advantage Actor Critic (A3C) with tf.keras and eager execution — The TensorFlow Blog

上述した通り、Pythonでプロセス非同期処理を実装するのはたいへん面倒なため、threadingモジュールを使いスレッド間での非同期処理で実装しています。
threadingでは並列処理による高速化は望めません。


コード全文はGithubへ:

github.com


Asynchronousの実装

スレッド間非同期並列処理のコードがこちら。
各Agent(スレッド)はglobal_counterglobal_ACNetを共有します。

グローバルActorCriticNetworkをbuild()することを忘れると学習が全く進まないことに注意。 tensorflow2.0はdefine by run なわけですが、A3Cではグローバルネットワークは自ら推論することが訓練中一度もないので明示的にbuildしないとdefineされないためです。


Actor Critic ネットワーク

Actor Critic Networkは状態を入力されるとValueとaction確率(softmaxする前なのでlogit)の2つを出力します。

アクションの決定はtensorflow_probabilityで行います。今回の例はアクションが離散値なので役に立ってませんが、連続値アクションをサンプリングするときにはコードがすっきりします。


Agentの挙動

A3CのAgentの動作はざっくりとは下記のような感じ

  1. 最大 N step分 ゲームをプレイ(play_n_steps())しtrajectoryを取得
    もし途中でゲームオーバーになった場合はその時点で2へ。
    N step 進んだらそこでゲーム中断し2へ

  2. 1で得た最大Nステップ分のtrajectoryからロスを計算・勾配情報を取得

  3. 2.で得た勾配情報を共有ネットワーク(global network)に適用する。

  4. ローカルネットワークとグローバルネットワークを同期(重みをコピー)する

  5. 1-4を繰り返す


ロスの計算


この実装では最大N step 先の即時報酬まで使用してAdvantageを計算(multi-step Advantage)するので、trajectoryの何番目のステップかによって先読みする長さが異なることに注意。

f:id:horomary:20200524182451p:plain


学習結果

最後の方はCartPole-v1環境の満点である500点を安定して取れるようになっています。

f:id:horomary:20200524181543p:plain


そしてA2C

horomary.hatenablog.com

*1:この手法名はOpenAIが便宜上こう呼んでいるだけで論文などで正式に発表されたものではないと思う

DQN(Deep Q Network)のtensorflow2実装

関連:

horomary.hatenablog.com



はじめに

[1312.5602] Playing Atari with Deep Reinforcement Learning

Human-level control through deep reinforcement learning | Nature

DQNDeepMind社によって2013年(nature版は2015年)に報告された深層CNNを行動価値関数Qの近似に用いる手法です。経験の自己相関低減のために遷移情報を一旦バッファに蓄積し、学習時はそこからランダムに選択してミニバッチを作成するExperience Replay(経験再生) と、ベルマンエラーの計算において重みを固定した過去のQ関数を(教師あり学習で言う)教師データとして使うTarget-networkというトリックを使用します。


準備:CartPole環境の作成

GymのCartPole-v1環境を呼び出し、DQNAgentへ渡します。

CartPole環境にはCartPole-v0CartPole-v1がありますが違いは最大ステップ数のみであり、v0は200ステップ継続で終了、v1は500ステップ継続で終了です。

CartPole環境の詳細については下記リンクを参照
https://gym.openai.com/envs/CartPole-v1/


アルゴリズム概要

各エピソードは以下のように進行します。

  1. 現在の状態(state)からactionを決定
    stateを行動価値関数Q(q_network)に入力して最も価値が高いアクションを採用する。
    ただし、探索率(self.epsilon)の確率で代わりにランダムなアクションをとる。

  2. アクションを実行
    アクション実行前の状態(state)、実行したアクション(action)、即時報酬(r)、アクション実行後の状態(next_state)を記録する。

  3. 行動価値関数Q(q_network)の更新

  4. Target Networkの更新(250ステップごと)

  5. 1-4をエピソード終了(ゲームオーバー)まで繰り返す


Q関数

ネットワーク自体はごくシンプルです。


Experiece Replayの実装

1ステップ分の経験(Experienceクラス)ごとにバッファ(self.experiences)に蓄積します。

Q関数の学習時はバッファからランダムに経験を取得しミニバッチとします。


ベルマンエラーの計算

DQNAgent.update_q_networkでは経験バッファより取得したミニバッチから r_t + \gamma \max (Q_{target}(s_{t+1}, a))を計算します。

ベルマンエラー r_t + \gamma \max (Q_{target}(s_{t+1}, a)) - Q(s_{t}, a_t) の平方二乗和をロスとして勾配を計算します。


結果

f:id:horomary:20200510121924p:plain


JupyterLab で D3.js × Python

f:id:horomary:20200224215656p:plain

D3.jsが役に立つケース

Pythonにおいてインタラクティブデータ可視化のほとんどのユースケースbokehyやPlotly などのライブラリにより実現可能です。 しかしいくつかのユースケースではライブラリの提供する自由度の制限により本当に表現したいことが実現困難ということもあります。

たとえば以下のような複数グラフ間でのインタラクティブな表現はD3が簡単です。 f:id:horomary:20200224222621g:plain

D3.jsの利用はJupyter Notebookではやや煩雑な手順を踏む必要がありました。 (notebookではpy_d3ライブラリを使うことを推奨します)

しかしJupyterLabでは気軽に利用可能となっています。 ですので、"Bokehで実装するのは難しいから"という理由でデータ可視化を諦める必要はありません。

※ただしコーディングを始める前に、”その可視化は本当に必要か?”と疑うことは重要です。


D3ってなに?

D3.jsは、データに基づいてドキュメントを操作するためのJavaScriptライブラリです。D3は、HTML、SVG、およびCSSを使用してデータを実現するのに役立ちます。D3はWeb標準に重点を置いているため、独自のフレームワークに縛られることなく、最新のブラウザーの全機能を利用できます。強力な視覚化コンポーネントとDOM操作に対するデータ駆動型アプローチを組み合わせています。 D3.js - Data-Driven Documents


Installation

Jupytelabがないならインストールしましょう。 また、pythonからHTMLへデータを渡すためにテンプレートエンジンのjinjaを、 HTMLをJupyterlab上で表示するためにPanelを使用するのでこれらをインストールします。

conda install -c conda-forge jupyterlab
conda install -c pyviz panel
pip install jinja2
#: jupyterlabでのpanelの使用はlabextensionのインストールが必要
jupyter labextension install @pyviz/jupyterlab_pyviz


jupyterlabで可視化するための流れ

pandas.DataFrame.to_jsonで生成されるjson形式のテキストを、jinjaテンプレートエンジンを使用してあらかじめ準備したhtmlファイルに埋め込みます。

これをpanel.pane.HTMLに渡すことによりjupyterlabでD3を作図ライブラリのように使用することが可能になります。

panel.pane.HTMLIPython.display.HTMLでもある程度代用可能です。同様にjinjastring.Templateでも代用可能ですが、jinjaの方が便利です。


今回のサンプルは以下のようなフォルダ構造で作成しました。 jupyter_d3.ipynbがJupyterlabで可視化を実行しているNotebookです。

working_dir
│ ─ correlogram.py
│ ─ jupyter_d3.ipynb
│ ─ simple_scatter.py
│
└─templates
   │ ─  correlogram.html
   └─   simple_scatter.html


シンプルな散布図の例

まずはシンプルな例から始めましょう。

jupyterlab上のコード

Cell 1 :
- PanelをLabで使用するためにpanel.extension()の実行
- jinjaで読み込むためのhtmlテンプレートの置き場所設定。
- ブラウザにD3.jsを読み込むための関数であるinit_d3を定義し、実行。

Cell 2 :
- サンプルデータセットを‘pd.DataFrame‘で用意

f:id:horomary:20200224231655p:plain

Cell 2 :
- simple_scatter.pyに定義されているSImpleScatterクラスをインスタンス化し、サンプルデータセットを渡す
- SImpleScatter.show()で散布図を表示

f:id:horomary:20200224232613p:plain

うまくD3.jsによる作図をPythonでラップできていることがわかると思います。

実際のD3へのデータの受け渡しはsimple_scatter.py内で行っています。

# simple_scatter.py

import panel as pn


class SimpleScatter:

    def __init__(self, df, env):

        self.data = df.to_json(orient="records")

        self.template = env.get_template("simple_scatter.html")

    def show(self, width=400, height=400, marker_size=6):
        html = self.template.render({"DATASET": self.data,
                                     "WIDTH": width,
                                     "HEIGHT": height,
                                     "MARKER_SIZE": marker_size})

        pane_html = pn.pane.HTML(html)

        return pane_html

SimpleScatterはデータフレームをJson形式文字列に変換し、これをsimple_scatter.htmlに埋め込み、panel.pane.HTMLオブジェクトとして返すだけのクラスです。このときマーカサイズなどの各種設定も一緒に埋め込みます。

D3.jsのコードはすべてsimple_scatter.htmlに記述されています。

<!--simple_scatter.html-->

<head>
    <meta charset="UTF-8">
<style>
    #chart{
      margin: 4px;
      box-shadow: 0px 0px 4px lightgray;
      background-color: white;
      border-radius: 10px;
    }
    .d3tip{
        position: absolute;
        text-align: center;
        width: auto;
        height: auto;
        padding: 5px;
        font-size: 10px;
        background: white;
        box-shadow: 0px 0px 10px lightgray;
        visibility: hidden;
      }
</style>
</head>

<body>
<div id="chart"></div>

<script src="https://d3js.org/d3.v5.js"></script>
<script>

    var width = {{ WIDTH }};
    var height= {{ HEIGHT }};
    var margin = {top: 40, right: 40, bottom: 40, left: 40};
    var RADIUS = {{ MARKER_SIZE }};

    var chart_width = width - margin.left - margin.right;
    var chart_height = height- margin.top - margin.bottom;

    var fontsize = 10;
    var fontfamily = "Meiryo UI";

    var DATASET = {{ DATASET }}
    var XNAME = "NOX"
    var YNAME = "AGE"
    var TITLE = "ScatterPlot"


    var tooltip = d3.select("body").append("div").attr("class", "d3tip");

    var svg = d3.select("#chart")
        .append("svg")
        .attr("width", width)
        .attr("height", height)
        .append("g")
        .attr("transform", `translate(${margin.left}, ${margin.top})`);

    // chart title
    svg.append("text")
            .attr("x", (width/ 2) - margin.left)
            .attr("y", 0 - (margin.top/4))
            .attr("text-anchor", "middle")
            .style("font-size", "16px")
            .style("fill", "dimgray")
            .style("font-weight", "bold")
            .style("text-decoration", "underline")
            .text(`${TITLE}`);
    var x_scale = d3.scaleLinear()
        .domain(getScaleMargin(
                 min=d3.min(DATASET.map((o)=>{return o[XNAME]})),
                 max=d3.max(DATASET.map((o)=>{return o[XNAME]})),
                ))
        .range([0, chart_width]);

    var y_scale = d3.scaleLinear()
        .domain(getScaleMargin(
                  min=d3.min(DATASET.map((o)=>{return o[YNAME]})),
                  max=d3.max(DATASET.map((o)=>{return o[YNAME]})),
                 ))
        .range([0, chart_height]);

    var x_axis = d3.axisBottom(x_scale);
    var y_axis = d3.axisLeft(y_scale);

    svg.append("g")
        .attr("class", "xaxis")
        .attr("transform", `translate(0, ${chart_height})`)
        .call(x_axis)
        .append("g")
        .attr("class", "xlabel")
        .append("text")
        .attr("fill", "dimgrey")
        .style("font-size", "16px")
        .style('font-weight', 'bold')
        .attr("x", chart_width)
        .attr("y", -6)
        .style("text-anchor", "end")
        .text(`${XNAME}`);


    svg.append("g")
        .attr("class", "yaxis")
        .call(y_axis)
        .append("g")
        .attr("class", "ylabel")
        .append("text")
        .attr("fill", "dimgrey")
        .style("font-size", "16px")
        .style('font-weight', 'bold')
        .attr("transform", "rotate(-90)")
        .attr("y", 6)
        .attr("dy", ".71em")
        .style("text-anchor", "end")
        .text(`${YNAME}`);


    svg.selectAll(".circle")
        .data(DATASET)
        .enter()
        .append("circle")
        .attr("class", "circle")
        .attr("cx", (d) => {
            return x_scale(d[XNAME]);
        })
        .attr("cy", (d) => {
            return y_scale(d[YNAME]);
        })
        .attr("fill", "steelblue")
        .attr("r", RADIUS)
        .on("mouseover", function(d) {
          tooltip
          .style("visibility", "visible")
          .html(`x: ${d[XNAME]}<br>y: ${d[YNAME]}`);
        })
        .on("mousemove", function(d) {
            tooltip
            .style("top", (d3.event.pageY - 20) + "px")
            .style("left", (d3.event.pageX + 10) + "px");
        })
        .on("mouseout", function(d) {
            tooltip.style("visibility", "hidden");
        });


    function getScaleMargin(min, max){
        let mergin = (max - min) * 0.1;
        return [min - mergin, max + mergin];
    };

</script>
</body>

以上、簡単ですね。


Appendix: 複数グラフ間でのインタラクティブデータ可視化

冒頭の例もSImpleScatterと全く同様に実装しています。単純にHTMLに記述しているD3のコード量が長くなるだけです。

f:id:horomary:20200224222621g:plain

右側の相関図内の相関の大きさを表す円をクリックすると、左側の散布図が対応するデータにアップデートされるというグラフです。

このグラフは探索的データ解析において、1. 相関を確認する -> 2. 散布図を確認する という作業の繰り返しがあまりにもダルいことがモチベーションになり作成しました。

html以外はほぼ同じなので詳細な解説は省き、コードだけ掲載します。
※D3.js初心者が試行錯誤しながら作成したコードです。わりとクソコードであることに留意してご参考ください。

#correlogram.py
import pandas as pd
import panel as pn


class Correlogram:

    def __init__(self, df, env):

        self.df = df

        self.template = env.get_template("correlogram.html")

    def show(self, width=400, height=400, marker_size=6,
             font_size=12, max_col=10):

        df = self.df.iloc[:, :max_col]

        data_json = df.to_json(orient="records")
        corr_json = df.corr().to_json(orient="index")

        html = self.template.render({"CORR_JSON": corr_json,
                                     "DATA_JSON": data_json,
                                     "WIDTH": width,
                                     "HEIGHT": height,
                                     "MARKER_SIZE": marker_size,
                                     "FONTSIZE": font_size})

        pane_html = pn.pane.HTML(html)

        return pane_html
<!--correlogram.html-->

<head>
    <meta charset="UTF-8">
    <style>
    .dashbord{
      width: 100%;
      display: flex;
      flex-wrap: normal;
      background-color: whitesmoke;
    }

    #chart1{
      margin: 4px;
      box-shadow: 0px 0px 4px lightgray;
      background-color: white;
      border-radius: 10px;
    }

    #chart2{
      margin: 4px;
      box-shadow: 0px 0px 4px lightgray;
      background-color: white;
      border-radius: 10px;
    }

    .xaxis path, .xaxis line{
      display: none ;
    }

    .yaxis path, .yaxis line{
      display: none ;
    }

    .tooltip {
        position: absolute;
        text-align: center;
        width: auto;
        height: auto;
        padding: 5px;
        font-size: 10px;
        background: white;
        box-shadow: 0px 0px 10px lightgray;
        visibility: hidden;
      }
      </style>
</head>
<body>
    <div class="dashbord">
        <div id="chart1"></div>
        <div id="chart2"></div>
    </div>
    <script src="https://d3js.org/d3.v5.js"></script>
    <script>
    var width = {{ WIDTH }};
    var height = {{ HEIGHT }};
    var margin = {top: 40, right: 40, bottom: 40, left: 40};
    var fontsize = {{FONTSIZE}};
    var fontfamily = "sans-serif";

    var chart_width = width - margin.left - margin.right;
    var chart_height = height - margin.top - margin.bottom;

    //var rawData = {"CRIM":{"CRIM":1.0,"ZN":-0.2004692197,"INDUS":0.4065834114,"CHAS":-0.0558915822},"ZN":{"CRIM":-0.2004692197,"ZN":1.0,"INDUS":-0.5338281863,"CHAS":-0.0426967193},"INDUS":{"CRIM":0.4065834114,"ZN":-0.5338281863,"INDUS":1.0,"CHAS":0.0629380275},"CHAS":{"CRIM":-0.0558915822,"ZN":-0.0426967193,"INDUS":0.0629380275,"CHAS":1.0}}
    var rawData = {{ CORR_JSON }}
    //var scatterData = [{"CRIM": 1,"ZN": 4,"INDUS": 6,"CHAS": 3},{"CRIM": 9,"ZN": 2,"INDUS": 4, "CHAS": 5}, {"CRIM": 7, "ZN": 3, "INDUS": 1, "CHAS": 0}]
    var scatterData = {{ DATA_JSON }}
    var indices = Object.keys(rawData)

    var xname = indices[0]
    var yname = indices[1]
    var scatterR = {{ MARKER_SIZE }}


    var upperData = [];
    for (let i=0; i<indices.length; i++){
        for (j=i+1; j<indices.length; j++){
            let d = {};
            d.x = indices[i];
            d.y = indices[j];
            d.corr = -1 * rawData[d.x][d.y];
            upperData.push(d);
        }
    }

    var lowerData= [];
    for (let i=0; i<indices.length; i++){
        for (j=i+1; j<indices.length; j++){
            let d = {};
            d.y = indices[i];
            d.x = indices[j];
            d.corr = -1 * rawData[d.x][d.y];
            lowerData.push(d);
        }
    }

    var middleData = [];
    for (let i=0; i<indices.length; i++){
            let d = {};
            d.x = indices[i];
            d.y = indices[i];
            d.corr = rawData[d.x][d.y];
            middleData.push(d);
    }

    var x_scale = d3.scaleBand()
        .domain(indices)
        .range([0, chart_width]);

    var y_scale = d3.scaleBand()
        .domain(indices)
        .range([0, chart_height]);

    var csize = d3.scaleSqrt()
        .domain([0, 1])
        .range([x_scale.bandwidth()/10, x_scale.bandwidth()/4]);

    var ccolor = d3.scaleLinear()
        .domain([-1, 0, 1])
        .range(["#000080", "#fff", "#B22222"]);

    // tooltip
    var tooltip = d3.select("body").append("div").attr("class", "tooltip");

    var svg = d3.select("#chart1")
        .append("svg")
        .attr("width", width)
        .attr("height", height)
        .append("g")
        .attr("transform", `translate(${margin.left}, ${margin.top})`);

    svg.append("text")
            .attr("x", (width/ 2) - margin.left)
            .attr("y", 0 - (margin.top/4))
            .attr("text-anchor", "middle")
            .style("font-size", "16px")
            .style("fill", "dimgray")
            .style("font-weight", "bold")
            .style("text-decoration", "underline")
            .text("Correlogram");


    svg.selectAll(".fname")
        .data(middleData)
        .enter()
        .append("text")
        .attr("class", "fname")
        .text((d) => {
            if (d.x.length < 5){
                return d.x;
                }
            else{
                return d.x.slice(0, 7);
                }
            })
        .attr("x", (d) => {
            return x_scale(d.x) + x_scale.bandwidth()/2;
        })
        .attr("y", (d)=>{
            return y_scale(d.y) + y_scale.bandwidth()/2;
        })
        .style("fill", "dimgrey")
        .style("text-anchor", "middle")
        .style("font-size", fontsize)
        .style('font-weight', 'bold')
        .on("mouseover", function(d) {
          tooltip
            .style("visibility", "visible")
            .html(d.x);
        })
        .on("mousemove", function(d) {
            tooltip
            .style("top", (d3.event.pageY - 20) + "px")
            .style("left", (d3.event.pageX + 10) + "px");
        })
        .on("mouseout", function(d) {
          tooltip.style("visibility", "hidden");
        });

    svg.selectAll(".corCircle")
        .data(lowerData)
        .enter()
        .append("circle")
        .attr("class", "corCircle")
        .attr("cx", (d) => {
            return x_scale(d.x) + x_scale.bandwidth()/2;
        })
        .attr("cy", (d)=>{
            return y_scale(d.y) + y_scale.bandwidth()/2;
        })
        .attr("fill", (d) => {
            return ccolor(d.corr);
        })
        .attr("r", (d) => {
            return csize(Math.abs(d.corr));
        })
        .on("mouseover", function(d){
            d3.select(this)
                .style("r", (d) =>{
                    return csize(Math.abs(d.corr)*3);
                });
            tooltip
              .style("visibility", "visible")
              .html(`x: ${d.x}<br> y:${d.y} <br> r: ${d.corr.toFixed(2)}`);
            })
        .on("mousemove", function(d) {
            tooltip
            .style("top", (d3.event.pageY - 20) + "px")
            .style("left", (d3.event.pageX + 10) + "px");
        })
        .on("mouseout", function(d){
            d3.select(this)
                .style("r", (d) => {
                    return csize(Math.abs(d.corr));
                });
            tooltip.style("visibility", "hidden");
        })
        .on("click", updateScatter);


    svg.selectAll(".corstr")
        .data(upperData)
        .enter()
        .append("text")
        .attr("class", "corstr")
        .text((d) => {
            return d.corr.toFixed(2);
        })
        .attr("x", (d) => {
            return x_scale(d.x) + x_scale.bandwidth()/2;
        })
        .attr("y", (d)=>{
            return y_scale(d.y) + y_scale.bandwidth()/2;
        })
        .style("fill", (d) => {
            if (d.corr >= 0){
                return "red";
            } else {
                return "royalblue";
            }
        })
        .style("text-anchor", "middle")
        .style("font-size", fontsize)
        .style('font-weight', 'bold');


    // Create chart2

    var svg2 = d3.select("#chart2")
        .append("svg")
        .attr("width", width)
        .attr("height", height)
        .append("g")
        .attr("transform", `translate(${margin.left}, ${margin.top})`);

    // chart title
    svg2.append("text")
            .attr("x", (width/ 2) - margin.left)
            .attr("y", 0 - (margin.top/4))
            .attr("text-anchor", "middle")
            .style("font-size", "16px")
            .style("fill", "dimgray")
            .style("font-weight", "bold")
            .style("text-decoration", "underline")
            .text("Scatter Plot");

    svg2.append("clipPath")
        .attr("id", "plot-area-scatter")
        .append("rect")
        .attr("x", 0)
        .attr("y", 0)
        .attr("width", chart_width)
        .attr("height", chart_height)

    var x_scale2 = d3.scaleLinear()
        .domain(getScaleMergin(
                 min=d3.min(scatterData.map((o)=>{return o[xname]})),
                 max=d3.max(scatterData.map((o)=>{return o[xname]})),
                ))
        .range([0, chart_width]);

    var y_scale2 = d3.scaleLinear()
        .domain(getScaleMergin(
                  min=d3.min(scatterData.map((o)=>{return o[yname]})),
                  max=d3.max(scatterData.map((o)=>{return o[yname]})),
                 ))
        .range([0, chart_height]);

    var x_axis2 = d3.axisBottom(x_scale2);
    var y_axis2 = d3.axisLeft(y_scale2);

    svg2.append("g")
        .attr("class", "xaxis2")
        .attr("transform", `translate(0, ${chart_height})`)
        .call(x_axis2)
        .append("g")
        .attr("class", "xlabel")
        .append("text")
        .attr("fill", "dimgrey")
        .style("font-size", "16px")
        .style('font-weight', 'bold')
        .attr("x", chart_width)
        .attr("y", -6)
        .style("text-anchor", "end")
        .text(`${xname}`);


    svg2.append("g")
        .attr("class", "yaxis2")
        .call(y_axis2)
        .append("g")
        .attr("class", "ylabel")
        .append("text")
        .attr("fill", "dimgrey")
        .style("font-size", "16px")
        .style('font-weight', 'bold')
        .attr("transform", "rotate(-90)")
        .attr("y", 6)
        .attr("dy", ".71em")
        .style("text-anchor", "end")
        .text(`${yname}`);


    svg2.append("g")
        .attr("id", "plot-area-scatter")
        .attr("clip-path", "url(#plot-area-scatter)")
        .selectAll(".points")
        .data(scatterData)
        .enter(j)
        .append("circle")
        .attr("class", "points")
        .attr("cx", (d) => {
            return x_scale2(d[xname]);
        })
        .attr("cy", (d) => {
            return y_scale2(d[yname]);
        })
        .attr("fill", "steelblue")
        .attr("r", scatterR);


    function getScaleMergin(min, max){
        let mergin = (max - min) * 0.15;
        return [min - mergin, max + mergin];
    };
    function updateScatter(d){
        xname = d.x;
        yname = d.y;

        x_scale2.domain(getScaleMergin(
                         min=d3.min(scatterData.map((o)=>{return o[xname]})),
                         max=d3.max(scatterData.map((o)=>{return o[xname]}))
                         ));

        y_scale2.domain(getScaleMergin(
                          min=d3.min(scatterData.map((o)=>{return o[yname]})),
                          max=d3.max(scatterData.map((o)=>{return o[yname]})))
                        );

        svg2.selectAll(".xlabel")
            .selectAll("text")
            .text(`${xname}`);

        svg2.selectAll(".ylabel")
            .selectAll("text")
            .text(`${yname}`);

        svg2.selectAll("circle")
            .transition()
            .delay(function(d, i){return 1})
            .attr("cx", (d) => {
                return x_scale2(d[xname]);
            })
            .attr("cy", (d) => {
                return y_scale2(d[yname]);
            });

    }


    </script>
</body>



Dataclassをjson形式でシリアライズ

Data ClassesはPython3.7からの新機能です。その名の通りデータを保持するためのクラスを簡潔に記述することができます。

Dataclassはdataclasses_jsonパッケージを使うことによりお手軽にjson形式へ変換できます。

github.com

json形式でシリアライズできると人間が直接編集したり非Pythonの外部アプリケーションにデータを渡したりする際にとても都合がよいです。

※追記:2022年現在ではpydantic.dataclasses.dataclassも有力な選択肢です

pydantic-docs.helpmanual.io

[目次]


インストール

pip install dataclasses_json

シンプルなDataclassの場合

PermissionConfigは何かのアプリケーションの権限設定をイメージしたサンプルクラスです。

dataclasses_jsonモジュールを Dataclassへ適用するには通常の@dataclassデコレータに@dataclass_jsonデコレータを重ねるだけでOKです。

from datetime import datetime
from dataclasses import dataclass
from typing import List, Dict
from uuid import uuid4, UUID

from dataclasses_json import dataclass_json


@dataclass_json
@dataclass
class PermissionConfig:
    uid : UUID = uuid4()
    updated: datetime = datetime.now()
    owner: str = None
    users : List[str] = None
    user_group: Dict[str, List[str]] = None
    note: str = "ensure_asciiをFalseにしないと日本語は文字化けする"
         

datetime.datetime型やuuid.UUID型にも対応しているのが嬉しいところです。 ではこのクラスへ具体的なデータを格納していきます。

>>> config = PermissionConfig()
>>> config.owner = "Alice"
>>> config.users = ["Bob", "Carol", "Dave", "Eve"]
>>> config.user_group = {"admin": ["Alice", "Bob"], "user":["Carol", "Dave", "Eve"]}
>>> print(config)
PermissionConfig(uid=UUID('f429b01b-3d72-416e-8909-9370b4a34770'), updated=datetime.datetime(2019, 11, 24, 3, 10, 13, 453036), owner='Alice', users=['Bob', 'Carol', 'Dave', 'Eve'], user_group={'admin': ['Alice', 'Bob'], 'user': ['Carol', 'Dave', 'Eve']}, note='ensure_asciiをFalseにしないと日本語は文字化けする')

ここまでは普通にDataclassですね。Json形式への変換は以下のように行います。

>>> config_json = config.to_json(indent=4, ensure_ascii=False)
>>> print(type(config_json)
<class 'str'>
>>> print(config_json)
{
    "uid": "f429b01b-3d72-416e-8909-9370b4a34770",
    "updated": 1574532613.453036,
    "owner": "Alice",
    "users": [
        "Bob",
        "Carol",
        "Dave",
        "Eve"
    ],
    "user_group": {
        "admin": [
            "Alice",
            "Bob"
        ],
        "user": [
            "Carol",
            "Dave",
            "Eve"
        ]
    },
    "note": "ensure_asciiをFalseにしないと日本語は文字化けする"
}

ensure_ascii=Falseにしないと日本語が文字化けすることに注意。

さらにこのJson形式文字列からPermissionConfigを復元します。

>>> config_from_json = PermissionConfig.from_json(setting_json)
>>> print(config_from_json)
PermissionConfig(uid=UUID('c6c0df3d-f702-436a-830a-5b51ec3c909b'), updated=datetime.datetime(2019, 11, 24, 1, 56, 25, 67129, tzinfo=datetime.timezone(datetime.timedelta(seconds=32400), '???? (?W?\x80??)')), owner='Alice', users=['Bob', 'Carol', 'Dave', 'Eve'], user_group={'admin': ['Bob', 'Alice'], 'user': ['Carol', 'Dave', 'Eve']}, note='')

正しく型を復元できているか確認します。

>>>print(type(config_from_json.uid),
           type(config_from_json.updated),
           type(config_from_json.owner),
           type(config_from_json.users),
           type(config_from_json.user_group))
<class 'uuid.UUID'> <class 'datetime.datetime'> <class 'str'> <class 'list'> <class 'dict'>

完璧ですね。

ネストしたDataclassの場合

Dataclassを保持するDataclassも使用可能です。

以下は先ほどのPermissionConfigPersonデータクラスを保持する例です。

@dataclass_json
@dataclass
class Person:
    name: str
    age: int


@dataclass_json
@dataclass
class PermissionConfig:
    uid : UUID = uuid4()
    updated: datetime = datetime.now()

    owner: Person = None
    users : List[Person] = None
    user_group: Dict[str, List[Person]] = None
    note: str = "ensure_asciiをFalseにしないと日本語は文字化けする"

先ほどと同様にデータを格納します。

>>> config = PermissionConfig()
>>> config.owner = Person(name="Alice", age=22)
>>> config.users = [Person(name="Bob", age=22), Person(name="Carol", age=35), Person(name="Dave", age=42), Person(name="Eve", age=27)]
>>> config.user_group = {"admin": [Person(name="Alice", age=22), Person(name="Bob", age=22)], "user":[Person(name="Carol", age=35), Person(name="Dave", age=42), Person(name="Eve", age=27)]}

>>> config_json = config.to_json(indent=4, ensure_ascii=False)
>>> print(type(config_json))
{
    "uid": "d683282a-0cb1-455f-a822-9d9d3cbca3f1",
    "updated": 1574533718.833665,
    "owner": {
        "name": "Alice",
        "age": 22
    },
    "users": [
        {
            "name": "Bob",
            "age": 22
        },
        {
            "name": "Carol",
            "age": 35
        },
        {
            "name": "Dave",
            "age": 42
        },
        {
            "name": "Eve",
            "age": 27
        }
    ],
    "user_group": {
        "admin": [
            {
                "name": "Alice",
                "age": 22
            },
            {
                "name": "Bob",
                "age": 22
            }
        ],
        "user": [
            {
                "name": "Carol",
                "age": 35
            },
            {
                "name": "Dave",
                "age": 42
            },
            {
                "name": "Eve",
                "age": 27
            }
        ]
    },
    "note": "ensure_asciiをFalseにしないと日本語は文字化けする"
}

jsonへのシリアライズはOK。ではデシリアライズは?

>>> config_from_json = PermissionConfig.from_json(config_json)
>>> print(config_from_json)
PermissionConfig(uid=UUID('aa2f5ff0-eee8-4e20-b711-ead988d90844'), updated=datetime.datetime(2019, 11, 24, 2, 11, 55, 430387, tzinfo=datetime.timezone(datetime.timedelta(seconds=32400), '???? (?W?\x80??)')), owner=Person(name='Alice', age=22), users=[Person(name='Bob', age=22), Person(name='Carol', age=35), Person(name='Dave', age=42), Person(name='Eve', age=27)], user_group={'admin': [Person(name='Alice', age=22), Person(name='Bob', age=22)], 'user': [Person(name='Carol', age=35), Person(name='Dave', age=42), Person(name='Eve', age=27)]}, note='ensure_asciiをFalseにしないと日本語は文字化けする')

OK!!

Python初心者講習のためのJupyterHub

JupyterHubを利用して管理者が用意したPythonチュートリアル用 jupyter notebookをユーザーがブラウザからすぐに実行できる環境を構築します。

  • JupyterHubによる、不特定多数のユーザーがブラウザからアクセスするだけでPythonを実行できる環境の提供
  • Dockerコンテナによる、ユーザーごとに隔離された実行サーバ
  • Githubによる、チュートリアル資料の管理

github.com


f:id:horomary:20191107004015p:plain


Pythonワークショップを楽にやりたい

近年のAIブームにより多くの企業においてPython社内教育のニーズが急激に高まっているようですが、外部サービスを利用するならともかく、Pythonによる機械学習研修やワークショップを社内で開催するというのはなかなかに面倒です。

教材を作成するのはまあよいとして、Python環境の構築でコケる人が必ず何人かは出ますし、チュートリアル教材となるjupyter notebookの配布やユーザに教材のバージョンを更新してもらうことなどは大変面倒です。

ここではそのような面倒ごとを解決するために、ユーザー側の事前準備不要で管理者の用意したJupyterNotebookを実行できる環境をJupyterHubによって提供します。

Google ColaboratoryとGoogleCloudによるnotebookの共有機能を利用できる環境ならそっちを利用した方がはるかに楽です。


JupyterHubとは

ブラウザからアクセス/実行できるユーザ認証機能つきJupyterサーバーです。

JupyterHubは、ユーザーグループにノートブックのパワーをもたらします。ユーザーにインストールおよび保守タスクを負担することなく、計算環境とリソースへのアクセスを提供します。ユーザー(学生、研究者、データサイエンティストを含む)は、システム管理者が効率的に管理できる共有リソース上の独自のワークスペースで作業を完了できます。
JupyterHubはクラウドまたは独自のハードウェアで実行され、事前に構成されたデータサイエンス環境を世界中のユーザーに提供することができます。カスタマイズ可能でスケーラブルであり、小規模および大規模なチーム、アカデミックコース、および大規模なインフラストラクチャに適しています。

Google翻訳 from Project Jupyter | JupyterHub

主要な機能は、

  • ユーザ認証機能

  • ユーザごとのsingle jupyter notebook サーバーの生成

の2つですが、これらについてはユースケースに応じて多様なオプションが用意されています。

認証についてはLDAPGithubアカウント認証など多くの選択肢が用意されていますし、single notebook serverの生成についてもDockerを使ったりKubernetesでリソースを管理したりといろいろあります。

今回はPythonワークショップでの使用=使い捨て環境を想定しているので、
NativeAuthenticator (ユーザー名とパスワードでアカウント作成し、管理者が許可する方式)
によってシンプルな認証管理を行い、
DockerSpawner (ユーザーごとにnotebook実行サーバをdockerコンテナで走らせる方式)
によってユーザーごとに使い捨てのnotebook実行サーバーを立ち上げる構成を採用します。


実行環境

テスト環境として

Google Cloud PlatformCompute EngineVMインスタンスで構築します。

VMインスタンスUbuntu 16.04 LTS
ポート開放: TCP 22, 80, 443, 8000


環境準備

関連パッケージのインストール

はじめはpipからjupyterhubをインストールしようとしたのですがnode.jsとのバージョン依存などでハマったのでcondaの使用を推奨します。

#: Minicondaのインストール(AnacondaでもOK)
$ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh

$ sh Miniconda3-latest-Linux-x86_64.sh

$ source .bashrc

$ conda create -n jupyterhub37 python=3.7

$ conda activate jupyterhub37 

#: jupyterhubと関連パッケージのインストール
$ conda install -c conda-forge jupyterhub

$ conda install jupyter_client

$ git clone https://github.com/jupyterhub/nativeauthenticator.git

$ pip install -e nativeauthenticator/

$ pip install dockerspawner


Dockerのインストール

各ユーザーのnotebook実行サーバは独立したdockerコンテナ内で走るのでdocker-CEをインストールします。 インストール方法は公式ドキュメントの通りです。

$ sudo apt-get -y update && sudo apt-get -y upgrade

$ sudo apt-get remove docker docker-engine docker.io containerd runc

$ sudo apt-get -y install \
    apt-transport-https \
    ca-certificates \
    curl \
    gnupg-agent \
    software-properties-common

$ curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -

$ sudo add-apt-repository \
   "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
   $(lsb_release -cs) \
   stable"

$ sudo apt-get -y update

$ sudo apt-get -y install docker-ce docker-ce-cli containerd.io


Notebookサーバ用Docker imageのbuild

各ユーザーのNotebook実行サーバーを走らせるためのdocker imageをbuildします。

以下の内容でDockerfileを作成し、sudo docker build -t hub/test1 . でbuildします。

FROM jupyter/scipy-notebook
USER root
RUN pip install optuna keras
RUN git clone https://github.com/zhiyzuo/python-tutorial.git work
CMD ["jupyterhub-singleuser", "--allow-root"]

Python講習を想定しているのではじめにチュートリアル用のnotebook集githubからcloneしてきます。"python tutorial github "で検索して上の方にあったリポジトリをサンプルに使用しています。

また、scipy-notebookはデータサイエンスに最低限必要なライブラリしかインストールされていないので、この例ではkerasとOptunaを追加インストールしてみます。

rootに切り替えているのは、通常ユーザのままではNotebookサーバーは立ち上がるが新規notebookの作成ができないという状態になったためです。

このトラブルの対処はJupyterHub + dockerspawnerでユーザごとにマウントするボリュームを変える - Qiitaを参考にさせていただきました。


テスト実行

ここまできたらいったんテスト起動してみましょう。

$ mkdir ~/jupyterhub && cd ~/jupyterhub

$ sudo jupyterhub

http;//SERVER_IP:8000へアクセスしてログイン画面が出ればOKです。
httpsではなくhttpであることに注意。

f:id:horomary:20191109163352p:plain

sudo jupyterhubcommand not foundと言われる場合は下記リンクを参考に、visudoで/etc/sudoersをいじりましょう。

sudo したときに「コマンドが見つかりません」と怒られた場合 - 約束の地


jupyterhub_config.pyの作成

まだバニラ状態のjupyterhubが動いただけで、認証はデフォルトのPAM認証(システムユーザー名とパスワードで認証)でありdockerも一切使用していません。


jupyterhubは起動ディレクトリ内にjupyterhub_config.pyというファイルを置くことで認証方法などの詳細な設定が可能になります。

一度起動した後の~/jupyterhubディレクトリを見るとjupyterhub_cookie_secretjupyterhub.sqliteの二つのファイルが生成されているはずです。

このディレクトリ内にさらにjupyterhub_config.pyを下記の内容で追加しましょう。

#: 認証方法と管理ユーザ名の指定
c.JupyterHub.authenticator_class = 'nativeauthenticator.NativeAuthenticator'
c.Authenticator.admin_users = {'admin'}


from jupyter_client.localinterfaces import public_ips
ip = public_ips()[0]
c.JupyterHub.hub_ip = ip

#: SSLを使うなら鍵と証明書の位置を指定
#c.JupyterHub.ssl_key = '/etc/ssl/server.key'
#c.JupyterHub.ssl_cert = '/etc/ssl/server.crt'


c.JupyterHub.spawner_class = 'dockerspawner.DockerSpawner'

#: Jupyter labを使う
c.Spawner.default_url = '/lab'
c.DockerSpawner.default_url = '/lab'


"""ユーザーの作業ディレクトリの指定
jupyter/scipy-notebook イメージにはjovyanというユーザディレクトリがデフォルトで存在するのでそこを使用する
作業ディレクトリ(/home/jovyan/work)をdocker volumeマウントすることで作業内容が永続化されるように設定する

# jovyanって誰?という疑問についてはこちらを参照
# [https://github.com/jupyter/docker-stacks/issues/358]
"""
notebook_dir = '/home/jovyan/work'
c.DockerSpawner.notebook_dir = notebook_dir
c.DockerSpawner.volumes = {'jupyterhub-user-{username}': notebook_dir}

#: single notebookサーバ用dockerイメージの選択
c.DockerSpawner.container_image = 'hub/test1'

ここで改めてjupyterhubを起動します。

$ sudo jupyterhub
[I 2019-11-09 08:20:02.459 JupyterHub app:2120] Using Authenticator: nativeauthenticator.nativeauthenticator.NativeAuthenticator
[I 2019-11-09 08:20:02.459 JupyterHub app:2120] Using Spawner: dockerspawner.dockerspawner.DockerSpawner-0.11.1
...

ログから設定内容が反映されていることがわかります。

http;//SERVER_IP:8000へアクセスすると、先ほどとは異なりログイン画面下に"Signup!"が追加されています。

f:id:horomary:20191109172436p:plain

まずは"Signup!"からユーザー登録を行いましょう。

config内で管理者に登録されているユーザーも"Signup!"からパスワードを登録する必要があります。

一般ユーザーはアカウントを作成した後、管理者の承認を受ける必要があります(やり方は後述)。

f:id:horomary:20191109172701p:plain

ログインするとJupyterLabが起動します。また、githubからcloneしてきた教材が初期状態で配置されていることが確認できます。

f:id:horomary:20191109173214p:plain

新規ユーザー承認

一般ユーザーはアカウントを作成した後、管理者の承認を受けるとログイン可能になります。

承認のためにはまず管理ユーザーでログインした状態でhttp://SERVER_IP:8000/hub/authorizeにURLの直接打ち込みでアクセスします。やや面倒ですがNativeAuthenticatorではGUIアクセスは実装されていないっぽいです。

f:id:horomary:20191109174717p:plain

承認ボタンを押せば承認作業完了です。


おわりに

ブラウザからアクセスできるPythonのマルチユーザー環境は今回想定したようなワークショップ以外にも、グループでの共同開発や出先でのデモンストレーションなど様々なユースケースがあるかと思います。

公式リポジトリに様々なオプションが用意されているので用途に合わせた最高のJupyterhub環境を構築しましょう。

参考

CentOS 7にJupyterHub 1.0を導入(SSL対応/NativeAuthenticator/DockerSpawner) - hiroki-sawano's diary

JupyterHub + dockerspawnerでユーザごとにマウントするボリュームを変える - Qiita

DockerSpawnerをカスタムして細かいフォルダ共有制御を実現している例

おまけ:ログアウト

ログアウトの方法はわかりづらいのですが、"FIles"にあります。

f:id:horomary:20191109173812p:plain

"Hub Control Panel"からGUIアクセスできる"Admin"ダッシュボードはNativeAuthenticatorを使用している場合は飾りなので無視してください。

おまけ:Volume

dockerの基本的な話ではありますが永続化されたフォルダのvolumeは以下のように確認できます。

$ sudo docker volume ls
DRIVER              VOLUME NAME
local               jupyterhub-user-admin
local               jupyterhub-user-testuser1