pytorch自带有一些高级的复杂模型,我们可以通过 torchvision.models 调用,例如torchvision.models.densenet169(pretrained=True) 就调用了densenet169的预训练模型。
1 查看网络结构
model = models.densenet169()
print(model)
out: 结构很长,所以只截图了最后一部分,我们可以发现这个模型的输出单元有1000个,也就是说这是一个1000分类的模型。
2 修改模型(分类数目)
由上一章我们得知这是一个1000分类的模型,那么如果我想训练一个二分类的模型,该如何进行修改?
我们可以看到,最后一层的名字叫做 classifier ,因为我们可以通过引用最后一层进行修改:
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, 2)
print(model)
就好了: 可以看到最后一层输出单元数变成2了。
同理,我要是想修改中间的某一层,只要我知道它的名称,就可以进行修改了。
3 读取模型参数
和正常模型一样,用 model.state_dict() 读取所有参数。
model.state_dict()
out: 需要注意的是,他这里生产的词典的 key 值,有特殊的命名方式,如果我想查看第2章中最后一层的 bias,可以这样操作:
4 保存模型
见保存pytorch模型
参考: https://blog.csdn.net/weixin_36670529/article/details/105910572 https://blog.csdn.net/whut_ldz/article/details/78845947 https://blog.csdn.net/u012494820/article/details/79068625
|