在 FixMatch 中, 对所有类别使用预定义的常量阈值来选择有助于训练的未标记数据, 因此无法考虑不同类别的不同学习状态和学习难度, UDA 也是如此. 为解决这个问题, 提出了课程伪标签(Curriculum Pseudo Labeling, CPL), 这是一种根据模型的学习状态来利用未标记数据的课程学习方法. CPL 的核心是在不同时刻灵活地调整不同类别的阈值.
FlexMatch 使用了 CPL, CPL 是一种课程学习(Curriculum Learning)策略, 考虑到半监督学习中不同的学习状态, CPL 将预定义的阈值替换为灵活的阈值. FlexMatch 只需不到 FixMatch 训练时间的1/5就可以达到最终精度.
课程学习(Curriculum Learning)
根据样本的难易程度, 给不同难度的训练样本分配不同的权重. 初始阶段, 给简单样本的权重最高, 随着训练过程的持续, 较难样本的权重将会逐渐被调高. 将权重动态分配的过程称之为课程(Curriculum), 课程初始阶段简易样本居多, 课程末尾阶段样本难度增加, 即"先易后难".
针对不同的实际问题可以设置不同的样本难易程度评价标准. 例如对于一个原始样本, 对其进行强扰动后, 样本的就由简单变向复杂.
课程伪标签(Curriculum Pseudo Labeling, CPL)
根据学习状态来动态确定阈值并非易事. 最理想的方法是计算每个类的评估准确度并使用它们来缩放阈值:
τ
t
(
c
)
=
a
t
(
c
)
?
τ
(1)
\tau_t(c)=a_t(c) \cdot\tau \tag{1}
τt?(c)=at?(c)?τ(1) 其中
τ
t
(
c
)
\tau_t(c)
τt?(c) 是
t
t
t 时刻
c
c
c 类别的灵活阈值,
a
t
(
c
)
a_t(c)
at?(c) 是相应的评估精度. 由于不能在模型学习过程中使用评估集, 因此必须从训练集中分离一个额外的验证集来进行准确性评估. 但是在 SSL 中, 标记数据原本就十分稀缺, 不能再剥离一部分出去. 其次, 为了在训练过程中动态调整阈值, 必须连续在每个时刻
t
t
t 进行准确度评估, 这将大大减慢训练速度.
为解决上述问题, CPL 使用另一种方法来估计学习状态, 它不引入额外的推理过程, 也不需要额外的验证集. 其关键假设是, 通过预测属于该类且高于阈值的样本数量来反映一个类的学习效果, 然后使用它们来调整阈值
τ
τ
\tau_τ
ττ?. 如下图所示: 如果一个类具有较少样本且其预测置信度达到阈值, 则称其具有较大的学习难度或较差的学习状态:
σ
t
(
c
)
=
∑
n
=
1
N
1
(
max
?
(
p
m
,
t
(
y
∣
u
n
)
)
>
τ
)
?
1
(
arg?max
?
(
p
m
,
t
(
y
∣
u
n
)
)
=
c
)
(2)
\sigma_t(c)=\sum_{n=1}^N \mathbb{1}(\max(p_{m,t}(y\vert u_n))>\tau) \cdot \mathbb{1}(\argmax(p_{m,t}(y\vert u_n))=c) \tag{2}
σt?(c)=n=1∑N?1(max(pm,t?(y∣un?))>τ)?1(argmax(pm,t?(y∣un?))=c)(2) 其中
σ
t
(
c
)
\sigma_t(c)
σt?(c) 反映了类
c
c
c 在
t
t
t 时刻的学习效果.
p
m
,
t
(
y
∣
u
n
)
p_{m,t}(y\vert u_n)
pm,t?(y∣un?) 是模型在
t
t
t 时刻对未标记数据
u
n
u_n
un? 的预测,
N
N
N 是未标记数据的总数. 当未标记数据集是平衡的(即属于不同类别的未标记数据的数量相等或接近)时, 较大的
σ
t
(
c
)
\sigma_t(c)
σt?(c) 表示更好的学习效果. 通过对
σ
t
(
c
)
\sigma_t(c)
σt?(c) 应用以下归一化使其范围在 0 到 1 之间, 然后可以使用它来缩放固定阈值
τ
\tau
τ:
β
t
(
c
)
=
σ
t
(
c
)
max
?
c
σ
t
(3)
\beta_t(c)=\frac{\sigma_t(c)}{\underset{c}{\max}\sigma_t} \tag{3}
βt?(c)=cmax?σt?σt?(c)?(3)
τ
t
(
c
)
=
β
t
(
c
)
?
τ
(4)
\tau_t(c)=\beta_t(c) \cdot \tau \tag{4}
τt?(c)=βt?(c)?τ(4) 随着学习的进行, 学习状态良好的类的阈值会提高, 以选择性地提取更高质量的样本. 最终, 当所有类都达到可靠的准确度时, 阈值都将接近
τ
\tau
τ. 不过阈值并不总是增长态, 如果未标记的数据在后面的迭代中被分类到不同的类别, 阈值也可能会降低. 这个新阈值用于计算 FlexMatch 中的无监督损失, 可以表示为:
L
u
,
t
=
1
μ
B
∑
b
=
1
μ
B
1
(
max
?
(
q
b
)
≥
τ
t
)
H
(
q
^
b
,
p
m
(
y
∣
A
(
u
b
)
)
)
(5)
\mathcal{L}_{u,t}=\frac{1}{\mu B} \sum_{b=1}^{\mu B} \mathbb{1}(\max(q_b)\geq \tau_t) \mathrm{H}(\hat{q}_b,p_m(y\vert \mathcal{A}(u_b))) \tag{5}
Lu,t?=μB1?b=1∑μB?1(max(qb?)≥τt?)H(q^?b?,pm?(y∣A(ub?)))(5) 其中
q
b
=
p
m
(
y
∣
α
(
u
b
)
)
q_b=p_m(y\vert \alpha(u_b))
qb?=pm?(y∣α(ub?)), 这份损失的形式结构与 FixMatch 基本一致. 最后, FlexMatch 中的损失表示为有监督和无监督损失的加权组合:
L
t
=
L
s
+
λ
L
u
,
t
(6)
\mathcal{L}_t=\mathcal{L}_s+\lambda\mathcal{L}_{u,t} \tag{6}
Lt?=Ls?+λLu,t?(6) 其中
L
s
\mathcal{L}_s
Ls? 为有监督损失:
L
s
=
1
B
∑
b
=
1
B
H
(
y
b
,
p
m
(
y
∣
α
(
x
b
)
)
)
(7)
\mathcal{L}_{s}=\frac{1}{B} \sum_{b=1}^{B}\mathrm{H}(y_b,p_m(y\vert \alpha(x_b))) \tag{7}
Ls?=B1?b=1∑B?H(yb?,pm?(y∣α(xb?)))(7)
其他
为避免早阶段训练可能出现的盲目预测, 将式(3)改写为:
β
t
(
c
)
=
σ
t
(
c
)
max
?
{
max
?
c
σ
t
,
N
?
∑
c
σ
t
}
(8)
\beta_t(c)=\frac{\sigma_t(c)}{\max \{ \underset{c}{\max}\sigma_t,N-\underset{c}{\sum}\sigma_t \}\tag{8}}
βt?(c)=max{cmax?σt?,N?c∑?σt?}σt?(c)?(8) 这确保了在训练开始时, 所有估计的学习效果从 0 逐渐上升, 直到未使用的未标记数据的数量
N
?
∑
c
σ
t
N-\underset{c}{\sum}\sigma_t
N?c∑?σt? 不再占主导地位.
同时, 还提出一个非线性映射函数
M
\mathcal{M}
M, 当
β
t
(
c
)
\beta_t(c)
βt?(c) 均匀地从 0 到 1 范围内变化时, 使阈值具有非线性的增加曲线:
τ
t
(
c
)
=
M
(
β
t
(
c
)
)
?
τ
(9)
\tau_t(c)=\mathcal{M}(\beta_t(c)) \cdot \tau \tag{9}
τt?(c)=M(βt?(c))?τ(9) 显然, 如果
M
\mathcal{M}
M 为恒等函数时, 式(9)与式(4)相同. 并且映射函数是单调递增的, 最大值不大于
1
/
τ
1/\tau
1/τ. 在文献中, 选择凸函数
M
(
x
)
=
x
2
?
x
\mathcal{M}(x) = \frac{x}{2?x}
M(x)=2?xx? 作为映射函数.
FlexMatch 完整算法如下: 代码地址: https://github.com/TorchSSL/TorchSSL
|