IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 《Pytorch 模型推理及多任务通用范式》第二节作业 -> 正文阅读

[人工智能]《Pytorch 模型推理及多任务通用范式》第二节作业

1 课程学习

本节课主要对于大白AI课程:https://mp.weixin.qq.com/s/STbdSoI7xLeHrNyLlw9GOg
《Pytorch 模型推理及多任务通用范式》课程中的第二节课进行学习。

2 作业题目

  • 必做题:

(1) 从 torchvision 中加载 resnet18 模型结构,并载入预训练好的模型权重 ‘resnet18-5c106cde.pth’ (在物料包的 weights 文件夹中)。
(2) 将(1)中加载好权重的 resnet18 模型,保存成 onnx 文件。
(3) 以 torch.rand([1,3,224,224]).type(torch.float32)作为输入,求 resnet18 的模型计算量和参数量。

1.82 GFLOPS and 11.69M parameters

(4) 以 torch.rand([1,3,448,448]).type(torch.float32)作为输入,求 resnet18 的模型计算量和参数量。

7.27 GFLOPS and 11.69M parameters

import torch
import torchvision.models as models
from thop import profile
# 加载模型结构(计算图)
model = models.resnet18()
# 读取权重
pretrained_state_dict = torch.load('./weights/resnet18-5c106cde.pth')
# 载入权重到模型
model.load_state_dict(pretrained_state_dict,strict=True)

model.to(device='cuda')
# 模型改为推理状态,,是因为网络中的 dropout 层和 batchnorm层在训练状态和推理状态的处理方式是不同的 
model.eval()
# 构建一个项目推理时需要的输入大小的单精度tensor,并且放置模型所在的设备
inputs1 = torch.ones([1,3,224,224]).type(torch.float32).to(torch.device('cuda'))
# 生成onnx
torch.onnx.export(model, inputs1, './weights/resnet18.onnx',verbose=True)

# 不同输入大小,模型计算量和参数量
flops1, params1 = profile(model=model, inputs=(inputs1,) )
inputs2 = torch.ones([1,3,448,448]).type(torch.float32).to(torch.device('cuda'))
flops2, params2 = profile(model=model, inputs=(inputs2,))
print('inputs1 Model: {:.2f} GFLOPS and {:.2f}M parameters'.format(flops1/1e9, params1/1e6))
print('inputs2 Model: {:.2f} GFLOPS and {:.2f}M parameters'.format(flops2/1e9, params2/1e6))
  • 思考题:

(1) 比较必做题中的(3)和(4)的结果,有什么规律?

模型参数量与输入大小无关
模型计算量,输入尺寸越大,计算量越大,计算量约随着图片尺寸成倍增长

(2) 尝试用 netron 可视化 resnet18 的 onnx 文件
在这里插入图片描述

(3) model 作为 torch.nn.Module 的子类,除了用 model.state_dict()查看网络层外,还可以用model.named_parameters()和 model.parameters()。它们三儿有啥不同?
参考《pytorch中state_dict()和named_parameters()有何差别》

它们的差异主要体现在3方面:
(1) 返回值类型不同
(2 )存储的模型参数的种类不同
(3) 返回的值的require_grad属性不同
第一,这很简单,model.state_dict()是将layer_name : layer_param的键值信息存储为dict形式,而model.named_parameters()则是打包成一个元祖然后再存到list当中;
第二,model.state_dict()存储的是该model中包含的所有layer中的所有参数;而model.named_parameters()则只保存可学习、可被更新的参数,model.buffer()中的参数不包含在model.named_parameters()中
最后,model.state_dict()所存储的模型参数tensor的require_grad属性都是False,而model.named_parameters()的require_grad属性都是True

  • model.state_dict返回的是一个字典,分别对应模型中的名字和参数
import torch
import torchvision.models as models
# 加载模型结构(计算图)
model = models.resnet18()
model_state_dict = model.state_dict()
for key, value in model_state_dict.items():
    print('prame name:', key)
    print('param shape:', value)
    print('-'*10)

输出:
在这里插入图片描述

  • model.named_parameters()返回的是一个迭代器,该方法可以输出模型的参数和该参数对应层的名字
model_named_param = model.named_parameters()
for i in model_named_param:
    print(i)

输出:
在这里插入图片描述

for layer, paramters in model_named_param:
    print('model_named_param layer:', layer)
    print('model_named_param paramters:', paramters.shape)
    print('_'*10)

输出:
![在这里插入图片描述](https://img-blog.csdnimg.cn/0c8c68ccd6be42d3bd69ce19698c99d5.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBAbHVja3kteHh5eXh4,size_20,color_FFFFFF,t_70,g_se,x_1model.parameters()返回的是一个迭代器,只包含输出模型的参数

model_parameters = model.parameters()
for i in model_parameters:
    print(i)

输出:
在这里插入图片描述

(4) 加载模型权重时用的的 model.load_state_dict(字典, strict=True),里面的 strict 参数什么情况下要赋值 False?

将权重数据从文件中加载到模型中时,如果参数不完全对应,那么必须传入参数strict=False,否则程序报错。strict=True必须要保证两者的参数必须完全一致。如果参数不完全一致,并且strict=False时,函数返回参数匹配失败的信息,包括missing_keys(表示模型中存在但不在权重文件中的参数)和unexpected_keys(表示不出现在模型中但是出现在权重文件中的参数)。

讨论PyTorch中模型加载时参数不一致的情况

model.load_state_dict(state_dict, strict=False)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-17 12:00:06  更:2021-10-17 12:02:13 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 6:37:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码