最近博主在尝试多任务网络,简而言之就是网络中有一个backbone和多个head,每个head对应不同的任务。训练多任务网络,有一种训练方法是固定住backbone,每个head单独训练,这样的话head之间互不影响 ,是有利于提高单个任务的精度的。 但是博主在写代码时却发现,虽然已经小心翼翼了,但是head之间依然相互影响了,于是开始了漫长的debug过程。。。。。
以下是博主训练时的伪代码~
import pyorch
class MTL(nn.Module):
self.backbone = resnet18()
self.head1 = Linear()
self.head2 = Linear()
def forward(x, task):
if task = "1":
return self.head1(self.backbone(x))
elif task = "2":
return self.head2(self.backbone(x))
model = MTL()
model.load_weight()
task_1_pred = []
task_2_pred = []
model.eval()
for data in test_dataloader:
x, y = data
task_1_pred.append(y, model(x, "1"))
task_2_pred.append(y, model(x, "2"))
acc_task_1_pretained = metric_task_1(task_1_pred)
acc_task_2_pretained = metric_task_2(task_2_pred)
model.train()
model.backbone.freeze()
model.head2.freeze()
optimizer = SGD(filter(lambda x: x.requires_grad), model.parameters()),
lr=1e-2,
momentum=0.9)
for data in train_dataloader:
optimizer.zero_grad()
x, y = data
pred = model(x, task="1")
loss = criterion(pred, y)
loss.backward()
optimizer.step()
task_1_pred = []
task_2_pred = []
model.eval()
for data in test_dataloader:
x, y = data
task_1_pred.append(y, model(x, "1"))
task_2_pred.append(y, model(x, "2"))
acc_task_1_finetune = metric_task_1(task_1_pred)
acc_task_2_finetune = metric_task_2(task_2_pred)
assert acc_task_2_finetune == acc_task_2_pretained
Debug过程:
- 由于实际代码远比伪代码复杂,一开始没有怀疑到模型头上,检查了dataloader,metric函数等等,发现并没有什么错误。
- 怀疑freeze没有起作用,于是计算了train之前和train之后backbone、head2的参数的变化,发现变化为0。即freeze起作用了。
store = {}
for name, val in model.backbone.named_parameters():
store[name] = val.clone().detach()
for name, val in model.head2.named_parameters():
store[name] = val.clone().detach()
for name, val in model.backbone.named_parameters():
print(torch.mean(torch.abs(val-store[name])))
for name, val in model.head2.named_parameters():
print(torch.mean(torch.abs(val-store[name])))
- 开始怀疑人生。。。。。。
- 继续怀疑人生。。。。。。
- 在某个论坛上看到BN层在推理时一些坑后,然后看了BN层实现的源码后,豁然开朗了。。。
Bug出现的原因:
- BN层在训练和测试时有两个参数的行为是不一样,即running_mean和running_var在训练和测试时是不一样。训练时,这两个参数是用当前batch计算出来;测试时,这两个参数与当前batch无关,而是使用了整个数据集的mean和var。这也是为什么在进行模型测试时,需要使用model.eval()命令的原因。
- BN层的running_mean和running_var这两个参数是统计值,梯度的反向传播与它们无关,使用 model.named_parameters()也无法获取到这两个参数的值。但是,这两个参数在模型训练过程中是切切实实在变化的,因为BN层需要不断更新这两个参数来对数据集的mean和var进行统计。
- 所以问题出现在,即使我们使用了freeze命令,模型的backbone的BN层的running_mean和running_var这两个参数也出现了更新。由于head2需要用到backbone的输出,因此task2的accuracy出现了变化。
解决办法:
很简单,添加model.backbone.eval()即可,这样BN层的running_mean和running_var这两个参数就不会更新了。debug后的伪代码如下:
model.train()
model.backbone.freeze()
model.backbone.eval()
model.head2.freeze()
model.head2.eval()
optimizer = SGD(filter(lambda x: x.requires_grad), model.parameters()),
lr=1e-2,
momentum=0.9)
.....
.....
.....
|