LIMEというモデル解釈ツールを使用してDeepChemで作成した回帰モデルの解析を行います。
本稿はDeepChemのモデル解釈チュートリアルを参考に作成しています。
※化合物ではなく一般的なテーブルデータから作成されたモデルにLIMEを適用する例は下記を参照ください
LIMEとは
LIME(local interpretable model-agnostic explanations)は機械学習モデルの解釈のためのpythonパッケージです。 github.com ※今回は回帰モデルに適用していますが、基本的には分類モデルに対しての適用を想定しているようです。
Interpretability
いくら予測精度が良かったとしても、中身の解釈できないモデルは依頼者にとって受け入れ難いことがしばしばあります。 研究者なんかは背景原理の考察を重視するため、その傾向が強いのではないでしょうか。
解釈可能なモデルは、予測の中身が直感に合わなかったとしても新たな実験の指針となるので研究者にとっては歓迎すべきものです。
準備
前回XGBで作成した、低分子化合物の水溶解度予測モデルをそのまま使用します。
LIMEによるモデル解釈
import numpy as np feature_names = dc.feat.RDKitDescriptors.allowedDescriptors num_unique_val = np.array([len(set(dataset.X[:,x])) for x in range(dataset.X.shape[1])]) categorical_features = np.argwhere(num_unique_val <= 10).flatten()
- feature_name:
各説明変数名のリスト - categorical features:
カテゴリ変数を示すインデックス。ユニークな値が10以下の説明変数をcategorical_featuresとしました。
### Deepchemのmodelのみ必要 def eval_model(my_model, transformers): def eval_closure(x): ds = dc.data.NumpyDataset(x, None, None, None) predictions = best_model.predict(ds) return predictions return eval_closure best_model = eval_model(best_model, transformers)
上のコードはDeepChemのmodelを使うときのみ必要。LIMEはsklearnのmodel.predict(X)
のようにXを渡してyが返ってくることを想定していますが、Deepchemのmodel.predict(dataset)ではXとyとidとtaskを一緒くたに渡す使用になっています。そこで、Xだけを抽出するようにモデルをラップする必要があります。
import lime import lime.lime_tabular explainer = lime.lime_tabular.LimeTabularExplainer(train_dataset.X, feature_names=feature_names, class_names=['Solubility'], categorical_features=categorical_features, verbose=True, mode='regression')
from rdkit.Chem.Draw import IPythonConsole from IPython.display import SVG from rdkit.Chem import rdDepictor from rdkit.Chem.Draw import rdMolDraw2D i = 15 mol = Chem.MolFromSmiles(test_dataset.ids[i])
それでは適当に抽出した上の化合物について予測モデルの説明を作成します。
### 注意 ### sklearnの場合は、 ### bet_model を model.predict (回帰) or model.predict_proba(分類)に書き換え exp = explainer.explain_instance(test_dataset.X[i], best_model, num_features=5) exp.show_in_notebook(show_table=True, show_all=False)
>>> exp.as_list() [('MinEStateIndex <= 1.42', 0.663314467748405), ('MolLogP > 64.38', -0.2120534031644591), ('EState_VSA10 > 402.73', -0.14520452044984966), ('SMR_VSA2 > 234.37', -0.1409951864572008), ('SlogP_VSA4 > 17.85', 0.07122605689712962)]
logP関連の説明変数の寄与が大きいのは当然ですね。
とりあえずは動いたので良し。