Posted on

TensorFlow目標檢測API – 程式使用範例

事前作業

先在這一篇選定想要使用的模型,然後下載後解壓縮在專案的資料夾內,然後參考這篇文章設定環境: TensorFlow 的目標檢測API – 設定環境

程式實作範例

這個範例需要一個攝影機,使用的是SSD MobileNet V2 FPNLite 640×640的預設資料集

import numpy as np
import tensorflow as tf
import VideoStream
import cv2

from object_detection.utils import label_map_util

# 載入模型
PATH_TO_SAVED_MODEL = "./ssd_mobilenet/saved_model"
# 載入模型
detect_fn = tf.saved_model.load(PATH_TO_SAVED_MODEL)

# 載入標籤
PATH_TO_LABELS = './models/research/object_detection/data/mscoco_label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)

import time

# 設定攝影機
videostream = VideoStream.VideoStream((1920, 1080), 30, 0).start()
cam_quit = 0
while cam_quit == 0:
    imageSource = videostream.read()
    smallImg = cv2.resize(imageSource, (1920//5, 1080//5))
    input_tensor = np.expand_dims(smallImg, 0)
    start_time = time.time()
    detections = detect_fn(input_tensor)
    end_time = time.time()
    num_detections = int(detections.pop('num_detections'))
    elapsed_time = end_time - start_time
    print('Done! Took {} seconds'.format(elapsed_time))
    detections = {key: value[0, :num_detections].numpy()
                  for key, value in detections.items()}
    for detection_boxes, detection_classes, detection_scores in \
            zip(detections['detection_boxes'], detections['detection_classes'], detections['detection_scores']):
        if detection_scores > 0.3:
            y_min_pixel = int(detection_boxes[0] * 1080)
            x_min_pixel = int(detection_boxes[1] * 1920)
            y_max_pixel = int(detection_boxes[2] * 1080)
            x_max_pixel = int(detection_boxes[3] * 1920)
            cv2.rectangle(imageSource, (x_min_pixel, y_min_pixel), (x_max_pixel, y_max_pixel), (255, 0, 0), 2)
    cv2.imshow('frame', imageSource)
    key = cv2.waitKey(1) & 0xFF
    if key == ord("q"):
        cam_quit = 1

cv2.destroyAllWindows()
videostream.stop()

執行的速度很不錯,一秒可以有25FPS,適合用於實時串流

執行的畫面如下,因為懶惰所以還沒有把標記的label標上去,若想要標記,可直接使用detections['detection_classes']作為分類的index,從category_index去取分類名稱