shartoo +

使用tensorflow object detection api训练自己的数据

本文总阅读量
欢迎star我的博客

一 数据准备

首先,我们有如下数据结构如下:

1.1 images文件

images目录下的文件为:

images目录

### 1.2 标注文件

xml标注文件类似:

xml标注文件

txt标注文件可以不需要。

1.3 label_map.pbtxt文件

xx_label_map.pbtxt文件中的内容如下:

item {
  id: 1
  name: 'Abyssinian'
}

item {
  id: 2
  name: 'american_bulldog'
}

item {
  id: 3
  name: 'american_pit_bull_terrier'
}

1.4 创建tf_record文件

先创建一个create_xx_tf_record.py文件,单独用来处理训练数据。可以直接从object_detection工程下的create_pacal_tf_record.py(如果是每个图片只有一个分类,可以使用create_pet_tf_record.py)复制而来。

修改起始参数配置:

修改dict_to_tf_example

参考你的标准xml文件,有些地方需要修改。

dict_to_tf

修改main

修改main

确保你的标注文件,图片目录对应的目录。标注文件目录下是否存在 trainval.txt文件是否存在,这个需要自己生成。我生成的列表(注意:没有带后缀)为:

trainval文件

执行完之后会在对应目录下生成 tf_record文件。

1.5 创建 .config 配置文件

目录tensorflow\models\object_detection\samples\configs下有各种配置文件,当前工程使用的是 faster_rcnn_inception_resnet_v2_robot.config,将其修改为适应当前数据的配置。

主要修改了这些参数:

2 训练

训练时执行train.py即可。不过需要传入一些参数,可以使用官网的指定方式:

python object_detection/train.py \
    --logtostderr \
    --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \
    --train_dir=${PATH_TO_TRAIN_DIR}

我在pycharm下运行,所以在Run->configigure里面加入参数即可。需要指定的参数是:

--logtostderr --pipeline_config_path=D:/data/robot_auto_seller/config/faster_rcnn_inception_resnet_v2_robot.config --train_dir=D:/data/robot_auto_seller/tf_ckpt

训练完成之后,大概的效果如下:

训练效果

如果训练得当,应该可以用tensorboard查看训练参数变化:

tensorboard

打开浏览器中的: http://localhost:6006/#scalars

tensorboard2

3 转换权重文件

训练完成之后的权重文件大概是会包含如下文件:

我生成的大概为:

ckpt文件

这些文件无法直接使用,eval.py 所使用的权重文件是.pb。需要做一步转换,object_detection工程中已经包含了该工具export_inference_graph.py,运行指令为:

python object_detection/export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path ${PIPELINE_CONFIG_PATH} \
    --trained_checkpoint_prefix ${TRAIN_PATH} \
    --output_directory output_inference_graph.pb

我的脚本为:

--input_type image_tensor --pipeline_config_path D:/data/aa/config/faster_rcnn_inception_resnet_v2_robot.config --trained_checkpoint_prefix D:/data/aa/tf_ckpt/model.ckpt-6359  --output_directory  D:/data/aa/robot_inference_graph

生成的效果为:

pb文件

## 4 预测

预测代码为:

# coding: utf-8
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
import cv2  
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

cap = cv2.VideoCapture(0)
PATH_TO_CKPT = 'D:/data/aa/robot_inference_graph/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('D:/data/aa', 'robot_label_map.pbtxt')
NUM_CLASSES = 3

# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
                                                            use_display_name=True)
category_index = label_map_util.create_category_index(categories)

def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)


# # Detection
PATH_TO_TEST_IMAGES_DIR = 'D:/data/aa/images'
TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, '000{}.jpg'.format(i)) for i in range(109, 115)]
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        #while True:  # for image_path in TEST_IMAGE_PATHS:    #changed 20170825
        # Definite input and output Tensors for detection_graph
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Each box represents a part of the image where a particular object was detected.
        detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Each score represent how level of confidence for each of the objects.
        # Score is shown on the result image, together with the class label.
        detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
        detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')
        for image_path in TEST_IMAGE_PATHS:
            image = Image.open(image_path)
            # the array based representation of the image will be used later in order to prepare the
            # result image with boxes and labels on it.
            image_np = load_image_into_numpy_array(image)
            # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
            image_np_expanded = np.expand_dims(image_np, axis=0)
            # Actual detection.
            (boxes, scores, classes, num) = sess.run(
                [detection_boxes, detection_scores, detection_classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            # Visualization of the results of a detection.
            print(boxes)
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,
                np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),
                category_index,
                use_normalized_coordinates=True,
                line_thickness=8)
            plt.figure(figsize=IMAGE_SIZE)
            cv2.imwrite('D:/data/robot_auto_seller/'+os.path.basename(image_path),image_np)
            plt.imshow(image_np)

此检测过程有两个版本。一个版本是开启摄像头检测,一个版本是直接检测图片。上面这部分代码是检测图片的。修改部分为

使用摄像头检测的例子放在附件中了。

参考: tensorflow 官方教程 浣熊检测(英文) tensorflow 生成pb文件

我的博客

观点

源码