はじめに
3D空間スキャンなどのソリューションを提供しているmatterport社がMask-RCNNの実装をOSSとしてgithubに公開してくれているので細胞画像のインスタンスセグメンテーションをやってみました。
この実装の最大の特徴は矩形情報を要求せず、mask情報から自動で適切な矩形を生成することです。 矩形を考慮しなくていいので、画像の回転や変形などのいわゆるaugmentationを自在に行うことができます。
領域塗分けのタスクでは教師画像作成にかかるコストが大きいので、少ない画像での学習を可能にするaugmentationを自在に行えることは強いメリットです。
Mask-RCNN
Mask-RCNNは2017年に発表され、ICCV2017でBest Paper Awardを獲得したインスタンスセグメンテーションのための手法です。
大雑把には、物体検出のための手法であるFaster-RCNNに領域塗分けのためのネットワークを追加した手法と言えます。発想としてはごく自然ですが言うほど簡単なタスクではないというが論文を読むとわかると思います。
インスタンスセグメンテーションについては上の画像が概念をわかりやすく説明しています。
Semantic segmentationではある風船と別の風船を区別できない、Object detectionでは各風船を区別できるが風船の形状情報がないのに対して、Instance segmentationでは各風船を区別しつつ風船の形状情報も出力していることがわかります。
これらの手法は
- 細胞の特定の領域(核など)だけ検出したい => Semantic segmentation
- 細胞数のカウントがしたい => Object Detection
- 細胞数のカウントをしつつ各細胞の形状を知りたい =>Instance segmentation
というように興味の対象に応じて適切に使い分けることが実用上で重要なポイントです。
データセット
HL60 CELL LINE (FIXED CELLS)という30枚のデータセットを使用します。
ダウンロード: https://cbia.fi.muni.cz/datasets/
データセットは上のように元画像と細胞ごとに色分けされたマスク画像がセットになっています。このマスク画像をそのまま使用してもいいのですが、わかりやすさのために下のような二値化された白黒画像を使用することにしました。
※二値化後に隣接した細胞がくっつかないように細胞境界を黒く塗りつぶしています。
環境作成
依存パッケージのインストール
matterport/MASK_RCNNリポジトリをclone(https://github.com/matterport/Mask_RCNN)
pip install -r requirements.txt
で必要ライブラリのインストールpycocotoolsのインストール
cocoリポジトリをclone(https://github.com/waleedka/coco)
PythonAPIディレクトリでpython setup.py build_ext install
matterport/MASK_RCNNのインストール
MASK_RCNNの最上位ディレクトリでpython setup.py install
学習実行
モデル訓練のコードです。cocoデータセット学習済み重みをモデルへロードした後、Dataset
クラスとConfig
クラスをモデルへ渡して訓練開始。
※Dataset
クラスとConfig
クラスの詳細は後述
import os import mrcnn.model as modellib from mrcnn import utils TRAIN_DATASET = os.path.join('dataset', 'train') dataset_train = OneClassDataset() dataset_train.load_dataset(TRAIN_DATASET) dataset_train.prepare() VALID_DATASET = os.path.join('datset', 'valid') dataset_val = OneClassDataset() dataset_val.load_dataset(VALID_DATASET) dataset_val.prepare() config = OneClassConfig() model = modellib.MaskRCNN(mode="training", config=config, model_dir="logs/model") COCO_MODEL_PATH = 'mask_rcnn_coco.h5' if not os.path.exists(COCO_MODEL_PATH): utils.download_trained_weights(COCO_MODEL_PATH) model.load_weights(COCO_MODEL_PATH, by_name=True, exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"]) #: ネットワークのhead部分のみの訓練 model.train(dataset_train, dataset_val, learning_rate=0.001, epochs=10, layers='heads')
画像を格納するdatasetディレクトリ構成
#: 対応するimage-maskペアは同じ名前にした dataset |-train | |-image | | |- 1.jpg | | |- 2.jpg | | └ … | | | └ mask | |- 1.jpg | |- 2.jpg | └ … | └ valid |- image └- mask
Configクラスについて
OneClassConfig
ではmrcnn.config
からConfigクラスを継承して訓練の設定を行います。
※バッチ当たりの画像数は増やしすぎるとGPUメモリに乗りません。
from mrcnn.config import Config class OneClassConfig(Config): #: config名 NAME = "cell_dataset" #: バッチあたり画像数 (GPUのメモリが大きいなら増やしてもよい IMAGES_PER_GPU = 1 # クラス数 = 背景 + 検出クラス数 NUM_CLASSES = 1 + 1 # エポックあたりステップ数 STEPS_PER_EPOCH = 50 VALIDATION_STEPS = 5 # 提案領域のconfidenceが90%以下なら物体検出フェイズをスキップ DETECTION_MIN_CONFIDENCE = 0.9
Datasetクラスについて
OneClassDataset
ではmrcnn.utils
からDatasetクラスを継承して画像読み込みメソッドの定義ととmask生成のメソッドのオーバーライドを行います。
import pathlib import cv2 from PIL import Image from mrcnn import utils from mrcnn.model import log class OneClassDataset(utils.Dataset): def load_dataset(self, dataset_dir): """ データセットを登録 """ #: データセット名、クラスID、クラス名 self.add_class('cell_dataset', 1, 'cell') images = glob.glob(os.path.join(dataset_dir, "image", "*.jpg")) masks = glob.glob(os.path.join(dataset_dir, "mask", "*.jpg")) for image_path, mask_path in zip(images, masks): image_path = pathlib.Path(image_path) mask_path = pathlib.Path(mask_path) assert image_path.name == mask_path.name, 'データセット名不一致' image = Image.open(image_path) height = image.size[0] width = image.size[1] mask = Image.open(mask_path) assert image.size == mask.size, "サイズ不一致" self.add_image( 'cell_dataset', path=image_path, image_id=image_path.stem, mask_path=mask_path, width=width, height=height) def load_mask(self, image_id): """マスクデータとクラスidを生成する """. image_info = self.image_info[image_id] if image_info["source"] != 'cell_dataset': return super(self.__class__, self).load_mask(image_id) mask_path = image_info['mask_path'] mask, cls_idxs = blob_detection(str(mask_path)) return mask, cls_idxs def image_reference(self, image_id): """Return the path of the image.""" info = self.image_info[image_id] if info["source"] == 'cell_dataset': return info else: super(self.__class__, self).image_reference(image_id) def blob_detection(mask_path): mask = cv2.imread(mask_path, 0) #: 念のためもう一度二値化 _, mask = cv2.threshold(mask, 150, 255, cv2.THRESH_BINARY) label = cv2.connectedComponentsWithStats(mask) data = copy.deepcopy(label[1]) labels = [] for label in np.unique(data): #: ラベル0は背景 if label == 0: continue else: labels.append(label) mask = np.zeros((mask.shape)+(len(labels),), dtype=np.uint8) for n, label in enumerate(labels): mask[:, :, n] = np.uint8(data == label) cls_idxs = np.ones([mask.shape[-1]], dtype=np.int32) return mask, cls_idxs
load_dataset
メソッドではデータセットを読み込んでself.add_image()
によってデータセットとして登録します。これはオーバライドでなく追加メソッドなので好きな名前でOKです。
add_image()
で必須のなのはデータセット名、元画像へのパス、任意の画像IDだけですがload_mask()
メソッドでマスク生成するのに必要なのでmask画像へのパスも登録しておきます。
load_mask
メソッドは親クラスからのオーバーライドです。このメソッドは学習時に呼び出され画像一枚についてのマスク情報とクラス情報を返します。
マスク情報とは、shape == [height, width, 画像内の細胞数] のndarrayであり各チャネルが細胞ひとつについてのマスク画像です。
load_mask
メソッドでは[height, width, 1]の二値化マスク画像を受け取り、opencvのブロブ検出関数を利用することでマスク情報を生成しています。(参考図)
マスク情報と同時にクラスIDのリストも生成します。
今回はすべて同じクラス(cell)なので長さ=画像内の細胞数で要素がすべて1のリストを返すだけです。
(self.add_class('cell_dataset', 1, 'cell')
で'cell'のidは1としたため)
適切にマスク情報を生成するようにload_maskメソッドを定義さえできればどのような形式のデータセットでも使用可能なので柔軟性が高いですね。
結果の確認
validation データセットについて結果の確認を行います。
import random from mrcnn import visualize VALID_DATASET = os.path.join('dataset', 'valid') dataset_val = OneClassDataset() dataset_val.load_dataset(VALID_DATASET) dataset_val.prepare() config = InferenceConfig() MODEL_DIR = os.path.join("logs", "model") model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config) weights_path = model.find_last() print("Loading weights ", weights_path) model.load_weights(weights_path, by_name=True) image_id = random.choice(dataset_val.image_ids) image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(dataset_val, config, image_id) info = dataset_val.image_info[image_id] results = model.detect([image], verbose=1) # Display results ax = get_ax(1) r = results[0] visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], dataset_val.class_names, r['scores'], ax=ax, title="Predictions")
class InferenceConfig(OneClassConfig): GPU_COUNT = 1 IMAGES_PER_GPU = 1 DETECTION_MIN_CONFIDENCE = 0.5
難易度の高いセグメンテーションではないですが動作テストとしてはokayishな感じですね。
さいごに
より複雑なデータセットに挑む場合、matterport/MASK_RCNNの細胞検出のサンプルを参考にしましょう。
Mask_RCNN/nucleus.py at master · matterport/Mask_RCNN · GitHub
今回のようなネットワークのheadだけのトレーニングだけでなく、バックボーンの特徴抽出器まで含めたトレーニング実行や動的なimage augmentationのやり方などが学べます。