预训练的模型可以很好的学到特征,我们想利用预训练的模型再加上分类器,实现分类任务,需要在自己的模型上添加新的层,直接使用 model.load_state_dict(state_dict) 会报错:
RuntimeError: Error(s) in loading state_dict for Model: ?? ?Missing key(s) in state_dict
可以逐个的查找自己的模型如果和预训练的模型有相同的参数名称,那么就复制过来,这样就可以避免直接全部加载出现的错误。
如果预训练模型中有多余的网络结构,可以用del 删去。
# 自己的模型
model = Model()
#加载预训练的模型
checkpoint = torch.load('checkpoint_{:04d}.pth.tar'.format(args.epochs))
state_dict = checkpoint['state_dict']
#删去不必要的结构
for k in list(state_dict.keys()):
# 删去预训练模型中mlp开头的所有结构
if k.startswith('mlp'):
del state_dict[k]
for name, param in model.named_parameters():
# print(name,param.size())
# 将所有不是predictor开头的结构冻结
if not (name.startswith('predictor')):
param.requires_grad = False
# 查看model初始化的参数
print("original param: {} {} ".format(name,param))
# 从预训练模型中加载
if name in list(state_dict.keys()):
param = state_dict[name]
print("new param from checkpoint: {} {} ".format(name,param))
|