IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 目标检测学习记录(三):项目实现 -> 正文阅读

[人工智能]目标检测学习记录(三):项目实现

1. 数据模块接口

1.1 简介

  • dataset_factory:数据模块工厂路由,找到不同的数据集读取逻辑(供训练调用)
  • dataset_init:保存不同数据集的TFRecords格式读取功能
  • utils:数据模块的共用组件(通用的函数、类)
  • dataset_config:数据模块的一些数据集配置文件
  • data_to_tfrecords:原始数据集格式转换逻0辑

1.2 代码编写

  • CommodityTFRecords类
import os
import tensorflow as tf
from datasets.utils import dataset_utils

slim = tf.contrib.slim


class CommodityTFRecords(dataset_utils.TFRecordsReaderBase):
    """商品读取类
    """
    def __init__(self, param):
        self.param = param

    def get_data(self, train_or_test, dataset_dir):
        """商品数据集获取
        :param tran_or_test: 训练数据还是测试数据
        :param dataset_dir: 数据集存放的目录.
        """
        if train_or_test not in ['train', 'test']:
            raise ValueError('训练测试数据集名称 % 指定有误' % train_or_test)

        if not tf.gfile.Exists(dataset_dir):
            raise ValueError("数据集目录不存在")

        # 构造第一个参数:数据目录+文件名
        file_pattern = os.path.join(dataset_dir, self.param.FILE_PATTERN % train_or_test)

        # 准备第二个参数:
        reader = tf.TFRecordReader

        # 准备第三个参数:decoder
        # 反序列化的格式
        keys_to_features = {
            'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
            'image/height': tf.FixedLenFeature([1], tf.int64),
            'image/width': tf.FixedLenFeature([1], tf.int64),
            'image/channels': tf.FixedLenFeature([1], tf.int64),
            'image/shape': tf.FixedLenFeature([3], tf.int64),
            'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
            'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
            'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
        }
        # 2、反序列化成高级的格式
        # 其中bbox框ymin [23] xmin [46],ymax [234] xmax[12]--->[23,46,234,13]
        items_to_handlers = {
            'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
            'shape': slim.tfexample_decoder.Tensor('image/shape'),
            'object/bbox': slim.tfexample_decoder.BoundingBox(
                ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
            'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
            'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
            'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
        }

        # 构造decoder
        decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)

        return slim.dataset.Dataset(
            data_sources=file_pattern,
            reader=reader,
            decoder=decoder,
            num_samples=self.param.SPLITS_TO_SIZES[train_or_test],
            items_to_descriptions=self.param.ITEMS_TO_DESCRIPTIONS,
            num_classes=self.param.NUM_CLASSES)
  • dataset_factory中加入:
from datasets.dataset_init import commodity_2018


from datasets.dataset_config import Cm2018

datasets_map = {
    'commodity_2018': commodity_2018.CommodityTFRecords
}


def get_dataset(name, train_or_test, dataset_dir):
    """
    获取不同数据集数据
    :param name: 数据集名字
    :param train_or_test:
    :param dataset_dir:
    :return:
    """
    if name not in datasets_map:
        raise ValueError('数据集不支持转换 %s' % name)

    return datasets_map[name](Cm2018).get_data(train_or_test, dataset_dir)

2. 模型接口

  • nets_factory的实现
    ···python
    from nets.network import ssd_vgg_300

slim = tf.contrib.slim

networks_obj = {
‘ssd_300_vgg’: ssd_vgg_300.SSDNet,
}

def get_network(name):
“”“获取模型网络实例
“””
return networks_obj[name]
···


3. 预处理

3.1 预处理需求

  • 在图像的深度学习中,对输入数据进行数据增强(Data Augmentation),为了丰富图像训练集,更好的提取图像特征,泛化模型(防止模型过拟合)
  • 还有一个最根本的目的就是要把图片变成符合大小要求
    • RCNN输入图片没有要求,但是网络当中卷积之前需要227 x 227大统一大小
    • YOLO算法:输入图片大小变换为448 x 448
    • SSD算法:输入图片大小变换为300 x 300

3.2 数据增强

  • 指通过剪切、旋转/反射/翻转变换、缩放变换、平移变换、尺度变换、对比度变换、噪声扰动、颜色变换等一种或多种组合数据增强变换的方式来增加数据集的大小。
    在这里插入图片描述

3.3 代码编写

import tensorflow as tf
from preprocessing import ssd_vgg_preprocessing

slim = tf.contrib.slim

preprocessing_fn_map = {
    "ssd_vgg_300": ssd_vgg_preprocessing
}

def get_preprocessing(name, is_training=True):
    """预处理工厂获取不同的模型数据增强(预处理)方式
    :param name: 模型预处理名称
    :param is_training: 是否训练
    :return: 返回预处理的函数
    """
    if name not in preprocessing_fn_map:
        raise ValueError("选择的预处理名称 %s 不在预处理模型库当中,请提供该模型预处理代码" % name)

    # 返回一个处理的函数,后续再去调用这个函数
    def preprocessing_fn(image, labels, bboxes,
                         out_shape, data_format='NHWC', **kwargs):
        return preprocessing_fn_map[name].preprocess_image(image, labels, bboxes,
                                                           out_shape, data_format=data_format,
                                                           is_training=is_training, **kwargs)
    return preprocessing_fn
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-11-05 00:28:48  更:2022-11-05 00:29:37 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 3:21:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计