| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> Federated Reconnaissance: Efficient Distributed Class-Incremental Learning 论文阅读+代码解析 -> 正文阅读 |
|
[人工智能]Federated Reconnaissance: Efficient Distributed Class-Incremental Learning 论文阅读+代码解析 |
一. 介绍论文中,作者提出了联合侦察,这是一类新的学习问题,分布式模型应该能够独立地学习新的概念,并有效地共享这些知识。通常在联合学习中,单个静态类集由每个客户端学习。相反,联邦侦察要求每个客户机可以单独学习一组不断增长的类,并与其他客户机有效地交流之前观察到的和新的类的知识。这种关于学习类的交流可以从客户那里获得知识;然后期望最终合并的模型支持每个客户机已公开的类的超集。然后可以将合并的模型部署回客户端进行进一步的学习。 1.1 早期的工作持续学习: 不断学习新概念是一个开放的和长期的问题在机器学习和人工智能没有表面上的一个统一的解决方案。虽然深度神经网络已被证明在广泛的任务中是非常有效的,但持续整合新信息的可用方法,同时记住以前学到的概念会变得效率低下。在这项工作中,我们假设访问一组训练前的数据,并探索算法,允许高效和准确的学习新类的顺序。 1.2 贡献一个有效的联邦侦查系统必须解决新类的高效学习与知识的保存转移。因此,作者将普通的随机梯度下降作为下界,iCaRL算法用于联邦侦查的比较以及将所有客户的所有训练数据联合分布的SGD作为上界。 二. 联邦侦查问题陈述2.1 系统需求联邦侦察需要对每个客户端设备进行持续的学习、高效的通信和知识合并。受到在大量分布式客户端设备上学习新类的应用程序的启发,我们定义了联邦侦察学习系统的以下需求:
联邦侦察的实际实现的具体要求当然将决定每个需求的细节和相对重要性。 2.2 问题定义联邦侦查由一组客户端组成
C
:
=
{
c
i
∣
i
∈
1...
C
}
\mathbb{C}:=\{c_i|i\in 1...C\}
C:={ci?∣i∈1...C},每一个客户端都经历着类不断增加的情况
M
i
,
t
:
=
{
p
(
y
=
j
∣
x
)
∣
j
∈
1...
M
j
}
\mathbb{M}_{i,t}:=\{p(y=j|x)|j\in 1...M_j\}
Mi,t?:={p(y=j∣x)∣j∈1...Mj?}。其中
C
C
C表示客户端的总数,
M
i
M_i
Mi?则表示一个客户端能区分的类的总数,一个类由概率
p
(
y
=
j
∣
x
)
p(y=j|x)
p(y=j∣x)通过标签j和x进行表示。中央服务器的工作是合并客户机关于类的知识
M
t
=
?
i
=
1
C
M
i
,
t
\mathbb{M}_t=\bigcup^C_{i=1}\mathbb{M}_{i,t}
Mt?=?i=1C?Mi,t?然后部署更新的模型并将模型
M
t
\mathbb{M}_t
Mt?返回给
C
\mathbb{C}
C。一个客户端
C
i
C_i
Ci?可以通过直接使用一组标记的例子进行训练从而接触到一个新的类
{
(
x
,
y
)
∣
(
x
,
y
)
∈
X
j
×
Y
j
}
\{(x,y)|(x,y)\in X_j \times Y_j\}
{(x,y)∣(x,y)∈Xj?×Yj?},或者交换经过压缩的知识,使得客户端近似估计
p
(
y
=
j
∣
x
)
p(y=j|x)
p(y=j∣x)。 三. 方法3.1 学习的算法(这里作者说了自己比较了哪些方法,由于本博客主要是学习思想,因此就不写了) 3.2 联邦原型网络(Federated Prototypical Networks)我们提出使用原型网络来有效地循序学习新类。由于原型网络在测试时不是基于梯度的,因此在学习新类时,通过对足够多的类进行判别性预训练,可以使它们对灾难性遗忘具有鲁棒性。当在联邦侦察基准上进行评估时,我们可以通过简单地存储之前的原型(方差)和用于计算之前原型的示例数量来计算每个类的均值(如果需要的话,还有方差)的无偏估计。我们根据定义了原型网络: 四.关键代码解读代码地址点这里 4.1 元训练部分首先就是我们需要定义prototypical网络,也就是计算出z来。
看着很复杂,其实就是4个卷积层构成的,作者这里的hidden_size为64(64个3*3的卷积核),因此假如说一个batch的x为[1,25,1,28,28](类似于元学习,第一个1表示有一个任务,25表示一个任务包含的数据量,这里是5way5shot所以是25,第三个1表示通道,之后为图片),经过encorder之后变为:[25,64,1,1],之后变了一下形状变为encorder = [1,25,64]方便之后计算。
首先对我们的5way5shot拆分encorder,变为[1,5,5,64],之后对我们的每一个类中的样本求平均(也就是5个example进行平均),算出来的z就为:[1,5,64]
不断迭代更新出我们的encorder层中的参数即可。 4.2 元测试部分——类增量再完成训练后,我们的enocrder参数达到最佳
θ
?
\theta^*
θ?,此时我们来进行元测试,进行类增量的测试。
计算z,由于只有当前类的信息,因此我们需要存储之前出现类的z,然后一起concat即可。
之后计算距离,再计算两者距离最小的下标即为预测。
|
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 | -2024/11/26 22:29:05- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |