| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> MixMatch文章解读+算法流程+核心代码详解 -> 正文阅读 |
|
[人工智能]MixMatch文章解读+算法流程+核心代码详解 |
MixMatch本博客仅做算法流程疏导,具体细节请参见原文 原文Github代码解读MixMatch抓住了半监督算法的两个重要观点:第一是熵最小化;第二是一致性正则化。结合这两个观点的算法就形成了MixMatch。 熵最小化半监督算法的一个常见假设就是分类的决策边界不应该通过数据分布的高密度区域。这句话简单的理解可以想象一个聚类模型,其决策边界一定是在簇与簇之间的稀疏边界上,不可能穿过一个簇的中心(高密度区域)。而实现这一点的一种方法就是要求分类器对未标记数据输出低熵预测。MixMatch中使用一个"sharpening"函数来隐式实现熵最小化。所谓熵最小化、低熵预测,都是指使输出概率分布比较有“偏向性”,而不希望输出一个“平均的预测”。熵在信息论中是不确定度的度量,根据离散模型的熵最大定理,可知在均匀分布时熵取得最大值,换句话说,出现一个确定的分布,即某一类的概率是1,其余类的概率是0时,熵为0。也就是说想要得到熵最小,就得使分类器输出后的模型预测概率集中分配给某一类。后面再介绍“sharpening”函数如何实现这一点。 一致性正则化一致性正则化也是一个常见的半监督假设。VAT、MeanTeacher等其实都或多或少使用了这种假设。其核心在于,我们希望一个样本和其加扰版本(通常图像中称为Augment)通过分类器后,得到相似的输出。其实也就是说分类边界不应该穿过数据分布的高密度区域。如下图,红色点是原始样本,蓝色和绿色为其扰动版本,红色同心圆的虚线圆是我们期望的容差范围,即在这个区间类的都应该认为和其中心数据点为同一类。通过扰动数据点的加入,将决策边界推到合适的位置,使分类器的鲁棒性更强。 一般而言,通过对原始样本和其扰动版本的分类器输出进行衡量,即可实现一致性正则化,常见的衡量方式有MSE、KL散度、JS散度等。在MixMatch中通过对图像的标准数据增强(水平翻转、裁剪)实现扰动(Augment),采用MSE准则方式衡量。 总得来说,算法有以下步骤: 归结而言有五个步骤: 第一步,对数据进行扩增(Augment)。扩增分为对有标记数据集 X X X?的扩增和对无标记数据集 U U U?的扩增,分别记为 X ^ \hat{X} X^?和 U ^ \hat{U} U^?。对 X X X?扩增一次,对 U U U?扩增 K K K?次,文章中取 K = 2 K=2 K=2?。因为在取batch时, B a t c h S i z e U = B a t c h S i z e X Batch Size _U = BatchSize_X BatchSizeU?=BatchSizeX??,所以扩增后 B a t c h S i z e U ^ = K ? B a t c h S i z e X ^ Batch Size _{\hat{U}} = K\cdot BatchSize_{\hat{X}} BatchSizeU^?=K?BatchSizeX^?????。 第二步,计算平均预测分布。此步骤仅对数据集 U ^ \hat{U} U^???进行。即通过如下公式计算,其中 ( u b , k ^ , y ) (\hat{u_{b,k}},y) (ub,k?^?,y)?是 U ^ \hat{U} U^?的一个 B a t c h Batch Batch?: q b ˉ = 1 K ∑ k P m o d e l ( y ∣ u b , k ^ ; θ ) \bar{q_b}=\frac{1}{K}\sum_kP_{model}(y|\hat{u_{b,k}};\theta) qb?ˉ?=K1?k∑?Pmodel?(y∣ub,k?^?;θ) 值得注意的是, P m o d e l ( y ∣ u b , k ^ ; θ ) P_{model}(y|\hat{u_{b,k}};\theta) Pmodel?(y∣ub,k?^?;θ)是 S o f t m a x Softmax Softmax?之后的预测概率分布。 第三步,通过 s h a r p e n i n g sharpening sharpening函数完成分布的锐化,其计算公式如下: S h a r p e n ( p , T ) i = p i 1 T ∑ j = 1 L p j 1 T Sharpen(p,T)_i=\frac{p_i^{\frac{1}{T}}}{\sum^L_{j=1}p_j^{\frac{1}{T}}} Sharpen(p,T)i?=∑j=1L?pjT1??piT1???? 当超参数 T → 0 T\to 0 T→0?时, S h a r p e n ( p , T ) Sharpen(p,T) Sharpen(p,T)?趋向于 o n e ? h o t one-hot one?hot??分布,即其中一个类别的概率为1,其余概率为0;锐化后的概率分布作为 U ^ \hat{U} U^?的数据标签(pseudo label)。 第四步,通过
M
i
x
U
p
MixUp
MixUp完成新数据集的构建。先将第一步扩增后的
X
^
\hat{X}
X^和
U
^
\hat{U}
U^进行拼接再打乱顺序,得到
W
=
S
h
u
f
f
l
e
(
C
o
n
c
a
t
(
X
^
,
U
^
)
)
W=Shuffle(Concat(\hat{X},\hat{U}))
W=Shuffle(Concat(X^,U^)),然后再将
W
W
W分为两部分,第一部分大小与
X
^
\hat{X}
X^相同(也与
X
X
X相同),记为
W
x
W_x
Wx?;第二部分大小与
U
^
\hat{U}
U^相同(也与
U
U
U相同),记为
W
u
W_u
Wu?。然后将
W
x
W_x
Wx?和
X
^
\hat{X}
X^进行
M
i
x
U
p
MixUp
MixUp,
W
u
W_u
Wu?和
U
^
\hat{U}
U^进行
M
i
x
U
p
MixUp
MixUp,得到
X
′
X'
X′和
U
′
U'
U′?。
M
i
x
U
p
MixUp
MixUp步骤如下: 第五步,计算半监督损失函数,分为在标记数据集 X ′ X' X′?上的损失函数 L x L_x Lx?和在无标记数据集 U ′ U' U′上的损失函数 L u L_u Lu?,公式如下: L x = 1 ∣ X ′ ∣ ∑ x , p ∈ X ′ H ( p , P m o d e l ( y ∣ x ; θ ) ) L_x=\frac{1}{|X'|}\sum_{x,p\in X'}H(p,P_{model}(y|x;\theta)) Lx?=∣X′∣1?x,p∈X′∑?H(p,Pmodel?(y∣x;θ)) L u = 1 L ∣ U ′ ∣ ∑ u , q ∈ U ′ ∣ ∣ q ? P m o d e l ( y ∣ u ; θ ) ∣ ∣ 2 2 L_u=\frac{1}{L|U'|}\sum_{u,q\in U'}||q-P_{model}(y|u;\theta)||^2_2 Lu?=L∣U′∣1?u,q∈U′∑?∣∣q?Pmodel?(y∣u;θ)∣∣22? L = L x + λ U L u L=L_x+\lambda_UL_u L=Lx?+λU?Lu? 其中 H ( ? ) H(\cdot) H(?)?是 C o r s s E n t r o p y L o s s CorssEntropyLoss CorssEntropyLoss?; L u L_u Lu?其实就是 M S E MSE MSE准则下的误差项。? 反向梯度传播即可完成整个MixMatch算法 核心代码详解图像的水平翻转、裁剪实现 A u g m e n t Augment Augment:
这里是在迭代过程中,手动取迭代器中的batch,而不是直接使用Dataloader。这种做法在最近的几篇文章代码复现中都遇见了,其主要目的是为了在一个epoch中可以迭代指定次数,而直接使用Dataloader只能迭代最多 c e i l ( 样 本 总 数 B a t c h S i z e ) ceil(\frac{样本总数}{BatchSize}) ceil(BatchSize样本总数?)次,其中 c e i l ( ? ) ceil(\cdot) ceil(?)是上取整函数,如果 d r o p l a s t drop_last dropl?ast,则只能迭代 样 本 总 数 B a t c h S i z e \frac{样本总数}{BatchSize} BatchSize样本总数?次。代码中的两个try except是为了保证迭代器完全迭代一次后,重新加载迭代器,继续迭代,直到达到指定次数才跳转下一个epoch。
因为文章中取 K = 2 K=2 K=2,所以进行两次扩增,求输出概率的均值,其中output_u和output_u2分别为两次扩增后的模型输出结果:
求Sharpening结果:
完成 M i x U p MixUp MixUp:
然后计算损失函数:
反向梯度传播,结束。 |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 | -2024/12/22 10:42:23- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |