1. 数据模块接口
1.1 简介
- dataset_factory:数据模块工厂路由,找到不同的数据集读取逻辑(供训练调用)
- dataset_init:保存不同数据集的TFRecords格式读取功能
- utils:数据模块的共用组件(通用的函数、类)
- dataset_config:数据模块的一些数据集配置文件
- data_to_tfrecords:原始数据集格式转换逻0辑
1.2 代码编写
import os
import tensorflow as tf
from datasets.utils import dataset_utils
slim = tf.contrib.slim
class CommodityTFRecords(dataset_utils.TFRecordsReaderBase):
"""商品读取类
"""
def __init__(self, param):
self.param = param
def get_data(self, train_or_test, dataset_dir):
"""商品数据集获取
:param tran_or_test: 训练数据还是测试数据
:param dataset_dir: 数据集存放的目录.
"""
if train_or_test not in ['train', 'test']:
raise ValueError('训练测试数据集名称 % 指定有误' % train_or_test)
if not tf.gfile.Exists(dataset_dir):
raise ValueError("数据集目录不存在")
file_pattern = os.path.join(dataset_dir, self.param.FILE_PATTERN % train_or_test)
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature([1], tf.int64),
'image/width': tf.FixedLenFeature([1], tf.int64),
'image/channels': tf.FixedLenFeature([1], tf.int64),
'image/shape': tf.FixedLenFeature([3], tf.int64),
'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
'shape': slim.tfexample_decoder.Tensor('image/shape'),
'object/bbox': slim.tfexample_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=self.param.SPLITS_TO_SIZES[train_or_test],
items_to_descriptions=self.param.ITEMS_TO_DESCRIPTIONS,
num_classes=self.param.NUM_CLASSES)
from datasets.dataset_init import commodity_2018
from datasets.dataset_config import Cm2018
datasets_map = {
'commodity_2018': commodity_2018.CommodityTFRecords
}
def get_dataset(name, train_or_test, dataset_dir):
"""
获取不同数据集数据
:param name: 数据集名字
:param train_or_test:
:param dataset_dir:
:return:
"""
if name not in datasets_map:
raise ValueError('数据集不支持转换 %s' % name)
return datasets_map[name](Cm2018).get_data(train_or_test, dataset_dir)
2. 模型接口
- nets_factory的实现
···python from nets.network import ssd_vgg_300
slim = tf.contrib.slim
networks_obj = { ‘ssd_300_vgg’: ssd_vgg_300.SSDNet, }
def get_network(name): “”“获取模型网络实例 “”” return networks_obj[name] ···
3. 预处理
3.1 预处理需求
- 在图像的深度学习中,对输入数据进行数据增强(Data Augmentation),为了丰富图像训练集,更好的提取图像特征,泛化模型(防止模型过拟合)
- 还有一个最根本的目的就是要把图片变成符合大小要求
- RCNN输入图片没有要求,但是网络当中卷积之前需要227 x 227大统一大小
- YOLO算法:输入图片大小变换为448 x 448
- SSD算法:输入图片大小变换为300 x 300
3.2 数据增强
- 指通过剪切、旋转/反射/翻转变换、缩放变换、平移变换、尺度变换、对比度变换、噪声扰动、颜色变换等一种或多种组合数据增强变换的方式来增加数据集的大小。
3.3 代码编写
import tensorflow as tf
from preprocessing import ssd_vgg_preprocessing
slim = tf.contrib.slim
preprocessing_fn_map = {
"ssd_vgg_300": ssd_vgg_preprocessing
}
def get_preprocessing(name, is_training=True):
"""预处理工厂获取不同的模型数据增强(预处理)方式
:param name: 模型预处理名称
:param is_training: 是否训练
:return: 返回预处理的函数
"""
if name not in preprocessing_fn_map:
raise ValueError("选择的预处理名称 %s 不在预处理模型库当中,请提供该模型预处理代码" % name)
def preprocessing_fn(image, labels, bboxes,
out_shape, data_format='NHWC', **kwargs):
return preprocessing_fn_map[name].preprocess_image(image, labels, bboxes,
out_shape, data_format=data_format,
is_training=is_training, **kwargs)
return preprocessing_fn
|