IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【Pytorch深度学习50篇】·······第五篇:【YOLO】【3】-----训练篇 -> 正文阅读

[人工智能]【Pytorch深度学习50篇】·······第五篇:【YOLO】【3】-----训练篇

一周过去了,我赶在一周的尾巴上来继续写文章了,上周不算忙,至少不用去驻场了,日子好过多了。不过苦了小兄弟们了,他们去驻场了。

来这个现在这个公司快两年了,大概是今年5月份开始,我们一个初创公司开始和这个行业的巨头开始对着干,老板也是真相信我们啊,别人公司什么体量,我们公司什么体量,根本不是一个数量级,就这样还初生牛犊不怕虎,他不怕,我怕,但是又能怎么着呢,环境就是这么残酷,你不硬着头皮上,就只能灰溜溜的走。负重前行吧。明知山有虎,偏向虎上行的精神。

好了,废话说多了,回归主题,YOLO的训练,之前我们网络模型和数据准备都完事了,所以今天的任务就是把它训练起来,我们先上训练代码

3.训练

from models.yolov3 import yolov3
from models.yolov5 import yolov5s
from models.yolov4 import yolov4
from torch.utils.data import DataLoader
import torch

import config
import dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Yolo_Train:
    def __init__(self, img_file_path, anno_file_path, model_size='small', pretrain_flag=True, data_augmentation=False):
        self.class_num = len(config.class_name)
        self.net, self.save_name = self.select_model(model_size, pretrain_flag)
        self.train_data = dataset.Yolo_Dataset(img_file_path, anno_file_path,data_augmentation)
        self.train_loader = DataLoader(self.train_data, batch_size=config.hy_params['batch_size'], shuffle=True,
                                       num_workers=2)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=config.hy_params['lr'])
        self.epoches = config.hy_params['epoch']

    def train(self):
        self.net.train()
        for epoch in range(self.epoches):
            losss = 0
            for index, (img_tensor, label_32, label_16, label_8) in enumerate(self.train_loader):
                label_32 = label_32.to(device)
                label_16 = label_16.to(device)
                label_8 = label_8.to(device)
                img = img_tensor.to(device)

                output32, output16, output8 = self.net(img)
                loss32 = self.loss_function(output32, label_32, 0.84)
                loss16 = self.loss_function(output16, label_16, 0.96)
                loss8 = self.loss_function(output8, label_8, 0.96)

                loss = loss32 + loss16 + loss8
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                losss += loss.item()
                print('epoch: [%s/%s] ----iteration[%s/%s]----------- loss:%.8f' % (
                    epoch + 1, self.epoches, index + 1, len(self.train_loader), losss / (index + 1)))
            torch.save(self.net, './weights/' + self.save_name)

    def loss_function(self, output, target, alpha):
        # output = [batch,21,13,13]    target = [batch,13,13,3,7]
        output = output.permute(0, 2, 3, 1)  # [batch,13,13,21]
        output = output.reshape(output.size(0), output.size(1), output.size(2), 3, -1)  # [batch,13,13,3,7]

        mask_obj = target[..., 0] > 0  # 置信度大于0,有目标
        mask_noobj = target[..., 0] == 0  # 没有目标

        loss_obj = torch.mean((output[mask_obj] - target[mask_obj]) ** 2)
        loss_noobj = torch.mean((output[mask_noobj][:, 0] - target[mask_noobj][:, 0]) ** 2)

        loss = loss_obj * alpha + loss_noobj

        return loss


    def select_model(self, model_size, pretrain_flag):
        assert model_size in ['small', 'middle', 'big']
        if model_size == 'small':
            if pretrain_flag:
                print('加载预训练模型')
                net = torch.load('./weights/yolov5s_net.pth').to(device)
            else:
                net = yolov5s.YOLO(nc=self.class_num).to(device)

            return net, 'yolov5s_net.pth'

        elif model_size == 'middle':
            if pretrain_flag:
                print('加载预训练模型')
                net = torch.load('./weights/yolov4_net.pth').to(device)
            else:
                net = yolov4.yolo4(num_class=self.class_num).to(device)
            return net, 'yolov4_net.pth'

        elif model_size == 'big':
            if pretrain_flag:
                print('加载预训练模型')
                net = torch.load('./weights/yolov3_net.pth').to(device)
            else:
                net = yolov3.yolo3(num_class=self.class_num).to(device)
            return net, 'yolov3_net.pth'


