IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> [半监督学习] MixMatch: A Holistic Approach to Semi-Supervised Learning -> 正文阅读

[人工智能][半监督学习] MixMatch: A Holistic Approach to Semi-Supervised Learning

统一用于半监督学习的主要方法, 以生成一种新算法 MixMatch, 该算法预测数据增强后未标记示例的低熵标签, 并使用 MixUp 混合标记和未标记数据. MixMatch 在许多数据集中获得了先进的结果.

半监督学习方法中的一类为添加一个损失项, 该损失项是在未标记的数据上计算的. 在许多工作中, 这个损失项包含以下三类:

  • 最小化熵. 它鼓励模型对未标记的数据输出可信的预测.
  • 一致性正则化. 它鼓励模型在其输入受到扰动时产生相同的输出分布.
  • 通用正则化. 它鼓励模型很好地泛化并避免过度拟合训练数据.

本文介绍了 MixMatch, 这是一种 SSL 算法, 它引入了一个单一的损失, 将这些主要方法优雅地统一到半监督学习中. 与以前的方法不同, MixMatch 一次针对所有属性, 其贡献如下:

  • 实验证明, MixMatch 在所有标准图像基准上都获得了最先进的结果, 并将 CIFAR-10 的错误率降低了 4 倍.
  • 在消融研究中进一步表明 MixMatch 大于其部分的总和.
  • MixMatch 对于差分隐私学习很有用, 使 PATE 框架中的 Students 能够获得新的最先进的结果, 同时加强隐私保证和准确性.

简而言之, MixMatch 为未标记数据引入了一个统一的损失项, 可以无缝地降低熵, 同时保持一致性并与传统的正则化技术保持兼容.

1. 损失项

1.1 一致性正则化(Consistency Regularization)

监督学习中一种常见的正则化技术是数据增强. 例如, 在图像分类中, 通常对输入图像进行弹性变形或添加噪声, 这可以在不改变其标签的情况下显着改变图像的像素内容. 这可以通过生成近乎无限的新修改数据流来人为地扩大训练集的大小. 一致性正则化将数据增强应用于半监督学习, 其强制一个未标记的示例 x x x 的分类应该与 D a t a A u g m e n t ( x ) {\rm DataAugment}(x) DataAugment(x) 的分类相同.

在我前面的几篇文章中介绍了利用一致性正则化的方法: Π \Pi Π-model, Temporal Ensembling, Mean-Teacher, Dual-Student, VAT, ICT, UDA.

1.2 熵最小化(Entropy Minimization)

许多半监督学习方法中一个常见的基本假设是分类器的决策边界不应穿过边缘数据分布的高密度区域. 强制执行此操作的一种方法是要求分类器输出对未标记数据的低熵预测. 损失项使未标记数据 x x x p m o d e l ( y ∣ x ; θ ) p_{model}(y | x; \theta) pmodel?(yx;θ) 的熵最小化. 这种形式的熵最小化与 VAT 相结合以获得更强的结果. "伪标签"通过对未标记数据的高置信度预测构建 one-hot 标签并将其用作标准交叉熵损失中的训练目标, 从而隐式地进行熵最小化. MixMatch 还通过在未标记数据的目标分布上使用"锐化"函数来隐式实现熵最小化.

1.3 通正则化(Traditional Regularization)

正则化是对模型施加约束, 希望其能更好地泛化数据. 使用权重衰减来惩罚模型参数的 L2 范数, 还在 MixMatch 中使用 MixUp. 将 MixUp 用作正则化器(应用于标记数据点)和半监督学习方法(应用于未标记数据点). MixUp 方法在之前就已经应用于半监督学习, 例如 ICT.

2. MixMatch

MixMatch 是一种"整体"方法, 它结合了 SSL 主流范式中的思想. 给定一个 batch 的 X \mathcal{X} X, 其标签为 one-hot 编码, 代表 L L L 个可能标签中的一个, 和一个相同大小 batch 的未标记示例 U \mathcal{U} U, MixMatch 生成一批经过处理的增强标记示例 X ′ \mathcal{X}' X 和一批具有"猜测"标签 U ′ \mathcal{U}' U 的增强未标记示例. U ′ \mathcal{U}' U X ′ \mathcal{X}' X 用于计算单独的标记和未标记损失项.

MixMatch 中使用的标签猜测过程如下图:
在这里插入图片描述
将随机数据增强(Stochastic data augmentation)应用于未标记的图像 K K K 次, 每个增强后的图像都通过分类器进行输入. 然后, 通过调整温度分布(distribution’s temperature)来"锐化"这 K K K 个预测的平均值.

MixMatch 整体算法如下:
在这里插入图片描述
MixMatch 具体步骤如下:

2.1 数据增强(Data Augmentation)

