IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【元学习】MER代码实现:Task/Class-IL增量场景下的Meta-Experience Replay详解 -> 正文阅读

[人工智能]【元学习】MER代码实现:Task/Class-IL增量场景下的Meta-Experience Replay详解

论文《Learning to learn without forgetting by maximizing transfer and minimizing interference》中提出了“将经验重放与元学习相结合“的增量学习方法:Meta-Experience Replay (MER)。
这里整理了一下MER的算法流程和代码实现,分别针对任务增量(Task-IL)和类增量(Class-IL)场景下。

论文解析可以戳这里:??????论文解析:Learning to learn without forgetting by maximizing transfer and minimizing interference

目录

1. 算法基础

1.1 Reservior Sampling (蓄水池采样)

1.2 Experience Replay (ER,经验回放方法)

1.3 Reptile

?2. Meta-Experience Replay 算法

?2.1 MER 算法详解

?2.2 任务增量下的代码注释

?2.3 类增量下的代码注释


1. 算法基础

1.1 Reservior Sampling (蓄水池采样)

Reservior Sampling 是基于经验重放的增量学习方法中常使用的等概率采样方法

(1) 原理

给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中数据只能访问一次。写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。

(2) 方法

假设需要采样的数量为k。

首先构建一个可容纳k个元素的数据,将序列的前k个元素放入数据。

然后对第j个元素(j>k),以k/j的概率决定该元素是否被留下(替换到数组中,数组中的k个元素被替换的概率相同)。

?(3) 证明

?(4) 算法流程


1.2 Experience Replay (ER,经验回放方法)

(1) 学习目标

核心是保持对已经见过的exemplars的记忆

目标函数:

其中,M?为 memory buffer,current size =?M_{size }, maximum size =?M_{max}

原理:使用 Reservioe Sampling 更新 buffer,确保在每一个时间步长里,任何?N?个exemplars在 buffer 中被看见的概率都等于?M_{size} /N

?(2) 算法流程

ER 算法中,每看到新的样本,就对当前 exemplars 进行优先级排序。确保 current exemplars 与 replay buffer 中的例子交叉(因为在继续next example前,希望确保算法能够对current example进行优化,特别是当它还未加入到memory中)


1.3 Reptile

Reptile是元学习中最经典和常用的算法之一。具体的原理可以自行查阅相关文献。
本文的MER就是在Reptile基础上结合增量学习,Reptile基于SGD优化器和学习率\alpha,跨s批次顺序优化。

在a set of s batches上的优化目标为:

?算法流程:

?2. Meta-Experience Replay 算法

这里主要介绍论文中的 Algorithm 1,是单个样本的增量更新。(Algorithm 6 是对一个批次batch的增量更新,原理和代码相差不大。)

?2.1 MER 算法详解

原理:MER保持着 Experience Replay 的记忆M,通过 Reservior Sampling 采样。每次时间步提取包括从buffer中k-1个随机样本在内的s个batches。

流程:

1、黄色框为内部更新 inner update:
? ? ?在Reptile的基础流程。对于 s 个 batches,每个 batch 中的 k 个样本,都进行1次Reptile批处理。

2、绿色框为外部更新 outer update:
? ? ?根据 inner update 后的模型参数,更新原始模型参数。使用Reservior sampling来更新 memory buffer。

?


?2.2 任务增量下的代码注释

代码链接:MER/meralg1.py at master · mattriemer/MER · GitHub

(1) Draw batches from buffer:
当前的新样本为 (x,y),结合新样本和从 memory buffer?M?中取出的旧样本(经验回放),生成该批次要训练的样本:B_1,B_2...,B_s \leftarrow sample(x,y,s,k,M)

def getBatch(self, x, y, t):
    # (x,y): 新看到的样本
    xi = Variable(torch.from_numpy(np.array(x))).float().view(1, -1)
    yi = Variable(torch.from_numpy(np.array(y))).long().view(1)
    if self.cuda:
        xi = xi.cuda()
        yi = yi.cuda()

    # bxs, bys: 该批次要训练的样本
    bxs = [xi]
    bys = [yi]

    if len(self.M) > 0:
        order = [i for i in range(0, len(self.M))]
        osize = min(self.batchSize, len(self.M))
        for j in range(0, osize):
            shuffle(order)
            k = order[j]
            x, y, t = self.M[k]
            xi = Variable(torch.from_numpy(np.array(x))).float().view(1, -1)
            yi = Variable(torch.from_numpy(np.array(y))).long().view(1)
            # handle gpus if specified
            if self.cuda:
                xi = xi.cuda()
                yi = yi.cuda()
            bxs.append(xi)
            bys.append(yi)

    return bxs, bys

在 observe() 中调用:

# Draw batch from buffer
bxs,bys = self.getBatch(xi,yi,t)

(2) Inner update 中使用Reptile meta-update:

