どこから見てもメンダコ

軟体動物門頭足綱八腕類メンダコ科

Mask-RCNNで細胞画像のインスタンスセグメンテーション

f:id:horomary:20190728041918j:plain

はじめに

3D空間スキャンなどのソリューションを提供しているmatterport社がMask-RCNNの実装をOSSとしてgithubに公開してくれているので細胞画像のインスタンスセグメンテーションをやってみました。

github.com

matterport.com

この実装の最大の特徴は矩形情報を要求せず、mask情報から自動で適切な矩形を生成することです。 矩形を考慮しなくていいので、画像の回転や変形などのいわゆるaugmentationを自在に行うことができます。

領域塗分けのタスクでは教師画像作成にかかるコストが大きいので、少ない画像での学習を可能にするaugmentationを自在に行えることは強いメリットです。


Mask-RCNN

Mask-RCNNは2017年に発表され、ICCV2017でBest Paper Awardを獲得したインスタンスセグメンテーションのための手法です。

[1703.06870] Mask R-CNN

大雑把には、物体検出のための手法であるFaster-RCNNに領域塗分けのためのネットワークを追加した手法と言えます。発想としてはごく自然ですが言うほど簡単なタスクではないというが論文を読むとわかると思います。

f:id:horomary:20190728002316j:plain
転載:https://engineering.matterport.com/splash-of-color-instance-segmentation-with-mask-r-cnn-and-tensorflow-7c761e238b46

インスタンスセグメンテーションについては上の画像が概念をわかりやすく説明しています。

Semantic segmentationではある風船と別の風船を区別できない、Object detectionでは各風船を区別できるが風船の形状情報がないのに対して、Instance segmentationでは各風船を区別しつつ風船の形状情報も出力していることがわかります。

これらの手法は
- 細胞の特定の領域(核など)だけ検出したい => Semantic segmentation
- 細胞数のカウントがしたい => Object Detection
- 細胞数のカウントをしつつ各細胞の形状を知りたい =>Instance segmentation
というように興味の対象に応じて適切に使い分けることが実用上で重要なポイントです。


データセット

HL60 CELL LINE (FIXED CELLS)という30枚のデータセットを使用します。
ダウンロード: https://cbia.fi.muni.cz/datasets/

f:id:horomary:20190728011819j:plain
画像-マスクペアのデータセット

データセットは上のように元画像と細胞ごとに色分けされたマスク画像がセットになっています。このマスク画像をそのまま使用してもいいのですが、わかりやすさのために下のような二値化された白黒画像を使用することにしました。

f:id:horomary:20190728011708j:plain
レーニングに使用するimage-maskペア

※二値化後に隣接した細胞がくっつかないように細胞境界を黒く塗りつぶしています。


