论文《Learning to learn without forgetting by maximizing transfer and minimizing interference》中提出了“将经验重放与元学习相结合“的增量学习方法:Meta-Experience Replay (MER)。 这里整理了一下MER的算法流程和代码实现,分别针对任务增量(Task-IL)和类增量(Class-IL)场景下。
论文解析可以戳这里:??????论文解析:Learning to learn without forgetting by maximizing transfer and minimizing interference
目录
1. 算法基础
1.1 Reservior Sampling (蓄水池采样)
1.2 Experience Replay (ER,经验回放方法)
1.3 Reptile
?2. Meta-Experience Replay 算法
?2.1 MER 算法详解
?2.2 任务增量下的代码注释
?2.3 类增量下的代码注释
1. 算法基础
1.1 Reservior Sampling (蓄水池采样)
Reservior Sampling 是基于经验重放的增量学习方法中常使用的等概率采样方法。
(1) 原理
给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中数据只能访问一次。写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。
(2) 方法
假设需要采样的数量为k。
首先构建一个可容纳k个元素的数据,将序列的前k个元素放入数据。
然后对第j个元素(j>k),以k/j的概率决定该元素是否被留下(替换到数组中,数组中的k个元素被替换的概率相同)。
?(3) 证明
?(4) 算法流程
1.2 Experience Replay (ER,经验回放方法)
(1) 学习目标
核心是保持对已经见过的exemplars的记忆
目标函数:
其中,?为 memory buffer,current size =?, maximum size =?
原理:使用 Reservioe Sampling 更新 buffer,确保在每一个时间步长里,任何??个exemplars在 buffer 中被看见的概率都等于?
?(2) 算法流程
ER 算法中,每看到新的样本,就对当前 exemplars 进行优先级排序。确保 current exemplars 与 replay buffer 中的例子交叉(因为在继续next example前,希望确保算法能够对current example进行优化,特别是当它还未加入到memory中)
1.3 Reptile
Reptile是元学习中最经典和常用的算法之一。具体的原理可以自行查阅相关文献。 本文的MER就是在Reptile基础上结合增量学习,Reptile基于SGD优化器和学习率,跨s批次顺序优化。
在a set of s batches上的优化目标为:
?算法流程:
这里主要介绍论文中的 Algorithm 1,是单个样本的增量更新。(Algorithm 6 是对一个批次batch的增量更新,原理和代码相差不大。)
?2.1 MER 算法详解
原理:MER保持着 Experience Replay 的记忆,通过 Reservior Sampling 采样。每次时间步提取包括从buffer中k-1个随机样本在内的s个batches。
流程:
1、黄色框为内部更新 inner update: ? ? ?在Reptile的基础流程。对于 s 个 batches,每个 batch 中的 k 个样本,都进行1次Reptile批处理。
2、绿色框为外部更新 outer update: ? ? ?根据 inner update 后的模型参数,更新原始模型参数。使用Reservior sampling来更新 memory buffer。
?
?2.2 任务增量下的代码注释
代码链接:MER/meralg1.py at master · mattriemer/MER · GitHub
(1) Draw batches from buffer: 当前的新样本为 (x,y),结合新样本和从 memory buffer??中取出的旧样本(经验回放),生成该批次要训练的样本:
def getBatch(self, x, y, t):
# (x,y): 新看到的样本
xi = Variable(torch.from_numpy(np.array(x))).float().view(1, -1)
yi = Variable(torch.from_numpy(np.array(y))).long().view(1)
if self.cuda:
xi = xi.cuda()
yi = yi.cuda()
# bxs, bys: 该批次要训练的样本
bxs = [xi]
bys = [yi]
if len(self.M) > 0:
order = [i for i in range(0, len(self.M))]
osize = min(self.batchSize, len(self.M))
for j in range(0, osize):
shuffle(order)
k = order[j]
x, y, t = self.M[k]
xi = Variable(torch.from_numpy(np.array(x))).float().view(1, -1)
yi = Variable(torch.from_numpy(np.array(y))).long().view(1)
# handle gpus if specified
if self.cuda:
xi = xi.cuda()
yi = yi.cuda()
bxs.append(xi)
bys.append(yi)
return bxs, bys
在 observe() 中调用:
# Draw batch from buffer
bxs,bys = self.getBatch(xi,yi,t)
(2) Inner update 中使用Reptile meta-update:
for step in range(0, self.steps):
weights_before = deepcopy(self.net.state_dict())
# Draw batch from buffer:
bxs, bys = self.getBatch(xi, yi, t)
loss = 0.0
for idx in range(len(bxs)):
# 单个样本进行元学习
self.net.zero_grad()
bx = bxs[idx]
by = bys[idx]
prediction = self.forward(bx, 0)
loss = self.bce(prediction, by)
loss.backward()
self.opt.step()
weights_after = self.net.state_dict()
# Within batch Reptile meta-update:
# 更新内部模型的参数
self.net.load_state_dict(
{name: weights_before[name] + ((weights_after[name] - weights_before[name]) * self.beta) for name in
weights_before})
(3) Outer update 中进行 Reptile 元更新和重新采样 第一步,将内部更新的元模型参数进行外部模型的更新:
# Across batch Reptile meta-update
self.net.load_state_dict({name : before[name] + ((after[name] - before[name]) * self.gamma) for name in before})
第二步,使用 Reservoir Sampling 更新 buffer memory:
# Reservoir sampling memory update:
if len(self.M) < self.memories:
self.M.append([xi, yi, t])
else:
p = random.randint(0, self.age)
if p < self.memories:
self.M[p] = [xi, yi, t]
?2.3 类增量下的代码注释
代码链接:La-MAML/meralg1.py at main · montrealrobotics/La-MAML · GitHub
(0) initialization初始化: 在__init__() 函数中,根据类增量的场景,重新设置了每个任务的类别数 nc_per_task
self.n_outputs = n_outputs
if self.is_cifar: # Class-IL
self.nc_per_task = n_outputs / n_tasks # 每个任务的类别不重叠
else: # Task -IL
self.nc_per_task = n_outputs # 每个任务的类别可以看作一样
(1) Draw batches from buffer: 与2.2中的如出一辙,但是多增加了任务t。将任务t也加入到了buffer memory中。
(2) Inner update 中使用Reptile meta-update:
在这里,类增量比2.2(任务增量)新增了一个 compute_offsets() 函数。主要是因为任务增量中,验证的是所有类别(每个任务的类别可以近似看作是一样的)的预测结果;而类增量中,验证的是当前任务中涉及到的类别(每个任务的类别都不重叠)的预测结果。 所以,使用 compute_offsets() 函数来框定只有在这次任务出现的类别:
def compute_offsets(self, task):
if self.is_cifar: # Class-IL
offset1 = task * self.nc_per_task
offset2 = (task + 1) * self.nc_per_task
else: # Task-IL
offset1 = 0
offset2 = self.n_outputs
return int(offset1), int(offset2)
同样,在 forward() 中也应用了compute_offsets() 函数来框定只有在这次任务中出现的类别的预测结果:
def forward(self, x, t):
output = self.netforward(x)
if self.is_cifar:
offset1, offset2 = self.compute_offsets(t)
# 不在offset1~offset2的预测结果都剔除
if offset1 > 0:
output[:, :offset1].data.fill_(-10e10)
if offset2 < self.n_outputs:
output[:, int(offset2):self.n_outputs].data.fill_(-10e10)
return output
在observe() 函数中,inner update内部更新流程:
for step in range(0, self.steps):
weights_before = deepcopy(self.net.state_dict())
##Check for nan
if weights_before != weights_before:
ipdb.set_trace()
# Draw batch from buffer:
bxs, bys, bts = self.getBatch(xi, yi, t)
loss = 0.0
total_loss = 0.0
for idx in range(len(bxs)):
self.net.zero_grad()
bx = bxs[idx]
by = bys[idx]
bt = bts[idx]
if self.is_cifar: # Class-IL
offset1, offset2 = self.compute_offsets(bt) # 获得当前任务的index
prediction = (self.netforward(bx)[:, offset1:offset2]) # 获得当前任务的预测结果
loss = self.bce(prediction,
by.unsqueeze(0) - offset1)
else: # Task-IL
prediction = self.forward(bx, 0)
loss = self.bce(prediction, by.unsqueeze(0))
if torch.isnan(loss):
ipdb.set_trace()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)
self.opt.step()
total_loss += loss.item()
weights_after = self.net.state_dict()
if weights_after != weights_after:
ipdb.set_trace()
# Within batch Reptile meta-update:
self.net.load_state_dict(
{name: weights_before[name] + ((weights_after[name] - weights_before[name]) * self.beta) for name in
weights_before})
(3) Outer update 中进行 Reptile 元更新和重新采样
与2.2中的如出一辙,但是多增加了任务t。将任务t也加入到了buffer memory中。
以上是我对任务增量/类增量场景下的MER代码的一些理解,从代码上也可以看出任务增量和类增量的异同。如果有写的不对的地方,欢迎指出与讨论~
citation:M. Reimer, I. Cases, R. Ajemian, M. Liu, I. Rish, Y. Tu, G. Tesauro, Learning to learn without forgetting by maximizing transfer and minimizing interference, in: ICLR, 2019.
|