IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> MixMatch文章解读+算法流程+核心代码详解 -> 正文阅读

[人工智能]MixMatch文章解读+算法流程+核心代码详解

MixMatch

本博客仅做算法流程疏导,具体细节请参见原文

原文

查看原文点这里

Github代码

Github代码点这里

解读

MixMatch抓住了半监督算法的两个重要观点:第一是熵最小化;第二是一致性正则化。结合这两个观点的算法就形成了MixMatch。

熵最小化

半监督算法的一个常见假设就是分类的决策边界不应该通过数据分布的高密度区域。这句话简单的理解可以想象一个聚类模型,其决策边界一定是在簇与簇之间的稀疏边界上,不可能穿过一个簇的中心(高密度区域)。而实现这一点的一种方法就是要求分类器对未标记数据输出低熵预测。MixMatch中使用一个"sharpening"函数来隐式实现熵最小化。所谓熵最小化、低熵预测,都是指使输出概率分布比较有“偏向性”,而不希望输出一个“平均的预测”。熵在信息论中是不确定度的度量,根据离散模型的熵最大定理,可知在均匀分布时熵取得最大值,换句话说,出现一个确定的分布,即某一类的概率是1,其余类的概率是0时,熵为0。也就是说想要得到熵最小,就得使分类器输出后的模型预测概率集中分配给某一类。后面再介绍“sharpening”函数如何实现这一点。

一致性正则化

一致性正则化也是一个常见的半监督假设。VATMeanTeacher等其实都或多或少使用了这种假设。其核心在于,我们希望一个样本和其加扰版本(通常图像中称为Augment)通过分类器后,得到相似的输出。其实也就是说分类边界不应该穿过数据分布的高密度区域。如下图,红色点是原始样本,蓝色和绿色为其扰动版本,红色同心圆的虚线圆是我们期望的容差范围,即在这个区间类的都应该认为和其中心数据点为同一类。通过扰动数据点的加入,将决策边界推到合适的位置,使分类器的鲁棒性更强。

Consistency Regularization

一般而言,通过对原始样本和其扰动版本的分类器输出进行衡量,即可实现一致性正则化,常见的衡量方式有MSE、KL散度、JS散度等。在MixMatch中通过对图像的标准数据增强(水平翻转、裁剪)实现扰动(Augment),采用MSE准则方式衡量。

总得来说,算法有以下步骤:

屏幕截图 2021-07-31 154634

归结而言有五个步骤:

第一步,对数据进行扩增(Augment)。扩增分为对有标记数据集 X X X?的扩增和对无标记数据集 U U U?的扩增,分别记为 X ^ \hat{X} X^?和 U ^ \hat{U} U^?。对 X X X?扩增一次,对 U U U?扩增 K K K?次,文章中取 K = 2 K=2 K=2?。因为在取batch时, B a t c h S i z e U = B a t c h S i z e X Batch Size _U = BatchSize_X BatchSizeU?=BatchSizeX??,所以扩增后 B a t c h S i z e U ^ = K ? B a t c h S i z e X ^ Batch Size _{\hat{U}} = K\cdot BatchSize_{\hat{X}} BatchSizeU^?=K?BatchSizeX^?????。

第二步,计算平均预测分布。此步骤仅对数据集 U ^ \hat{U} U^???进行。即通过如下公式计算,其中 ( u b , k ^ , y ) (\hat{u_{b,k}},y) (ub,k?^?,y)?是 U ^ \hat{U} U^?的一个 B a t c h Batch Batch?:

q b ˉ = 1 K ∑ k P m o d e l ( y ∣ u b , k ^ ; θ ) \bar{q_b}=\frac{1}{K}\sum_kP_{model}(y|\hat{u_{b,k}};\theta) qb?ˉ?=K1?k?Pmodel?(yub,k?^?;θ)

值得注意的是, P m o d e l ( y ∣ u b , k ^ ; θ ) P_{model}(y|\hat{u_{b,k}};\theta) Pmodel?(yub,k?^?;θ) S o f t m a x Softmax Softmax?之后的预测概率分布。

第三步,通过 s h a r p e n i n g sharpening sharpening函数完成分布的锐化,其计算公式如下:

S h a r p e n ( p , T ) i = p i 1 T ∑ j = 1 L p j 1 T Sharpen(p,T)_i=\frac{p_i^{\frac{1}{T}}}{\sum^L_{j=1}p_j^{\frac{1}{T}}} Sharpen(p,T)i?=j=1L?pjT1??piT1????

当超参数 T → 0 T\to 0 T0?时, S h a r p e n ( p , T ) Sharpen(p,T) Sharpen(p,T)?趋向于 o n e ? h o t one-hot one?hot??分布,即其中一个类别的概率为1,其余概率为0;锐化后的概率分布作为 U ^ \hat{U} U^?的数据标签(pseudo label)。

