提出插值一致性训练(Interpolation consistency training, ICT), 这是一种简单且效率高的算法, 用于在半监督学习范式中训练深度神经网络. 在分类问题中, ICT 将决策边界移动到数据分布的低密度区域.
低密度分离假设(聚类假设)启发了许多一致性正则化半监督学习技术, 包括
Π
\Pi
Π-model, Temporal ensembling, Mean-Teacher, VAT. 另外, 在 ICT 这篇论文之后又出现了 UDA, 其效果比 ICT 还要好.
有研究表明: 对抗性扰动训练会损害泛化性能. 为了克服这个问题, 便提出了插值一致性训练(ICT), 简单来说, ICT 通过在未标记点
u
1
u_1
u1? ,
u
2
u_2
u2? 的插值
α
u
1
+
(
1
?
α
u
2
)
\alpha u_1 + (1-\alpha u_2)
αu1?+(1?αu2?) 上的一致性预测
f
(
α
u
1
+
(
1
?
α
u
2
)
)
=
α
f
(
u
1
)
+
(
1
?
α
)
f
(
u
2
)
f(\alpha u_1 + (1-\alpha u_2))=\alpha f(u_1)+(1-\alpha)f(u_2)
f(αu1?+(1?αu2?))=αf(u1?)+(1?α)f(u2?) 来规范半监督学习.
与监督学习相比, ICT 的决策边界穿越低密度区域, 这将更好地反映未标记数据的结构. 对比结果如下图所示:
插值一致性训练(ICT)
根据论文 Mixup: Beyond empirical risk minimization. In International conference on learning representations. 给出 mixup 式子:
M
i
x
λ
(
a
,
b
)
=
λ
a
+
(
1
?
λ
)
b
(1)
{\rm Mix}_{\lambda}(a,b)=\lambda a + (1-\lambda)b \tag{1}
Mixλ?(a,b)=λa+(1?λ)b(1) ICT 训练分类器
f
θ
f_\theta
fθ? 以在未标记点的插值中提供一致性预测:
f
θ
(
M
i
x
λ
(
u
j
,
u
k
)
)
≈
M
i
x
λ
(
f
θ
′
(
u
j
)
,
f
θ
′
(
u
k
)
)
(2)
f_\theta({\rm Mix}_{\lambda}(u_j,u_k))\approx {\rm Mix}_\lambda(f_{\theta'}(u_j),f_{\theta'}(u_k)) \tag{2}
fθ?(Mixλ?(uj?,uk?))≈Mixλ?(fθ′?(uj?),fθ′?(uk?))(2) 其中
θ
′
\theta'
θ′ 时
θ
\theta
θ 的滑动平均. 那么为什么未标记样本之间的插值为半监督训练提供了良好的一致性扰动呢? 有如下解释.
应该应用一致性正则化的最有用的样本是靠近决策边界的样本. 向这种低边距未标记样本
u
j
u_j
uj? 添加一个小的扰动
δ
\delta
δ 可能会将
u
j
+
δ
u_j + \delta
uj?+δ 推到决策边界的另一侧. 这将违反低密度分离假设, 就使
u
j
+
δ
u_j + \delta
uj?+δ 成为应用一致性正则化的好位置.
回到低边距未标记点
u
j
u_j
uj?, 如何找到一个扰动
δ
\delta
δ, 使得
u
j
u_j
uj? 和
u
j
+
δ
u_j +\delta
uj?+δ 位于决策边界的相对两侧? 而使用随机扰动是一种低效的策略, 因为接近决策边界的方向子集只是环境空间的一小部分. 那么, 可以考虑对第二个随机选择的未标记样本
u
k
u_k
uk? 进行插值
u
j
+
δ
=
M
i
x
λ
(
u
j
,
u
k
)
u_j + \delta = {\rm Mix}_\lambda(u_j , u_k)
uj?+δ=Mixλ?(uj?,uk?). 而两个未标记的样本
u
j
u_j
uj? 和
u
k
u_k
uk? 存在下面三种情况:
- (1) 位于同一蔟中
- (2) 位于不同的蔟中, 但属于同一类别
- (3) 位于不同的蔟上, 属于不同的类别
由聚类假设, (1)的概率随类别数的增加而降低. 如果假设每个类别的聚类数是平衡的, 则(2)的概率较低;最后, (3)的概率最高. 然后, 假设
(
u
j
,
u
k
)
(u_j,u_k)
(uj?,uk?) 之一位于决策边界附近(它是执行一致性的一个很好的候选者), 则由于(3)的概率很高, 朝向
u
k
u_k
uk? 的插值可能指向低密度区域, 其次是另一类的聚类. 由于这是移动决策的不错的方向, 因此对于基于一致性的正则化, 插值是一个很好的扰动.
到目前为止, 随机未标记样本之间的插值可能会落在低密度区域, 因此, 这种插值可以应用基于一致性的正则化. 但是应该如何标记这些插值呢? 与单个未标记样本
u
j
u_j
uj? 的随机或对抗性扰动不同, ICT 涉及两个未标记示例 $(u_j,u_k). 直观地说, 我们希望将决策边界尽可能地推离类别边界, 因为众所周知, 具有大边距的决策边界可以更好地泛化.
在监督学习环境中, mixup 是实现大边距决策边界的一种方法. 在 mixup 中, 通过强制预测模型在样本之间线性变化, 将决策边界推离类别边界, 通过式(2)来完成. 在这里, 通过训练模型
f
θ
f_\theta
fθ? 来预测
M
i
x
λ
(
u
j
,
u
k
)
{\rm Mix}_λ(u_j,u_k)
Mixλ?(uj?,uk?) 的"假标签"
M
i
x
λ
(
f
θ
′
(
u
j
)
,
f
θ
′
(
u
k
)
)
{\rm Mix}_\lambda (f_{\theta'}(u_j),f_{\theta'}(u_k))
Mixλ?(fθ′?(uj?),fθ′?(uk?)) 来将 mixup 扩展到半监督学习. 其中
θ
′
\theta'
θ′ 是
θ
\theta
θ 的滑动平均, 同 Mean-Teacher 中 Teacher model 的
θ
′
\theta'
θ′ 的计算一样.
ICT 模型如下图所示:
- 从联合分布
P
X
Y
(
X
,
Y
)
P_{XY}(X,Y)
PXY?(X,Y) 提取标记样本
(
x
i
,
y
i
)
(x_i,y_i)
(xi?,yi?) 记为
D
L
\mathcal{D}_L
DL?.
- 从边缘分布
P
X
(
X
)
=
P
X
Y
(
X
,
Y
)
P
Y
∣
X
(
Y
∣
X
)
P_X(X)=\frac{P_{XY}(X,Y)}{P_{Y\vert X}(Y \vert X)}
PX?(X)=PY∣X?(Y∣X)PXY?(X,Y)? 提取未标记样本
u
j
u_j
uj?,
u
k
u_k
uk? 记为
D
u
l
\mathcal{D}_{ul}
Dul?.
- 学习目标是训练一个模型
f
θ
f_\theta
fθ?, 能够从
X
X
X 预测
Y
Y
Y. 通过使用随机梯度下降, 在每次迭代
t
t
t 时, 更新参数
θ
\theta
θ 以最小化损失函数
L
=
L
S
+
w
(
t
)
L
U
S
L=L_S+w(t)L_{US}
L=LS?+w(t)LUS?. 其中
L
S
L_S
LS? 为标记样本上的交叉熵损失,
L
U
S
L_{US}
LUS? 为新的插值上的一致性正则化损失. 两个损失都是在 minibatch 上进行计算的, 每次迭代后
w
(
t
)
w(t)
w(t) 都会 ramp up, 以增加
L
U
S
L_{US}
LUS? 的重要性.
为了计算
L
U
S
L_{US}
LUS?, 对两个小批量未标记点
u
j
u_j
uj? 和
u
k
u_k
uk? 进行采样, 并计算它们的假标签
y
^
j
=
f
θ
′
(
u
j
)
\hat{y}_j=f_{\theta'}(u_j)
y^?j?=fθ′?(uj?) 和
y
^
k
=
f
θ
′
(
u
k
)
\hat{y}_k=f_{\theta'}(u_k)
y^?k?=fθ′?(uk?), 然后, 计算插值
u
m
=
M
i
x
λ
(
u
j
,
u
k
)
u_m={\rm Mix}_λ(u_j,u_k)
um?=Mixλ?(uj?,uk?), 以及该位置的模型预测
y
^
m
=
f
θ
′
(
u
m
)
\hat{y}_m=f_{\theta'}(u_m)
y^?m?=fθ′?(um?). 接着, 更新参数
θ
\theta
θ 以使预测
y
^
m
\hat{y}_m
y^?m? 更接近于假标签的插值
M
i
x
λ
(
y
^
j
,
y
^
k
)
{\rm Mix}_λ(\hat{y}_j,\hat{y}_k)
Mixλ?(y^?j?,y^?k?). 预测
y
^
m
\hat{y}_m
y^?m? 和
M
i
x
λ
(
y
^
j
,
y
^
k
)
{\rm Mix}_λ(\hat{y}_j,\hat{y}_k)
Mixλ?(y^?j?,y^?k?) 之间的差异可以使用任何损失来衡量, 在本文实验中, 使用的是均方误差. 对于式(1), 在每次更新时, 从
B
e
t
a
(
α
,
α
)
{\rm Beta}(\alpha,\alpha)
Beta(α,α) 中随机抽取一个
λ
\lambda
λ.
综上,
L
U
L
L_{UL}
LUL? 可被写为:
L
U
L
=
E
u
j
,
u
k
E
λ
l
(
f
θ
(
M
i
x
λ
(
u
j
,
u
k
)
)
,
M
i
x
λ
(
f
θ
′
(
u
j
)
,
f
θ
′
(
u
k
)
)
)
(3)
\mathcal{L}_{UL}=\underset{u_j,u_k}{\mathbb{E}}\underset{\lambda}{\mathbb{E}} \mathcal{l}(f_\theta({\rm Mix}_{\lambda}(u_j,u_k)),{\rm Mix}_\lambda(f_{\theta'}(u_j),f_{\theta'}(u_k))) \tag{3}
LUL?=uj?,uk?E?λE?l(fθ?(Mixλ?(uj?,uk?)),Mixλ?(fθ′?(uj?),fθ′?(uk?)))(3)
ICT 算法流程如下: 算法代码: https://github.com/vikasverma1077/ICT
|