正如许多 SSL 方法中的典型情况一样, 对标记和未标记的数据都使用数据增强. 对 X \mathcal{X} X 中的每个 x b x_b xb?, 生成一个转换后的版本 x ^ b = A u g m e n t ( x b ) \hat{x}_b = \mathrm{Augment}(x_b) x^b?=Augment(xb?), 在上面 algorithm 1的第3行. 对未标记数据 U \mathcal{U} U 中的每个 u b u_b ub?, 生成 K K K 个增强 u ^ b , k = A u g m e n t ( u b ) , k ∈ ( 1 , … , K ) \hat{u}_{b,k} = \mathrm{Augment}(u_b), k \in (1, \dots ,K) u^b,k?=Augment(ub?),k(1,,K), 在上面 algorithm 1的第5行. 使用这些单独的增强来为每个 u b u_b ub? 生成一个"猜测标签" q b q_b qb?. 其中 batch_size= B B B.

2.2 标签猜测(Label Guessing)

对于 U \mathcal{U} U 中的每个未标记示例, 增强后使用 MixMatch 模型的预测为示例的标签生成一个"猜测"值. 这个猜测在后来被用在无监督损失项中. 为此, 计算其平均值, 在上面 algorithm 1的第7行:
q  ̄ b = 1 K ∑ K = 1 K P m o d e l ( y ∣ u ^ b , k ; θ ) (1) \overline{q}_b=\frac{1}{K} \sum_{K=1}^K P_{model}(y \vert \hat{u}_{b,k}; \theta) \tag{1} q?b?=K1?K=1K?Pmodel?(yu^b,k?;θ)(1)
锐化操作(Sharpening): 在生成标签猜测时, 执行一个额外的步骤, 即对于给定增强的平均预测 q  ̄ b \overline{q}_b q?b?, 应用锐化函数来减少标签分布的熵. 在实践中, 调整分类的"温度"分布是常用方法, 在上面 algorithm 1的第8行, 它被定义为:
S h a r p e n ( p , T ) i = p i 1 T / ∑ j = 1 L p j 1 T (2) \mathrm{Sharpen}(p,T)_i=p_i^{\frac{1}{T}} \bigg/ \sum_{j=1}^L p_j^{\frac{1}{T}} \tag{2} Sharpen(p,T)i?=piT1??/j=1L?pjT1??(2)
其中 p p p 是一些输入分类的分布(特别是在 MixMatch 中, p p p 是对增强 q  ̄ b \overline{q}_b q?b? 的平均类别预测, T T T 是超参数. 随着 T → 0 T \rightarrow 0 T0, S h a r p e n ( p , T ) \mathrm{Sharpen}(p, T) Sharpen(p,T) 的输出将接近 Dirac(“one-hot”)分布.

2.3 MixUp

使用 MixUp 进行半监督学习, 与过去的 SSL 工作不同, 将标记示例和未标记示例与猜测标签混合在一起. 为了与损失兼容, 这里定义了一个修改过的 MixUp 版本. 对于具有相应标签概率的两个示例 ( x 1 , p 1 ) (x_1, p_1) (x1?,p1?), ( x 2 , p 2 ) (x_2, p_2) (x2?,p2?), 通过以下式计算 ( x ′ , p ′ ) (x', p') (x,p):
λ ~ B e t a ( α , α ) (3) \lambda \sim \mathrm{Beta}(\alpha,\alpha) \tag{3} λBeta(α,α)(3)
λ ′ = max ? ( λ , 1 ? λ ) (4) \lambda'=\max(\lambda,1-\lambda) \tag{4} λ=max(λ,1?λ)(4)
x ′ = λ ′ x 1 + ( 1 ? λ ′ ) x 2 (5) x'=\lambda'x_1+(1-\lambda')x_2 \tag{5} x=λx1?+(1?λ)x2?(5)
p ′ = λ ′ p 1 + ( 1 ? λ ′ ) p 2 (6) p'=\lambda'p_1+(1-\lambda')p_2 \tag{6} p=λp1?+(1?λ)p2?(6)
其中 α \alpha α 是超参数. 将 ( x ′ , p ′ ) (x', p') (x,p) 作为增强数据或者虚拟训练数据. 为了应用 MixUp, 见上面 algorithm 1的第10-11行, 首先将所有带标签的增强标记示例和所有带猜测标签的未标记示例收集到:
X ^ = ( ( x ^ b , p b ) ; b ∈ ( 1 , … , B ) ) (7) \hat{\mathcal{X}}=((\hat{x}_b,p_b);b\in(1,\dots,B)) \tag{7} X^=((x^b?,pb?);b(1,,B))(7)
U ^ = ( ( u ^ b , k , q b ) ; b ∈ ( 1 , … , B ) , k ∈ ( 1 , … , K ) ) (8) \hat{\mathcal{U}}=((\hat{u}_{b,k},q_b);b\in(1,\dots,B),k\in(1,\dots,K)) \tag{8} U^=((u^b,k?,qb?);b(1,,B),k(1,,K))(8)
然后, 组合这些集合并将结果打乱以形成 W \mathcal{W} W, 它将作为 MixUp 的数据源, 见上面 algorithm 1的第12行. 对于 X ^ \hat{\mathcal{X}} X^ 中的第 i i i 个带标签的示例对, 计算 M i x U p ( X ^ i , W i ) {\rm MixUp}(\hat{\mathcal{X}}_i, \mathcal{W}_i) MixUp(X^i?,Wi?) 其中 i ∈ ( 1 , … , ∣ X ^ ∣ ) i \in (1, \dots, \vert \hat{\mathcal{X}}\vert) i(1,,X^), 并将结果添加到 X ′ \mathcal{X}' X 里, 见上面 algorithm 1的第13行. 计算 U i ′ = M i x U p ( U ^ i , W i + ∣ X ^ ∣ ) \mathcal{U}'_i = \mathrm{MixUp}( \hat{\mathcal{U}}_i, \mathcal{W}_{i+\vert \hat{\mathcal{X}}\vert}) Ui?=MixUp(U^i?,Wi+X^?) 其中 i ∈ ( 1 , … , ∣ U ^ ∣ ) i \in (1, \dots, \vert \hat{\mathcal{U}}\vert) i(1,,U^), 见上面 algorithm 1的第14行.

总而言之, MixMatch 将 X \mathcal{X} X 转换为 X ′ \mathcal{X}' X, 这是一组应用了数据增强和 MixUp 的标记示例. 类似地, U \mathcal{U} U 被转换为 U ′ \mathcal{U}' U, 即每个未标记示例的多个增强的集合, 并带有相应的猜测标签. X ′ \mathcal{X}' X, U ′ \mathcal{U}' U 是对有标签和无标签数据进行增强之后得到的新训练数据.

2.4 损失函数(Loss Function)

组合损失 L \mathcal{L} L 定义如下:
X ′ , U ′ = M i x M a t c h ( X , U , T , K , α ) (9) \mathcal{X}',\mathcal{U}'=\mathrm{MixMatch}(\mathcal{X},\mathcal{U},T,K,\alpha) \tag{9} X,U=MixMatch(X,U,T,K,α)(9)
L X = 1 ∣ X ′ ∣ ∑ x , p ∈ X ′ H ( p , P m o d e l ( y ∣ x ; θ ) ) (10) \mathcal{L}_{\mathcal{X}}=\frac{1}{\vert \mathcal{X}' \vert} \sum_{x,p \in \mathcal{X}'} \mathrm{H}(p,P_{model}(y \vert x; \theta)) \tag{10} LX?=X1?x,pX?H(p,Pmodel?(yx;θ))(10)
L U = 1 L ∣ U ′ ∣ ∑ u , q ∈ U ′ ∥ q ? P m o d e l ( y ∣ u ; θ ) ∥ 2 2 (11) \mathcal{L}_{\mathcal{U}}=\frac{1}{L\vert \mathcal{U}' \vert} \sum_{u,q \in \mathcal{U}'} \lVert q-P_{model}(y \vert u; \theta)\rVert_2^2 \tag{11} LU?=LU1?u,qU?q?Pmodel?(yu;θ)22?(11)
L = L X + λ U L U (12) \mathcal{L}=\mathcal{L}_{\mathcal{X}}+\lambda_{\mathcal{U}}\mathcal{L}_{\mathcal{U}} \tag{12} L=LX?+λU?LU?(12)
其中 H ( p , q ) \mathrm{H}(p, q) H(p,q) 是分布 p p p q q q 之间的交叉熵, T T T K K K α \alpha α λ U \lambda_{\mathcal{U}} λU? 是超参数. 式子(12)将来自 X ′ \mathcal{X}' X 的有标签数据之间的交叉熵损失与来自 U ′ \mathcal{U}' U 的预测和猜测标签的平方 L 2 L_2 L2? 损失相结合.

2.5 一些超参数设置

在实践中发现, MixMatch 的大多数超参数都可以固定, 不需要在每个实验或每个数据集的基础上进行调整. 具体来说, 对于所有实验, 设置 T = 0.5 T = 0.5 T=0.5, K = 2 K = 2 K=2. 此外, 仅在每个数据集的基础上更改 α \alpha α λ U \lambda_{\mathcal{U}} λU?. 文中发现 α = 0.75 \alpha = 0.75 α=0.75 λ U = 100 \lambda_{\mathcal{U}}=100 λU?=100 是调整的良好起点.

TensorFlow 版本代码: https://github.com/google-research/mixmatch
PyTorch 版本代码: https://github.com/YU1ut/MixMatch-pytorch

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-03 01:12:38  更:2022-02-03 01:13:41 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 20:27:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码