if __name__ == '__main__':
    img_file_path = r'D:\DATAS\face_mask\JPEGImages'
    anno_file_path = r'D:\DATAS\face_mask\Annotations'
    Trainer = Yolo_Train(img_file_path, anno_file_path,model_size='small')
    Trainer.train()

我这里把之前提到的v3,v4,v5模型都涉及进来了。训练的时候我们可以通过model_size来更改模型

我们先来看初始化函数

def __init__(self, img_file_path, anno_file_path, model_size='small', pretrain_flag=True, data_augmentation=False):
        self.class_num = len(config.class_name)
        self.net, self.save_name = self.select_model(model_size, pretrain_flag)
        self.train_data = dataset.Yolo_Dataset(img_file_path, anno_file_path,data_augmentation)
        self.train_loader = DataLoader(self.train_data, batch_size=config.hy_params['batch_size'], shuffle=True,
                                       num_workers=2)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=config.hy_params['lr'])
        self.epoches = config.hy_params['epoch']

我们定义了类别的数量、选择了网络模型、定义了dataloader(也就是上一篇文章的数据准备)、定义了优化器和训练的epoch数量。

选择网络模型我也定义了一个函数

def select_model(self, model_size, pretrain_flag):
        assert model_size in ['small', 'middle', 'big']
        if model_size == 'small':
            if pretrain_flag:
                print('加载预训练模型')
                net = torch.load('./weights/yolov5s_net.pth').to(device)
            else:
                net = yolov5s.YOLO(nc=self.class_num).to(device)

            return net, 'yolov5s_net.pth'

        elif model_size == 'middle':
            if pretrain_flag:
                print('加载预训练模型')
                net = torch.load('./weights/yolov4_net.pth').to(device)
            else:
                net = yolov4.yolo4(num_class=self.class_num).to(device)
            return net, 'yolov4_net.pth'

        elif model_size == 'big':
            if pretrain_flag:
                print('加载预训练模型')
                net = torch.load('./weights/yolov3_net.pth').to(device)
            else:
                net = yolov3.yolo3(num_class=self.class_num).to(device)
            return net, 'yolov3_net.pth'

可以加载yolov3、yolov4、yolov5的模型文件和预训练模型

然后就是训练函数

def train(self):
        self.net.train()
        for epoch in range(self.epoches):
            losss = 0
            for index, (img_tensor, label_32, label_16, label_8) in enumerate(self.train_loader):
                label_32 = label_32.to(device)
                label_16 = label_16.to(device)
                label_8 = label_8.to(device)
                img = img_tensor.to(device)

                output32, output16, output8 = self.net(img)
                loss32 = self.loss_function(output32, label_32, 0.84)
                loss16 = self.loss_function(output16, label_16, 0.96)
                loss8 = self.loss_function(output8, label_8, 0.96)

                loss = loss32 + loss16 + loss8
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                losss += loss.item()
                print('epoch: [%s/%s] ----iteration[%s/%s]----------- loss:%.8f' % (
                    epoch + 1, self.epoches, index + 1, len(self.train_loader), losss / (index + 1)))
            torch.save(self.net, './weights/' + self.save_name)

又是一顿常规操作,损失函数,我选择了一个不是特别的好的MSE,这个玩意貌似不太好,还应该crossentropy和IOUloss,之后我来更新出来吧。

整个项目我在下一篇的推理的时候,会完全给大家开源出来的。

至此,敬礼,solute!!!!

老规矩,上咩咩。

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-05 12:02:52  更:2021-12-05 12:02:54 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 0:07:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码