IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch在finetune和multi-task任务的一个小坑(与BN层推理相关的坑) -> 正文阅读

[人工智能]Pytorch在finetune和multi-task任务的一个小坑(与BN层推理相关的坑)

最近博主在尝试多任务网络,简而言之就是网络中有一个backbone和多个head,每个head对应不同的任务。训练多任务网络,有一种训练方法是固定住backbone,每个head单独训练,这样的话head之间互不影响 ,是有利于提高单个任务的精度的。
但是博主在写代码时却发现,虽然已经小心翼翼了,但是head之间依然相互影响了,于是开始了漫长的debug过程。。。。。

以下是博主训练时的伪代码~

import pyorch
class MTL(nn.Module):
	self.backbone = resnet18()
	self.head1 = Linear()
	self.head2 = Linear()

	def forward(x, task):
		if task = "1":
			return self.head1(self.backbone(x))
		elif task = "2":
			return self.head2(self.backbone(x))

model = MTL()
model.load_weight()

## check pretained accuracy
task_1_pred = []
task_2_pred = []
model.eval()
for data in test_dataloader:
	x, y = data
	task_1_pred.append(y, model(x, "1"))
	task_2_pred.append(y, model(x, "2"))
acc_task_1_pretained = metric_task_1(task_1_pred)
acc_task_2_pretained = metric_task_2(task_2_pred)

## train head1 and freeze the others
 model.train()
 # 期望网络只更新model.head1的参数,其他的参数不更新
 model.backbone.freeze()
 model.head2.freeze()
 optimizer = SGD(filter(lambda x: x.requires_grad), model.parameters()), 
 				 lr=1e-2, 
 				 momentum=0.9)
for data in train_dataloader:
	optimizer.zero_grad()
	x, y = data
	pred = model(x, task="1")
	loss = criterion(pred, y)
  	loss.backward()
  	optimizer.step()
## check accuracy
task_1_pred = []
task_2_pred = []
model.eval()
for data in test_dataloader:
	x, y = data
	task_1_pred.append(y, model(x, "1"))
	task_2_pred.append(y, model(x, "2"))
acc_task_1_finetune = metric_task_1(task_1_pred)
acc_task_2_finetune = metric_task_2(task_2_pred)

 # 由于网络只更新model.head1的参数,因此task2的accuracy应该是不变的,
 # 但是会出现AssertError,说明somehow,train head1影响了head2的输出。
assert acc_task_2_finetune == acc_task_2_pretained

Debug过程:

  1. 由于实际代码远比伪代码复杂,一开始没有怀疑到模型头上,检查了dataloader,metric函数等等,发现并没有什么错误。
  2. 怀疑freeze没有起作用,于是计算了train之前和train之后backbone、head2的参数的变化,发现变化为0。即freeze起作用了。
#train之前
store = {}
for name, val in model.backbone.named_parameters():
	store[name] = val.clone().detach()
for name, val in model.head2.named_parameters():
	store[name] = val.clone().detach()
# train之后
for name, val in model.backbone.named_parameters():
	print(torch.mean(torch.abs(val-store[name])))
for name, val in model.head2.named_parameters():
	print(torch.mean(torch.abs(val-store[name])))
  1. 开始怀疑人生。。。。。。
  2. 继续怀疑人生。。。。。。
  3. 在某个论坛上看到BN层在推理时一些坑后,然后看了BN层实现的源码后,豁然开朗了。。。

Bug出现的原因:

  1. BN层在训练和测试时有两个参数的行为是不一样,即running_mean和running_var在训练和测试时是不一样。训练时,这两个参数是用当前batch计算出来;测试时,这两个参数与当前batch无关,而是使用了整个数据集的mean和var。这也是为什么在进行模型测试时,需要使用model.eval()命令的原因。
  2. BN层的running_mean和running_var这两个参数是统计值,梯度的反向传播与它们无关,使用 model.named_parameters()也无法获取到这两个参数的值。但是,这两个参数在模型训练过程中是切切实实在变化的,因为BN层需要不断更新这两个参数来对数据集的mean和var进行统计。
  3. 所以问题出现在,即使我们使用了freeze命令,模型的backbone的BN层的running_mean和running_var这两个参数也出现了更新。由于head2需要用到backbone的输出,因此task2的accuracy出现了变化。

解决办法:

很简单,添加model.backbone.eval()即可,这样BN层的running_mean和running_var这两个参数就不会更新了。debug后的伪代码如下:

## train head1 and freeze the others
 model.train()
 # 期望网络只更新model.head1的参数,其他的参数不更新
 model.backbone.freeze()
 model.backbone.eval() # 固定住BN层的running_mean和running_var这两个参数
 model.head2.freeze()
 model.head2.eval() #保险起见,head2也做设置
 optimizer = SGD(filter(lambda x: x.requires_grad), model.parameters()), 
 				 lr=1e-2, 
 				 momentum=0.9)
.....
.....
.....
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-21 15:22:26  更:2021-08-21 15:23:15 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 23:40:29-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码