環境作成

  1. 依存パッケージのインストール
    matterport/MASK_RCNNリポジトリをclone(https://github.com/matterport/Mask_RCNN)
    pip install -r requirements.txtで必要ライブラリのインストール

  2. pycocotoolsのインストール
    cocoリポジトリをclone(https://github.com/waleedka/coco)
    PythonAPIディレクトリでpython setup.py build_ext install

  3. matterport/MASK_RCNNのインストール
    MASK_RCNNの最上位ディレクトリでpython setup.py install


学習実行

モデル訓練のコードです。cocoデータセット学習済み重みをモデルへロードした後、DatasetクラスとConfigクラスをモデルへ渡して訓練開始。
DatasetクラスとConfigクラスの詳細は後述

    import os

    import mrcnn.model as modellib
    from mrcnn import utils


    TRAIN_DATASET = os.path.join('dataset', 'train')
    dataset_train = OneClassDataset()
    dataset_train.load_dataset(TRAIN_DATASET)
    dataset_train.prepare()

    VALID_DATASET = os.path.join('datset', 'valid')
    dataset_val = OneClassDataset()
    dataset_val.load_dataset(VALID_DATASET)
    dataset_val.prepare()

    config = OneClassConfig()


    model = modellib.MaskRCNN(mode="training", config=config,
                              model_dir="logs/model")

    COCO_MODEL_PATH = 'mask_rcnn_coco.h5'
    if not os.path.exists(COCO_MODEL_PATH):
        utils.download_trained_weights(COCO_MODEL_PATH)
    
    model.load_weights(COCO_MODEL_PATH, by_name=True,
                       exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
                                "mrcnn_bbox", "mrcnn_mask"])

   #: ネットワークのhead部分のみの訓練
    model.train(dataset_train, dataset_val,
                learning_rate=0.001,
                epochs=10,
                layers='heads')


画像を格納するdatasetディレクトリ構成

#: 対応するimage-maskペアは同じ名前にした
dataset
  |-train
  |   |-image
  |   |   |- 1.jpg
  |   |   |- 2.jpg
  |   |   └ …
  |   |   
  |   └ mask
  |       |- 1.jpg
  |       |- 2.jpg
  |       └ …
  |
  └ valid
      |- image
      └- mask


Configクラスについて

OneClassConfigではmrcnn.configからConfigクラスを継承して訓練の設定を行います。
※バッチ当たりの画像数は増やしすぎるとGPUメモリに乗りません。

from mrcnn.config import Config


class OneClassConfig(Config):

    #: config名
    NAME = "cell_dataset"

    #: バッチあたり画像数 (GPUのメモリが大きいなら増やしてもよい
    IMAGES_PER_GPU = 1

    # クラス数 = 背景 + 検出クラス数
    NUM_CLASSES = 1 + 1

    # エポックあたりステップ数
    STEPS_PER_EPOCH = 50

    VALIDATION_STEPS = 5

    # 提案領域のconfidenceが90%以下なら物体検出フェイズをスキップ
    DETECTION_MIN_CONFIDENCE = 0.9


Datasetクラスについて

OneClassDatasetではmrcnn.utilsからDatasetクラスを継承して画像読み込みメソッドの定義ととmask生成のメソッドのオーバーライドを行います。

import pathlib

import cv2
from PIL import Image
from mrcnn import utils
from mrcnn.model import log


class OneClassDataset(utils.Dataset):

    def load_dataset(self, dataset_dir):
        """ データセットを登録
        """
        #: データセット名、クラスID、クラス名
        self.add_class('cell_dataset', 1, 'cell')

        images = glob.glob(os.path.join(dataset_dir, "image", "*.jpg"))
        masks = glob.glob(os.path.join(dataset_dir, "mask", "*.jpg"))

        for image_path, mask_path in zip(images, masks):
            image_path = pathlib.Path(image_path)
            mask_path = pathlib.Path(mask_path)
            assert image_path.name == mask_path.name, 'データセット名不一致'

            image = Image.open(image_path)
            height = image.size[0]
            width = image.size[1]

            mask = Image.open(mask_path)
            assert image.size == mask.size, "サイズ不一致"

            self.add_image(
                'cell_dataset',
                path=image_path,
                image_id=image_path.stem,
                mask_path=mask_path,
                width=width, height=height)

    def load_mask(self, image_id):
        """マスクデータとクラスidを生成する
        """.
        image_info = self.image_info[image_id]
        if image_info["source"] != 'cell_dataset':
            return super(self.__class__, self).load_mask(image_id)

        mask_path = image_info['mask_path']
        mask, cls_idxs = blob_detection(str(mask_path))

        return mask, cls_idxs

    def image_reference(self, image_id):
        """Return the path of the image."""
        info = self.image_info[image_id]
        if info["source"] == 'cell_dataset':
            return info
        else:
            super(self.__class__, self).image_reference(image_id)


def blob_detection(mask_path):
    mask = cv2.imread(mask_path, 0)
   #: 念のためもう一度二値化
    _, mask = cv2.threshold(mask, 150, 255, cv2.THRESH_BINARY)

    label = cv2.connectedComponentsWithStats(mask)
    data = copy.deepcopy(label[1])

    labels = []
    for label in np.unique(data):
        #: ラベル0は背景
        if label == 0:
            continue
        else:
            labels.append(label)

    mask = np.zeros((mask.shape)+(len(labels),), dtype=np.uint8)

    for n, label in enumerate(labels):
        mask[:, :, n] = np.uint8(data == label)

    cls_idxs = np.ones([mask.shape[-1]], dtype=np.int32)

    return mask, cls_idxs

load_datasetメソッドではデータセットを読み込んでself.add_image()によってデータセットとして登録します。これはオーバライドでなく追加メソッドなので好きな名前でOKです。

add_image()で必須のなのはデータセット名、元画像へのパス、任意の画像IDだけですがload_mask()メソッドでマスク生成するのに必要なのでmask画像へのパスも登録しておきます。


load_maskメソッドは親クラスからのオーバーライドです。このメソッドは学習時に呼び出され画像一枚についてのマスク情報とクラス情報を返します。

マスク情報とは、shape == [height, width, 画像内の細胞数] のndarrayであり各チャネルが細胞ひとつについてのマスク画像です。 load_maskメソッドでは[height, width, 1]の二値化マスク画像を受け取り、opencvブロブ検出関数を利用することでマスク情報を生成しています。(参考図)

f:id:horomary:20190728035159j:plain
load_maskメソッドの処理内容

マスク情報と同時にクラスIDのリストも生成します。
今回はすべて同じクラス(cell)なので長さ=画像内の細胞数で要素がすべて1のリストを返すだけです。 (self.add_class('cell_dataset', 1, 'cell')で'cell'のidは1としたため)


適切にマスク情報を生成するようにload_maskメソッドを定義さえできればどのような形式のデータセットでも使用可能なので柔軟性が高いですね。


結果の確認

validation データセットについて結果の確認を行います。

    import random
    from mrcnn import visualize

    VALID_DATASET = os.path.join('dataset', 'valid')
    dataset_val = OneClassDataset()
    dataset_val.load_dataset(VALID_DATASET)
    dataset_val.prepare()

    config = InferenceConfig()
    MODEL_DIR = os.path.join("logs", "model")
    model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
                              config=config)
    
    weights_path = model.find_last()
    print("Loading weights ", weights_path)
    model.load_weights(weights_path, by_name=True)

    
    image_id = random.choice(dataset_val.image_ids)
    image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(dataset_val, config, image_id)
    info = dataset_val.image_info[image_id]
    
    results = model.detect([image], verbose=1)

    # Display results
    ax = get_ax(1)
    r = results[0]
    visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
                                dataset_val.class_names, r['scores'], ax=ax,
                                title="Predictions")
class InferenceConfig(OneClassConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    DETECTION_MIN_CONFIDENCE = 0.5

f:id:horomary:20190728041918j:plain

難易度の高いセグメンテーションではないですが動作テストとしてはokayishな感じですね。

さいごに

より複雑なデータセットに挑む場合、matterport/MASK_RCNNの細胞検出のサンプルを参考にしましょう。

Mask_RCNN/nucleus.py at master · matterport/Mask_RCNN · GitHub

今回のようなネットワークのheadだけのトレーニングだけでなく、バックボーンの特徴抽出器まで含めたトレーニング実行や動的なimage augmentationのやり方などが学べます。