IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> (二)GANomaly论文及代码解析 -> 正文阅读

[人工智能](二)GANomaly论文及代码解析


欢迎访问个人网络日志🌹🌹知行空间🌹🌹

论文: https://arxiv.org/abs/1805.06725

代码: https://github.com/samet-akcay/ganomaly

1.介绍

GANomaly是英国杜伦大学(Durham University,QS前100)Samet Akcay等发表在ACCV2018上的会议论文。

这篇文章是期望提出一种可以只在正常数据上训练,但却能识别异常图像的方法。其提出了由对抗训练框架组成的通用异常检测模型。本文作者在编码-解码-编码式的架构中使用了对抗自编码器,获取训练数据在图像和隐式向量中的分布。

这篇文章主要贡献:

  • 半监督异常检测,提出了基于编码-解码再编码架构的对抗式自编码器,获取训练数据在图像和隐式向量空间的分布,取得比其他基于GAN网络和自编码器异常检测方法更好的效果

  • 代码开源

2.GANomaly网络组成

2.1 GAN简介

生成对抗网络Generative Adversarial Networks (GAN)是蒙特利尔大学Université de MontréalIan Goodfellow 2014年发表的论文(作者Ian Goodfellow最近因不让居家办公了从Apple公司离职媒体正宣传的热闹),GAN属于无监督机器学习算法,原来GAN模型的目标是为了生成原始数据, 其结构包括在训练过程中对抗的两部分,生成器和对抗器,生成器负责生成与source data尽可能相似的数据,判别器负责尽可能的找出生成器生成的fake data。关于GAN的更多介绍可参考:(一)深度卷积对抗网络DCGAN

2.2 问题定义

异常检测问题的正式定义:

  • 数据集,训练数据集 D \mathcal{D} D为只能包含 M M M个正常类别的训练数据 D = { X 1 , X 2 , . . . , X M } \mathcal{D}=\{X_1,X_2,...,X_M\} D={X1?,X2?,...,XM?},测试数据集 D ^ \hat{\mathcal{D}} D^为包含 N N N个正常和异常数据的集合, N N N通常比 M M M小很多。 D ^ = { ( X 1 ^ , y 1 ) , . . . , ( X 2 ^ , y 2 ) ) \hat{\mathcal{D}}=\{(\hat{X_1},y_1),...,(\hat{X_2}, y_2)) D^={(X1?^?,y1?),...,(X2?^?,y2?))
  • 目标模型学习数据 D \mathcal{D} D中的大多数公共性质,训练后在推理阶段检测测试数据集 D ^ \hat{\mathcal{D}} D^中的异常数据,模型 f f f学习正常数据的分布并最小化正常数据输入时模型的异常数据评分输出 A ( x ) \mathcal{A}(x) A(x),对于一个测试数据 x ^ \hat{x} x^,模型输出的异常评分 A ( x ^ ) \mathcal{A}(\hat{x}) A(x^)越高,表示输入时异常数据的可能性越大。设置阈值 ( ? ) (\phi) (?),当 A ( x ^ ) > ? \mathcal{A}(\hat{x})\gt\phi A(x^)>?时即认为是异常数据输入。

2.3网络结构

在这里插入图片描述

网络结构如上图,GANormaly主要包括3部分,以个自编码器,一个编码器和一个生成器。第一部分是一个蝴蝶结形的自编码网络作为模型的生成器,生成器学习输入数据表征,通过编码和解码网络重建输入数据。生成器中编码网络的输出 z z z也被成为生成器的瓶颈特征,并被认为其代表了包含输入数据最好表征的最小维度。第二部分是编码网络E将生成器的输出 x ^ \hat{x} x^压缩成低维的 z ^ \hat{z} z^, E E E和生成器 G G G中的编码网络 G E G_E GE?有着相同的结构,但参数不同,因此 z z z z ^ 维 度 大 小 相 同 \hat{z}维度大小相同 z^。以往的方法都是通过瓶颈特征来最小化隐式向量,GANomaly通过增加一个编码网络,显式的学习最小化特征距离。第3部分是判别器网路D,其目标是判别输入 x x x和生成器的输出 x ^ \hat{x} x^real还是fake

3.模型训练

因训练时只使用了正常类别的数据,可以假设即使生成器的编码器可以将输入数据 X X X映射到隐向量 z z z,判别器却不能够判别异常。因此生成器的输出 X ^ \hat{X} X^将会是去掉异常特征后的图像数据,再通过编码器 E E E,将 X ^ \hat{X} X^映射到特征向量 z ^ \hat{z} z^上,此时生成器中的编码器 G E G_E GE?输出的隐向量 z z z z ^ \hat{z} z^因一个包含图像异常,一个不包含,因此对于异常数据两者将有较大的差异,故可以识别出输入的异常数据。

GANomaly模型包含 3 3 3部分,其损失函数也包含3部分,每一部分损失分别对应网络的相应结构。

3.1对抗损失

Adversarial Loss
其使用的是特征对齐的损失函数,而非基于判别器输出, f f f是一个函数, 可以根据输入 x x x选择判别器的中间层来计算生成器对应层输出的 L 2 L_2 L2?距离。

L a d v = E x ~ p x ∣ ∣ f ( x ) ? E x ~ p x f ( G ( x ) ) ∣ ∣ 2 L_{adv} =\mathop{E}\limits_{x\sim px}||f(x) - \mathop{E}\limits_{x\sim px}f(G(x))||_2 Ladv?=xpxE?f(x)?xpxE?f(G(x))2?