for step in range(0, self.steps):
    weights_before = deepcopy(self.net.state_dict())
    # Draw batch from buffer:
    bxs, bys = self.getBatch(xi, yi, t)
    loss = 0.0
    for idx in range(len(bxs)):
        # 单个样本进行元学习
        self.net.zero_grad()
        bx = bxs[idx]
        by = bys[idx]
        prediction = self.forward(bx, 0)
        loss = self.bce(prediction, by)
        loss.backward()
        self.opt.step()

    weights_after = self.net.state_dict()

    # Within batch Reptile meta-update:
    # 更新内部模型的参数
    self.net.load_state_dict(
        {name: weights_before[name] + ((weights_after[name] - weights_before[name]) * self.beta) for name in
         weights_before})

(3) Outer update 中进行 Reptile 元更新和重新采样
第一步,将内部更新的元模型参数进行外部模型的更新:

 # Across batch Reptile meta-update
self.net.load_state_dict({name : before[name] + ((after[name] - before[name]) * self.gamma) for name in before})

第二步,使用 Reservoir Sampling 更新 buffer memory:

# Reservoir sampling memory update: 
if len(self.M) < self.memories:
    self.M.append([xi, yi, t])
else:
    p = random.randint(0, self.age)
    if p < self.memories:
        self.M[p] = [xi, yi, t]

?2.3 类增量下的代码注释

代码链接:La-MAML/meralg1.py at main · montrealrobotics/La-MAML · GitHub

(0) initialization初始化:
在__init__() 函数中,根据类增量的场景,重新设置了每个任务的类别数 nc_per_task

self.n_outputs = n_outputs
if self.is_cifar:  # Class-IL
    self.nc_per_task = n_outputs / n_tasks  # 每个任务的类别不重叠
else:  # Task -IL
    self.nc_per_task = n_outputs  # 每个任务的类别可以看作一样

(1) Draw batches from buffer:
与2.2中的如出一辙,但是多增加了任务t。将任务t也加入到了buffer memory中。

(2) Inner update 中使用Reptile meta-update:

在这里,类增量比2.2(任务增量)新增了一个 compute_offsets() 函数。主要是因为任务增量中,验证的是所有类别(每个任务的类别可以近似看作是一样的)的预测结果;而类增量中,验证的是当前任务中涉及到的类别(每个任务的类别都不重叠)的预测结果。
所以,使用 compute_offsets() 函数来框定只有在这次任务出现的类别:

def compute_offsets(self, task):
    if self.is_cifar: # Class-IL
        offset1 = task * self.nc_per_task
        offset2 = (task + 1) * self.nc_per_task
    else: # Task-IL
        offset1 = 0
        offset2 = self.n_outputs
    return int(offset1), int(offset2)

同样,在 forward() 中也应用了compute_offsets() 函数来框定只有在这次任务中出现的类别的预测结果:

def forward(self, x, t):
    output = self.netforward(x)
    if self.is_cifar:
        offset1, offset2 = self.compute_offsets(t)
        
        # 不在offset1~offset2的预测结果都剔除
        if offset1 > 0:
            output[:, :offset1].data.fill_(-10e10)
        if offset2 < self.n_outputs:
            output[:, int(offset2):self.n_outputs].data.fill_(-10e10)
    return output

在observe() 函数中,inner update内部更新流程:

for step in range(0, self.steps):
    weights_before = deepcopy(self.net.state_dict())
    ##Check for nan
    if weights_before != weights_before:
        ipdb.set_trace()
    # Draw batch from buffer:
    bxs, bys, bts = self.getBatch(xi, yi, t)
    loss = 0.0
    total_loss = 0.0
    for idx in range(len(bxs)):

        self.net.zero_grad()
        bx = bxs[idx]
        by = bys[idx]
        bt = bts[idx]

        if self.is_cifar:  # Class-IL
            offset1, offset2 = self.compute_offsets(bt)  # 获得当前任务的index
            prediction = (self.netforward(bx)[:, offset1:offset2])  # 获得当前任务的预测结果
            loss = self.bce(prediction,
                            by.unsqueeze(0) - offset1)
        else:  # Task-IL
            prediction = self.forward(bx, 0)
            loss = self.bce(prediction, by.unsqueeze(0))
        if torch.isnan(loss):
            ipdb.set_trace()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)
        self.opt.step()
        total_loss += loss.item()
    weights_after = self.net.state_dict()
    if weights_after != weights_after:
        ipdb.set_trace()

    # Within batch Reptile meta-update:
    self.net.load_state_dict(
        {name: weights_before[name] + ((weights_after[name] - weights_before[name]) * self.beta) for name in
         weights_before})

(3) Outer update 中进行 Reptile 元更新和重新采样

与2.2中的如出一辙,但是多增加了任务t。将任务t也加入到了buffer memory中。

以上是我对任务增量/类增量场景下的MER代码的一些理解,从代码上也可以看出任务增量和类增量的异同。如果有写的不对的地方,欢迎指出与讨论~

citation:M. Reimer, I. Cases, R. Ajemian, M. Liu, I. Rish, Y. Tu, G. Tesauro, Learning to learn without forgetting by maximizing transfer and minimizing interference, in: ICLR, 2019.

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-27 14:05:27  更:2021-09-27 14:05:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 16:01:51-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码