どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

Deepchm③(グラフ畳み込みで回帰)

分子グラフの畳み込みモデルでdelaney-水溶解度データセットの回帰を行います。

csvデータのロード

例のごとくdelaney水溶解度データセットをロードします。 MPNNモデルやCoulomb matrix featurizerなどを使いたい場合は立体構造が必要ですが、ConvMolFeaturizerはsmiles構造でOKです。

import deepchem as dc
from rdkit import Chem
import numpy as np

csv = "datasets/delaney-processed.csv"

featurizer = dc.feat.ConvMolFeaturizer()
loader = dc.data.CSVLoader(
      tasks=['measured log solubility in mols per litre'],
      smiles_field="smiles",
      featurizer=featurizer)
    
dataset = loader.featurize(csv)
>>>dataset.X.shape
(1128, 1024)
>>>dataset.y.shape
(1128, 1)

データセットの分割

splitter = dc.splits.ScaffoldSplitter()
train_dataset, valid_dataset, test_dataset = splitter.train_valid_test_split(dataset,frac_train=0.8,
                                                                             frac_valid=0.1,frac_test=0.1)


transformers = [dc.trans.NormalizationTransformer(transform_y=True,
                                                  dataset=train_dataset)]

for transformer in transformers:
    train_dataset = transformer.transform(train_dataset)
    valid_dataset = transformer.transform(valid_dataset)
    test_dataset = transformer.transform(test_dataset)

モデルの訓練

from deepchem.models.tensorgraph.models.graph_models import GraphConvModel
model = GraphConvModel(n_tasks=1, batch_size=64, uncertainty=False, mode='regression')
model.fit(train_dataset, nb_epoch=30)

精度の評価

metric = dc.metrics.Metric(dc.metrics.r2_score)

print("Evaluating model")
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)

print("Train scores")
print(train_scores)

print("Validation scores")
print(valid_scores)
### out
Evaluating model
computed_metrics: [0.9236492543579347]
computed_metrics: [0.5783662081038962]
Train scores
{'r2_score': 0.9236492543579347}
Validation scores
{'r2_score': 0.5783662081038962}

可視化

まあまあの精度でしょうか。

import matplotlib.pyplot as plt
%matplotlib inline

predicted_test = model.predict(test_dataset)
true_test = test_dataset.y
plt.scatter(predicted_test, true_test)
plt.xlabel('Predicted log-solubility in mols/liter')
plt.ylabel('True log-solubility in mols/liter')
plt.title('GCNN-predicted vs. true log-solubilities')

f:id:horomary:20181021121652p:plain

参考・引用

Graph Convolutions For Tox21

Model Interpretability