第四步,通过 M i x U p MixUp MixUp完成新数据集的构建。先将第一步扩增后的 X ^ \hat{X} X^ U ^ \hat{U} U^进行拼接再打乱顺序,得到 W = S h u f f l e ( C o n c a t ( X ^ , U ^ ) ) W=Shuffle(Concat(\hat{X},\hat{U})) W=Shuffle(Concat(X^,U^)),然后再将 W W W分为两部分,第一部分大小与 X ^ \hat{X} X^相同(也与 X X X相同),记为 W x W_x Wx?;第二部分大小与 U ^ \hat{U} U^相同(也与 U U U相同),记为 W u W_u Wu?。然后将 W x W_x Wx? X ^ \hat{X} X^进行 M i x U p MixUp MixUp W u W_u Wu? U ^ \hat{U} U^进行 M i x U p MixUp MixUp,得到 X ′ X' X U ′ U' U?。 M i x U p MixUp MixUp步骤如下:
λ ~ B e t a ( α , α ) \lambda\sim Beta(\alpha,\alpha) λBeta(α,α)
λ ′ = m a x ( λ , 1 ? λ ) \lambda'=max(\lambda,1-\lambda) λ=max(λ,1?λ)
x ′ = λ ′ x 1 + ( 1 ? λ ′ ) x 2 x'=\lambda'x_1+(1-\lambda')x_2 x=λx1?+(1?λ)x2?
p ′ = λ ′ p 1 + ( 1 ? λ ′ ) p 2 p'=\lambda'p_1+(1-\lambda')p_2 p=λp1?+(1?λ)p2?

第五步,计算半监督损失函数,分为在标记数据集 X ′ X' X?上的损失函数 L x L_x Lx?和在无标记数据集 U ′ U' U上的损失函数 L u L_u Lu?,公式如下:

L x = 1 ∣ X ′ ∣ ∑ x , p ∈ X ′ H ( p , P m o d e l ( y ∣ x ; θ ) ) L_x=\frac{1}{|X'|}\sum_{x,p\in X'}H(p,P_{model}(y|x;\theta)) Lx?=X1?x,pX?H(p,Pmodel?(yx;θ))

L u = 1 L ∣ U ′ ∣ ∑ u , q ∈ U ′ ∣ ∣ q ? P m o d e l ( y ∣ u ; θ ) ∣ ∣ 2 2 L_u=\frac{1}{L|U'|}\sum_{u,q\in U'}||q-P_{model}(y|u;\theta)||^2_2 Lu?=LU1?u,qU?q?Pmodel?(yu;θ)22?

L = L x + λ U L u L=L_x+\lambda_UL_u L=Lx?+λU?Lu?

其中 H ( ? ) H(\cdot) H(?)?是 C o r s s E n t r o p y L o s s CorssEntropyLoss CorssEntropyLoss?; L u L_u Lu?其实就是 M S E MSE MSE准则下的误差项。?

反向梯度传播即可完成整个MixMatch算法

核心代码详解

图像的水平翻转、裁剪实现 A u g m e n t Augment Augment

transform_train = transforms.Compose([
    dataset.RandomPadandCrop(32),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

transform_val = transforms.Compose([
    dataset.ToTensor(),
])

这里是在迭代过程中,手动取迭代器中的batch,而不是直接使用Dataloader。这种做法在最近的几篇文章代码复现中都遇见了,其主要目的是为了在一个epoch中可以迭代指定次数,而直接使用Dataloader只能迭代最多 c e i l ( 样 本 总 数 B a t c h S i z e ) ceil(\frac{样本总数}{BatchSize}) ceil(BatchSize?)次,其中 c e i l ( ? ) ceil(\cdot) ceil(?)是上取整函数,如果 d r o p l a s t drop_last dropl?ast,则只能迭代 样 本 总 数 B a t c h S i z e \frac{样本总数}{BatchSize} BatchSize?次。代码中的两个try except是为了保证迭代器完全迭代一次后,重新加载迭代器,继续迭代,直到达到指定次数才跳转下一个epoch。

for batch_idx in range(args.train_iteration):
     try:
         inputs_x, targets_x = labeled_train_iter.next()
     except:
         labeled_train_iter = iter(labeled_trainloader)
         inputs_x, targets_x = labeled_train_iter.next()

     try:
         (inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
     except:
         unlabeled_train_iter = iter(unlabeled_trainloader)
         (inputs_u, inputs_u2), _ = unlabeled_train_iter.next()

因为文章中取 K = 2 K=2 K=2,所以进行两次扩增,求输出概率的均值,其中output_uoutput_u2分别为两次扩增后的模型输出结果:

outputs_u = model(inputs_u)
outputs_u2 = model(inputs_u2)
p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2  # 求两次的平均值

求Sharpening结果:

pt = p**(1/args.T)
targets_u = pt / pt.sum(dim=1, keepdim=True)
targets_u = targets_u.detach()

完成 M i x U p MixUp MixUp:

all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
l = np.random.beta(args.alpha, args.alpha)
l = max(l, 1-l)
idx = torch.randperm(all_inputs.size(0))
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
mixed_input = l * input_a + (1 - l) * input_b
mixed_target = l * target_a + (1 - l) * target_b

然后计算损失函数:

logits = [model(mixed_input[0])]
for input in mixed_input[1:]:
    logits.append(model(input))
# put interleaved samples back
logits = interleave(logits, batch_size)
logits_x = logits[0]
logits_u = torch.cat(logits[1:], dim=0)

Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.train_iteration)

loss = Lx + w * Lu

反向梯度传播,结束。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-01 14:30:37  更:2021-08-01 14:31:07 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 10:42:23-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码