IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> (二) OGNet 论文笔记 -> 正文阅读

[人工智能](二) OGNet 论文笔记


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


论文:Old is Gold: Redefining the Adversarially Learned One-Class Classifier Training Paradigm
代码:https://github.com/xaggi/OGNet
presentation

简介

这篇是韩国科技大学在2020年CVPR上发表文章,聚焦在使用生成对抗网络做单分类如异常检测的任务,在MNIST等数据上取得了不错的效果。

以前的基于生成对抗式网络做异常检测时,都是在训练时使用生成器和对抗器,在测试推理阶段,则只使用生成器,然后计算输入数据和生成器输出之间的差异性,来评估输入数据是否是异常数据.这种方法的前提假设是网络只在正常数据上进行训练,因此不管何种数据输入生成器后,生成器的输出都更像正常数据。这种方法有个漏洞,就是在推理时使用的生成器有可能可以比较好的重建没见过的数据,简而言之就是输入是异常数据时还能比较好的恢复异常数据,这时输入和生成器的输出差异较小,导致异常检测判断失效。

一个自然的想法是同时使用生成器和判别器来做异常检测,但同时使用判别器和生成器时训练时,使用 项指标来判断何时停止训练也是一个问题,同时使用判别器和生成器训练时,可以看到模型的评估结果振荡的也比较厉害.

将判别器的作用从判断生成器的输出是否是真实数据改成评估生成器重建图像的效果对于异常检测应该 更合适,因仅使用了正常数据训练,因此对于正常数据的重建效果应该更好.根据这种想法,这篇文档的 方法为,分两阶段训练two stage,先按普通的方法训练生成器,再训练判别器,训练判别器的数据有 重建效果好的数据如real data 和 生成的正常数据,重建效果差的数据如异常数据的生成数据,异常数据增强模块输出的数据.

stage one 中的 low-epoch Generator被当作 G^{old},用于生成stage two 中的训练数据 anomaly data, 不需要特定epoch中的G,stage two中对D的训练,只需要较少的迭代即可实现,因为其已经在stage one中预训练过 了,stage two训练时会冻结G的权重。

模块介绍

(1).模型整体结构

在这里插入图片描述

异常数据增强模块,pseudo-anomaly module,

在这里插入图片描述

(2).目标函数

  • Phase One

phase one 是训练生成对抗网络,与生成对抗卷积网络中使用的目标函数相同,
L G + D = m i n G m a x D ( E X ~ p t [ 1 ? l o g ( D ( X ) ) ] + E X ~ ~ p t + N σ [ l o g ( D ( G ( X ~ ) ) ) ] ) L_{G+D} = \mathop{min}\limits_{G}\mathop{max}\limits_{D}(\mathop{\mathbb{E}}\limits_{X\sim p_t}[1-log(D(X))] + \mathop{\mathbb{E}}\limits_{\tilde{X} \sim p_t + \N_\sigma}[log(D(G(\tilde{X})))]) LG+D?=Gmin?Dmax?(Xpt?E?[1?log(D(X))]+X~pt?+Nσ?E?[log(D(G(X~)))])

上式中 G G G是生成器, D D D是判别器, X X X是输入图像, X ~ \tilde{X} X~表示的是在 X X X上加上噪声 N σ N_\sigma Nσ?后得到的异常图像, p t p_t pt?表示的是输入数据的分布。

除了上面的常规GAN的目标函数,本文中还引入了均方误差作为生成器图像重建效果Reconstruction的衡量:

L R = m i n G ∣ ∣ X ? G ( X ) ∣ ∣ 2 L_R = \mathop{min}\limits_{G}||X - G(X)||^2 LR?=Gmin?X?G(X)2

综合方程(1)和(2),则Phase One使用的目标损失函数可写为:

L = L G + D + λ L R L = L_{G+D} + \lambda L_R L=LG+D?+λLR?

  • Phase Two

phase two冻结生成器G的参数,只更新判别器D的参数,以使判别器具备评估图像重建效果的能力。phase two训练使用的数据包括质量比较好的数据质量比较差的数据异常数据增强模块生成的数据质量比较好的数据由原始输入 X X X和生成器重建的 X ^ = G ( X ) \hat{X}=G(X) X^=G(X),质量比较差的数据包括使用保存的low epoch 生成器生成的低质量图像 X ^ l o w \hat{X}^{low} X^low,异常数据增强模块生成的数据 X ^ p s e u d o \hat{X}^{pseudo} X^pseudo指取 X i , X j , i ≠ j X_i,X_j,i\neq j Xi?,Xj?,i?=j经过 G o l d G^{old} Gold生成 X ^ i l o w , X ^ j l o w \hat{X}_i^{low},\hat{X}_j^{low} X^ilow?,X^jlow?,将二者求平均得 X ˉ ^ \hat{\bar{X}} Xˉ^,再使用 G o l d G^{old} Gold生成 X ^ p s e u d o \hat{X}^{pseudo} X^pseudo,异常数据生成模块的过程写成公式:
X ˉ ^ = G o l d ( X i ) + G o l d ( X j ) 2 = X ^ i l o w + X ^ j l o w 2 , i ≠ j X ^ p s e u d o = G ( X ˉ ^ ) \hat{\bar{X}} = \frac{G^{old}(X_i)+G^{old}(X_j)}{2} = \frac{\hat{X}_i^{low}+\hat{X}_j^{low}}{2}, i\neq j \\ \hat{X}^{pseudo} = G(\hat{\bar{X}}) Xˉ^=2Gold(Xi?)+Gold(Xj?)?=2X^ilow?+X^jlow??,i?=jX^pseudo=G(Xˉ^)

