Detectron2 を使ってみよう:Pre-Trainedモデルの使い方説明

こんにちは!うしじです。

Detectron2いいですね。さすがFacebook AI、いい仕事してます。

今回は、Detectron2を使ってみましたので、その使い方について書きたいと思います。


Detectron2 とは

Detectron2とは、Facebook AIが開発した、PyTorchベースの物体検出のライブラリです。 様々なモデルとそのPre-Trainedモデルが実装されており、下記のように、Bounding boxやInstance Segmentation等の物体検出を簡単に実装することができます。

Detectron2 Overview



Detectron2の詳細は、下記のFacebook AIのブログやDetectron2のGitHubをご確認ください。



チュートリアル

Detectron2 には、Google Colab上にわかりやすいチュートリアルがあります。 チュートリアルでは、下記の内容を学ぶことができます。

  • Detectron2のインストール
  • Pre-Trainedモデルの利用
  • カスタムデータセットでのトレーニング
  • 様々なモデルの利用
  • 動画への適用


Detectron2 を使ってみよう

チュートリアルの内容と被るところも多いですが、本記事では、様々なPre-Trainedモデルを用いて、自分で用意した画像で物体検出を行います。

また、環境としては、Google Colabのnotebook上で実行する想定で、コードを記載しています。

では、やってみましょう!


インストール、セットアップ

今回は、物体検出対象の画像は、Google Drive上に置きます。ColabでGoogle Drive上のデータを読み込む方法については、この記事を参照してください。


まずは、Google Driveをマウントしておきます。
また、対象の画像もGoogle Driveに格納しておきましょう。

from google.colab import drive
drive.mount('/content/drive')


次に、PyTorchやDetectron2等の必要なモジュールをインストール、インポートします。

# install dependencies: (use cu101 because colab has CUDA 10.1)
!pip install -U torch==1.5 torchvision==0.6 -f https://download.pytorch.org/whl/cu101/torch_stable.html 
!pip install cython pyyaml==5.1
!pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())
!gcc --version
# opencv is pre-installed on colab

# install detectron2:
!pip install detectron2==0.1.3 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.5/index.html

# You may need to restart your runtime prior to this, to let your installation take effect
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import cv2
import random
from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog



画像のロード

今回利用する画像をロードします。"/content/drive/My Drive/Colab Data/image.jpg" のところは、自身がアップロードした画像のパスとファイル名に変更してください。

im = cv2.imread("/content/drive/My Drive/Colab Data/image.jpg")
cv2_imshow(im)

今回はこの画像です。(Free-PhotosによるPixabayからの画像) Input image




Faster R-CNN

ここから、Pre-Trainedモデルを用いて推論していきます。
Pre-Trainedモデルは、Detectron2のModel Zooから探せます。


まずは、Faster R-CNNを用いてBounding Boxを出してみます。 具体的なモデルは、"COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" です。


# Detectron2のコンフィグとモデル固有のコンフィグを読み込みます。
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))

# threshold(閾値)を設定します。この閾値より確度の高いもののみ出力されます。
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7

# 今回利用するFaster R-CNNのトレーニング済みファイルを読み込みます。
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")

# 推論を実行します。
predictor = DefaultPredictor(cfg)
outputs = predictor(im)

# 結果を表示します。
v = Visualizer(im[:,:,::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_imshow(v.get_image()[:, :, ::-1])


推論結果です。うまく認識できていますね。 Faster R-CNN Result



次に、Thresholdを0.1に変更してみます。

# Thresholdを0.1に変更
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
outputs = predictor(im)
v = Visualizer(im[:,:,::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_imshow(v.get_image()[:, :, ::-1])


先ほどと違い、確度の低い、AIが推論に自信の無いものでも表示されており、ごちゃごちゃしているのがわかります。 Faster R-CNN Result




Instance Segmentation

Instance Segmentationでは、物体の認識をピクセルレベルで実施しており、Bounding Boxだけでなく、対象物の領域も認識することができます。
利用するファイルを変えるだけで、コードは、Faster R-CNNと同じで大丈夫です。

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")

predictor = DefaultPredictor(cfg)
outputs = predictor(im)

v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_imshow(v.get_image()[:, :, ::-1])

Instance Segmentation Result




Keypoint Detection

Keypoint Detectionでは、人の姿勢を認識することができます。
コードも、Instance Segmentationと同様に、利用するファイルを変えるだけで大丈夫です。

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")

predictor = DefaultPredictor(cfg)
outputs = predictor(im)

v = Visualizer(im[:,:,::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_imshow(v.get_image()[:, :, ::-1])

Keypoint detection Result




Panoptic Segmentation

Panoptic Segmentationでは、全てのピクセルに対してラベルを振っており、Instance Segmentation等では出力されなかった、壁や天井といったものも分類されています。 Panoptic Segmentationのコードは、これまでと異なり、Thresholdや出力、表示のコードがこれまでと少し違いますので、注意してください。

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
predictor = DefaultPredictor(cfg)
panoptic_seg, segments_info = predictor(im)["panoptic_seg"]
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info)
cv2_imshow(v.get_image()[:, :, ::-1])

Panoptic Segmentation Result



いかがでしたでしょうか。Detectron2では、多くのコードを書くことなく、対象ファイルを切り替えるだけで、様々なモデルを利用可能です。
Detectron2を使って、今後何か作ってみたいですね。




PyTorchについて学びたい人におすすめの書籍を紹介させていただきます。


つくりながら学ぶ! PyTorchによる発展ディープラーニング