IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Person_reID_baseline_pytorch 源码解析之 train.py -> 正文阅读

[人工智能]Person_reID_baseline_pytorch 源码解析之 train.py

脚本 train.py 是用来训练模型的脚本,训练模型首先需要载入数据集,然后开始训练过程,训练完成后可以根据训练结果绘制 loss 曲线图,并保存训练好的模型参数。本文将按照训练模型的流程,分别解析对应步骤的代码。

1. 载入数据集

通过执行数据处理脚本 prepare.py ,我们已经将数据集组织成了 datasets.ImageFolder 可以直接使用的数据集结构。要想将数据集载入模型还需要将数据集张量化并生成数据集迭代器。

1.1 数据集张量化

使用 datasets.ImageFolder 可以将图片格式的数据集变为 pytorch 支持的张量 tensor ,如果对 transform 参数进行设置,则会对数据集的图片进行数据增强等变换。

调用 datasets.ImageFolder 后生成了 pytorch 支持的数据集 image_datasets[‘train’] 和 image_datasets[‘val’] 。

image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                          data_transforms['val'])

可以通过 pytorch 的 transforms 库引入 transform,针对训练集和测试集进行不同的 transform 变化

from torchvision import datasets, transforms
transform_train_list = [
        #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
        transforms.Resize((h, w), interpolation=3),
        transforms.Pad(10),
        transforms.RandomCrop((h, w)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]

transform_val_list = [
        transforms.Resize(size=(h, w),interpolation=3), #Image.BICUBIC
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        
data_transforms = {
    'train': transforms.Compose(transform_train_list),
    'val': transforms.Compose(transform_val_list),
}

1.2 数据集迭代器

训练模型时,一般不会一次性把所有数据都加载到模型中。通常采用 mini_batch 的方法,按照 batchsize 的大小将一个 batch 的数据载入到模型中。pytorch 框架支持用 torch.utils.data.DataLoader 作为 dataloader 载入数据。

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=True, num_workers=0, pin_memory=True) # 8 workers may work faster
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

将 image_datasets[‘train’] 和 image_datasets[‘val’] 输入 torch.utils.data.DataLoader 后,获得了两个迭代器 dataloaders[‘train’] and dataloaders[‘val’] 。

下面来介绍一下 torch.utils.data.DataLoader 的主要参数

class torch.utils.data.DataLoader(dataset, 
								batch_size=1, 
								shuffle=False, 
								sampler=None, 
								num_workers=0, 
								collate_fn=<function default_collate>, 
								pin_memory=False, 
								drop_last=False)

torch.utils.data.DataLoader 将返回一个数据迭代器。

参数说明:

  • dataset (Dataset) – 加载数据的数据集
  • batch_size (int) – 每个batch加载多少个样本(默认: 1)
  • shuffle (bool) – 设置为True时会在每个epoch重新打乱数据(默认: False)
  • sampler (Sampler) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数
  • num_workers (int) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • drop_last (bool, optional) – 如果数据集大小不能被 batch size 整除,则设置为 True 后可删除最后一个不完整的batch。如果设为 False 并且数据集的大小不能被 batch size 整除,则最后一个batch将更小。(默认: False)

2. 开始训练

在函数 train_model 中,实现了模型训练过程。网络模型一般会迭代多轮以达到一个很好的训练效果,通常通过循环执行一段训练代码来实现迭代训练。

2.1 训练代码

下面对主要的训练代码进行解析:

			# Iterate over data.
            for data in dataloaders[phase]:
                # 载入一个 batch 的输入
                # 数据迭代器返回一个 batch 的图像及其标签
                inputs, labels = data
                now_batch_size,c,h,w = inputs.shape
                if now_batch_size<opt.batchsize: # skip the last batch
                    continue
                # print(inputs.shape)
                # 变量化输入
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
				# 开始训练
                # 将梯度参数置零
                optimizer.zero_grad()
				
				# 前向传播,计算损失
                #-------- forward --------
                outputs = model(inputs)
                # preds 是 softmax 概率最大的类别的索引, 即模型预测的类别
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
				
				# 只在 train 模式下执行,反向传播,梯度下降优化, 
                #-------- backward + optimize -------- 
                # only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

训练过程中,还可以使用 warm_up 等学习率策略。

2.2 模型加载

模型训练过程中,还会涉及到模型加载。在训练模式下,模型的网络参数会发生改变;而在验证模式下,一般不进行梯度下降反向传播等操作,我们希望网络参数保持不变。此时会考虑使用 model.load_state_dict 加载最佳模型参数进行验证。

注意
model.load_state_dict 是深拷贝,可以保证加载的是最佳模型参数
model.state_dict 是浅拷贝,保存的是最后一轮训练的模型参数

另外使用预训练迁移模型的部分层参数时,记得令 strict=False,即
model.load_state_dict(state_dict, strict=False)。strict 默认为 True,表示严格按照名称加载参数,如果出现未定义的名称,就会报错。如果将 strict=False,则会忽略未定义的名称,不会报错。

            # deep copy the model
            if phase == 'val':
                last_model_wts = model.state_dict()
                if epoch%10 == 9:
                    save_network(model, epoch)
                draw_curve(epoch)
            if phase == 'train':
               scheduler.step()
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    #print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, 'last')

3. 结果保存

训练过程中,一般会保存训练好的模型参数,方便下次训练时加载模型。为了监控训练过程,一般还会绘制 loss 曲线。

3.1 模型保存

baseline 通过 torch.save 实现模型参数的保存,具体代码如下:

# Save model
#---------------------------
def save_network(network, epoch_label):
    save_filename = 'net_%s.pth'% epoch_label
    # save_path = os.path.join('./model',name,save_filename)
    save_path = os.path.join('model', name, save_filename)
    torch.save(network.cpu().state_dict(), save_path)
    if torch.cuda.is_available():
        network.cuda(gpu_ids[0])

pytorch 一般使用如下代码实现模型的保存和加载

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

3.2 loss 曲线绘制

使用 pyplot 库可以实现绘图,loss 曲线绘制代码如下:

# Draw Curve
#---------------------------
import matplotlib.pyplot as plt
x_epoch = []
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")
def draw_curve(current_epoch):
    x_epoch.append(current_epoch)
    ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
    ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
    ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
    ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
    if current_epoch == 0:
        ax0.legend()
        ax1.legend()
    # fig.savefig( os.path.join('./model',name,'train.jpg'))
    fig.savefig(os.path.join('model', name, 'train.jpg'))

参考文献

  1. 从零开始行人重识别
  2. Person_reID_baseline_pytorch
  3. torch.max()使用讲解
  4. 源码详解Pytorch的state_dict和load_state_dict
  5. Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点
  6. torch.load_state_dict()函数的用法总结
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-24 07:56:48  更:2021-11-24 07:58:42 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 4:18:44-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码