上式中 p x px px x x x的分布

3.2 上下文损失

Contextual Loss
为了学习输入数据中的上下文信息,增加衡量输出数据 x x x和生成器重建数据 x ^ \hat{x} x^误差的上下文损失 L c o n L_{con} Lcon?

L c o n = E x ~ p x ∣ ∣ x ? G ( x ) ∣ ∣ 1 L_{con} = \mathop{E}\limits_{x\sim px}||x-G(x)||_1 Lcon?=xpxE?x?G(x)1?

3.3 编码器损失

Encoder Loss,前面两个损失函数不仅可以让生成的数据尽量真实,还能保存数据的上下文信息。引入Encoder Loss是为了使 G E G_E GE?输出的隐向量 z z z E E E生成的特征向量 z ^ \hat{z} z^的距离最小。

L e n c = E x ~ p x ∣ ∣ G E ( x ) ? E ( G ( X ) ) ∣ ∣ 2 L_{enc} = \mathop{E}\limits_{x\sim px}||G_E(x) - E(G(X))||_2 Lenc?=xpxE?GE?(x)?E(G(X))2?

最终,GANomaly的损失函数为:

L = ω a d v L a d v + ω c o n L c o n + ω e n c L e n c L = \omega_{adv}L_{adv}+\omega_{con}L_{con} + \omega_{enc}L_{enc} L=ωadv?Ladv?+ωcon?Lcon?+ωenc?Lenc?

4.模型测试

测试阶段模型使用 L e n c L_enc Le?nc作为一个输入图像异常程度的评分。因为通过训练阶段最小化 L e n c L_{enc} Lenc?,则对于异常图像 z z z z ^ \hat{z} z^差异会比较大。

A ( x ) = ∣ ∣ G E ( x ^ ) ? E ( G ( x ^ ) ) ∣ ∣ 1 \mathcal{A(x)} = ||G_E(\hat{x})-E(G(\hat{x}))||_1 A(x)=GE?(x^)?E(G(x^))1?

为了评估整体的异常性能,对测试数据集 D ^ \hat{\mathcal{D}} D^中的每个 x ^ \hat{x} x^计算其异常评分 A ( x ) \mathcal{A(x)} A(x),得测试数据集上每个数据对应的评分集合 S = { s i : A ( x i ^ ) , x i ^ ∈ D ^ } \mathcal{S}=\{s_i:\mathcal{A(\hat{x_i})}, \hat{x_i}\in\hat{\mathcal{D}}\} S={si?:A(xi?^?),xi?^?D^},对 S \mathcal{S} S中的元素缩放到 [ 0 , 1 ] [0,1] [0,1]

s i ′ = s i ? m i n ( S ) m a x ( S ) ? m i n ( S ) s'_i = \frac{s_i-min(S)}{max(S)-min(S)} si?=max(S)?min(S)si??min(S)?

5.代码分析

GANomaly的代码基于pytorch实现,代码使用方法说明的很清晰。

5.1数据加载

GANomaly数据加载使用的torchvision提供的ImageFolder类,只需按

Custom Dataset
├── test
│   ├── 0.normal
│   │   └── normal_tst_img_0.png
│   │   └── normal_tst_img_1.png
│   │   ...
│   │   └── normal_tst_img_n.png
│   ├── 1.abnormal
│   │   └── abnormal_tst_img_0.png
│   │   └── abnormal_tst_img_1.png
│   │   ...
│   │   └── abnormal_tst_img_m.png
├── train
│   ├── 0.normal
│   │   └── normal_tst_img_0.png
│   │   └── normal_tst_img_1.png
│   │   ...
│   │   └── normal_tst_img_t.png

这样的格式将数据存放好即可。

"""
加载自定义数据的代码
"""
splits = ['train', 'test']
drop_last_batch = {'train': True, 'test': False}
shuffle = {'train': True, 'test': True}
transform = transforms.Compose([transforms.Resize(opt.isize),
                              transforms.CenterCrop(opt.isize),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

dataset = {x: ImageFolder(os.path.join(opt.dataroot, x), transform) for x in splits}
dataloader = {x: torch.utils.data.DataLoader(dataset=dataset[x],
                                             batch_size=opt.batchsize,
                                             shuffle=shuffle[x],
                                             num_workers=int(opt.workers),
                                             drop_last=drop_last_batch[x],
                                             worker_init_fn=(None if opt.manualseed == -1
                                             else lambda x: np.random.seed(opt.manualseed)))
               for x in splits}
return dataloader

5.2 损失定义

""" Backpropagate through netG
"""
self.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1])
self.err_g_con = self.l_con(self.fake, self.input)
self.err_g_enc = self.l_enc(self.latent_o, self.latent_i)
self.err_g = self.err_g_adv * self.opt.w_adv + \
               self.err_g_con * self.opt.w_con + \
               self.err_g_enc * self.opt.w_enc
self.err_g.backward(retain_graph=True)

损失函数的使用如上述代码

6.测试效果

  • 数据量
NomalyAbnomaly
TES290747
TRAIN291
  • 测试结果

在这里插入图片描述

  • 可以看到准确率只有91%,效果在自定义的数据集上还不太好,不容易应用

注,上图分类评估指标可参考(二)sklearn.metrics.classification_report中的Micro/Macro/Weighted Average指标求得。

参考资料


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-14 09:57:29  更:2022-05-14 09:58:41 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 5:51:26-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码