1. 权重参数保存的两种方式
第一种: 将网络模型和对应的参数保存在一起;
第二种: 模型和参数分离, 单独的保存模型的权重参数,
方式二官方推荐, 方便于网络模型修改后, 提取出对应层的参数权重;
1.1 模型和参数保存在一起
- 保存时:
net123 = module.CustomModel()
torch.save( net123,"./weights/All_in.pth")
- 加载模型时:
net123 = torch.load("./weights/All_in.pth")
1.1 单独保存模型的参数-state_dict()
- 保存时:
net123 = module.CustomModel()
torch.save(net123.state_dict(),'./weights/epoch_weight.pth')
- 加载模型时:
net123.load_state_dict(torch.load('epoch_weight.pth'))
2. 移除网络中的层
2.1 思路
具体的逻辑思路:
- 先删除网络所对应的层;
- 然后拷贝一份原始网络权重参数, 去除 步骤1中删除层所对应的权重参数;
2.2 删除网络层
具体的个人实现方式是:
-
拷贝原始网络模型的类, 将拷贝过来的类重命名为自己的类 class CustomModel(nn.Module) , -
在这个重命名类的基础上,进行修改,删除网络层;
其中重点修改该类中两个函数, 1. 初始化函数,2. 前向传播函数:
- 初始化函数部分, 注释掉自己不需要的属性;
- 前向传播函数部分, 注释掉,自己想要删除的那个网络层;
class CustomModle(nn.Module)
1. 初始化函数,`def __init__(self)`:
2. 前向传播函数, `def forward():`
举例讲来,假设这里需要移除 self.patch_embed 层, 那么在前向传播函数中, 便注释掉该层的 前向传播;
def forward_features(self, x):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
2.3 删除网络层所对应的权重参数
- 加载原始的网络权重文件;
- 写出,需要删除网络层所对应的参数名称,
(这一步,可以通过调试到 self.net = your_model() 时,查看网络中每一层所对应的名字)
VisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(norm): Identity()
)
(pos_drop): Dropout(p=0.0, inplace=False)
(blocks): Sequential(
(0): Block
- 通过字典索引的方式,删除层所对应的权重参数;
key : 删除网络层的名称, value : 该层所对应的权重;
if self.args.weights is not None:
assert os.path.exists(self.args.weights), "weights file: '{}' not exist.".format(self.args.weights)
weights_dict = torch.load(self.args.weights, map_location='cuda:0')
del_keys_patch_embed = ['patch_embed.proj.bias', 'patch_embed.proj.weight']
for k in del_keys_patch_embed:
del weights_dict[k]
del_keys = ['head.weight', 'head.bias'] if self.net.has_logits \
else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
for k in del_keys:
del weights_dict[k]
print(self.net.load_state_dict(weights_dict, strict=False))
2.4 另外一种移除权重参数的方法
import torch
from collections import OrderedDict
import os
import torch.nn as nn
import torch.nn.init as init
from xxx import new_VGG
def init_weight(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal(0,0.01)
m.bias.data.zero_()
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith('module'):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k,v in state_dict.items():
name = ','.join(k.split('.')[state_idx:])
new_state_dict[name] = v
return new_state_dict
state_dict = torch.load('/users/xxx/xxx.pth')
new_dict = copyStateDict(state_dict)
keys = []
for k,v in new_dict.items():
if k.startswith('conv_cls'):
continue
keys.append(k)
new_dict = {k:new_dict[k] for k in keys}
net = new_VGG()
net.state_dict().update(new_dict)
torch.save(net.state_dict(), '/users/xxx/xxx.pth')
3. 修改特定层
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
self.model.conv1 = conv1
def forward(self,x):
x = self.model(x)
return x
按照上方的操作,则conv1的预训练权重无法被利用。为了能够利用到conv1的预训练权重,我们沿着dim=1取平局,拓展平均后的权重至与新conv1权重维度一致。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1_weight = torch.mean(self.model.conv1.weight,dim=1,keepdim=True).repeat(1,input_ch,1,1)
conv1 = nn.Conv2d(input_ch, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
model_dict = self.model.state_dict()
self.model.conv1 = conv1
model_dict['conv1.weight'] = conv1_weight
model_dict.update(model_dict)
self.model.load_state_dict(model_dict)
def forward(self,x):
x = self.model(x)
return x
4. 使用网络层所对应的序号
vgg16=tv.models.vgg16(True)
vgg16.features.add_module('extra',nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
vgg16.features[31]=nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
vgg16.features=nn.Sequential(*list(vgg16.features.children())[:-1])
print(vgg16)
vgg16.features[26].weight.requires_grad = False
reference:
|