Pytorch加载模型
当在特殊环境下会出现模型无法通过设置pretrained=True的方式来完成自动加载,这时我们需要完成手动加载。我们假设模型设置为:
class Model(nn.Module):
def __init__(self, num_classes=21):
super(Model, self).__init__()
self.backbone = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
self.fc = nn.Linear(1000, num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.fc(x)
return x
使用timm包下的SwinTransformer作为例子
model = Model()
model_dict = model.state_dict()
print(model_dict.keys()) ## 可以通过.keys()来查看模型结构的state_dict的名称
resume = True
if resume:
print("Resume from checkpoint...")
checkpoint = torch.load('/root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth')
# 首先使用torch.load()加载模型的state_dict
print(checkpoint.key()) #通过.key()查看名称
model_param = checkpoint['model']
model_param = {"backbone." + k: v for k, v in model_param.items() if "backbone." + k in model_dict}
# "backbone." + k 中的"backbone."是根据模型参数和模型结构的state_dict差异决定的
model_dict.update(model_param)
model.load_state_dict(model_param) # 固定搭配
print("====>loaded checkpoint ")
else:
print("====>no checkpoint found.")
即可完成模型参数的加载。
同时也可以将上方的代码直接加入 class 中
class Model(nn.Module):
def __init__(self, num_classes=45):
super(Model, self).__init__()
model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False)
model_dict = model.state_dict()
resume = True
if resume:
print("Resume from checkpoint...")
checkpoint = torch.load('/root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth')
model_param = checkpoint['model']
model_param = {k: v for k, v in model_param.items() if k in model_dict}
# 在此处就不需要添加多余的前缀
model_dict.update(model_param)
model.load_state_dict(model_param)
print("====>loaded checkpoint ")
else:
print("====>no checkpoint found.")
self.backbone = model
self.fc = nn.Linear(1000, num_classes)
def forward(self, x):
x = self.backbone(x)
x = self.fc(x)
return x
当前还存在一个问题,就是如果输入的模型numworks不设置为1000的话,就会导致在加载模型的最后一层报错 使用model.load_state_dict(model_param, False)也无法完成加载。
该问题尚未解决,期待解决方案!
|