综上,phase two的目标函数写为:

m a x D ( α E X [ l o g ( 1 ? D ( X ) ) ] + ( 1 ? α ) E X ^ [ l o g ( 1 ? l o g ( D ( X ^ ) ) ) ] + β E X ^ l o w [ l o g ( D ( X ^ l o w ) ) ] + ( 1 ? β ) E X ^ p s e u d o [ l o g ( D ( X ^ p s e u d o ) ) ] ) \mathop{max}\limits_{D}(\alpha\mathop{\mathbb{E}}\limits_{X}[log(1-D(X))]+(1-\alpha)\mathop{\mathbb{E}}\limits_{\hat{X}}[log(1-log(D(\hat{X})))] + \beta\mathop{\mathbb{E}}\limits_{\hat{X}^{low}}[log(D(\hat{X}^{low}))] + (1-\beta)\mathop{\mathbb{E}}\limits_{\hat{X}^{pseudo}}[log(D(\hat{X}^{pseudo}))]) Dmax?(αXE?[log(1?D(X))]+(1?α)X^E?[log(1?log(D(X^)))]+βX^lowE?[log(D(X^low))]+(1?β)X^pseudoE?[log(D(X^pseudo))])

3.模型测试

测试时对于单分类任务,使用判别器的输出 c o n f i d e n c e = D ( G ( X ) ) confidence = D(G(X)) confidence=D(G(X))作为评分, c o n f i d e n c e > τ confidence \gt \tau confidence>τ为异常类,否则为正常类。
O C C = { n o r m a l c l a s s , i f D ( G ( X ) < τ a n o m a l y c l a s s , o t h e r w i s e OCC = \left\{\begin{matrix} normal class, if D(G(X) < \tau\\ anomaly class, otherwise \end{matrix}\right. OCC={normalclass,ifD(G(X)<τanomalyclass,otherwise?

总结一下本文的工作

  • 网络架构,分两阶段训练将判别器用于衡量图像的重建效果
  • 训练数据增强方法,使用low epoch G o l d G^{old} Gold作为增强的异常数据

实验

本文原作者在Caltech-256,MNIST,USCD Ped2数据集上都做了实验,取得了SOTA的结果。在MNIST数据集上对论文进行了复现,但对论文开源的代码稍做了修改。原文中在MNIST数据集上取0这个类别作为normal class,其余每个类别取一定的数据作为anomaly class,验证效果如下图:

在这里插入图片描述

复现的结果:
在这里插入图片描述

源码分析

github上作者开源的代码使用的时pytroch 1.2,版本比较老了。

模型的训练使用的时model.py中的train方法

d_fake_output = self.d(g_output)
d_real_output = self.d(input)
d_fake_loss = F.binary_cross_entropy(torch.squeeze(d_fake_output), fake)
d_real_loss = F.binary_cross_entropy(torch.squeeze(d_real_output), valid)
d_sum_loss = 0.5 * (d_fake_loss + d_real_loss)
d_sum_loss.backward(retain_graph=True)
d_optim.step()
g_optim.zero_grad()

##############################################
g_recon_loss = F.mse_loss(g_output, input)
g_adversarial_loss = F.binary_cross_entropy(d_fake_output.squeeze(), valid)
g_sum_loss = (1-self.adversarial_training_factor)*g_recon_loss + self.adversarial_training_factor*g_adversarial_loss
g_sum_loss.backward()
g_optim.step()

这段代码在新版本的pytorch上会报错,因d_optim.step会更新判别器的参数,而g_sum_loss中使用了d_fake_outputg_sum_loss.backward() 时会去计算判别器参数的梯度,因判别器的参数已被更新,还使用旧的输出计算梯度。将得到错误的梯度,故将报错。

本文中GAN的实现与pytorch 给出的DCGAN示例中的实现方式有所不同

DCGAN例子中在更新生成器参数时使用的d_fake_out是基于更新后的判别器参数重新计算的,即

d_fake_output = self.d(g_output.detach()) # mutation 1
d_real_output = self.d(input)
d_fake_loss = F.binary_cross_entropy(torch.squeeze(d_fake_output), fake)
d_real_loss = F.binary_cross_entropy(torch.squeeze(d_real_output), valid)
d_sum_loss = 0.5 * (d_fake_loss + d_real_loss)
d_sum_loss.backward() # mutation 2
d_optim.step()
g_optim.zero_grad()

##############################################
d_fake_output = self.d(g_output) # mutation 3
g_recon_loss = F.mse_loss(g_output, input)
g_adversarial_loss = F.binary_cross_entropy(d_fake_output.squeeze(), valid)
g_sum_loss = (1-self.adversarial_training_factor)*g_recon_loss + self.adversarial_training_factor*g_adversarial_loss
g_sum_loss.backward()
g_optim.step()

上述代码对原作者开源的代码做了三处改动,其实这里有个疑问,更新生成器参数计算梯度时,会计算判别器的梯度,而计算的梯度并没有用来更新判别器的参数,在进行下一个iteration的训练之前又被.zero_grad的置0,这应该造成了资源的浪费。

此外,在复现论文结果时,使用的方法与原代码有所不同,原作者的方式是每个一定的epoch进行一次phase two的训练,复现中使用的是先取phase one100epoch训练过程中AUC最大的权重,再基于此训练phase two,取phase twoAUC最大的权重作为最终的训练结果,测试效果见上图,在单分类上只使用正常数据训练取得这样的结果,算十分不错的。

参考资料


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-24 18:10:12  更:2022-05-24 18:13:32 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 5:40:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码