IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 与变分自编码器(Variational Auto-Encoder)的决一死战 -> 正文阅读

[人工智能]与变分自编码器(Variational Auto-Encoder)的决一死战

前言

关于VAE,其实我在第一次自学李宏毅老师的课程时就已经写过一篇文章,之后在学校的Machine Learning2课程中也将原始paper-Auto-Encoding Variational Bayes 作为我的期末报告的对象。最近因为毕业论文又需要用到VAE,因此打算在这里对VAE再做一次复习与总结。

回归原文

0.问题

现在我们有一批数据{x1,x2,…,xn}, 采样自某随机变量x。我们希望构建一个模型来探究x的生成过程。首先作者假设这个生成过程分为2步,其中涉及某个隐变量z。即
1)某个随机变量z从某个z的先验分布pθ (z)得到。
2)然后根据得到的z,从条件概率分布pθ(x|z)采样得到x的值。
在这里插入图片描述
但显然这个思路过于理想化,整个过程的大部分信息,比如后验概率pθ(z|x),隐变量z的值,均是隐藏的,因此我们无法单单依靠这个建模思路来求解模型参数。

当然,作者也提到了可以对整个过程,比如后验概率或者边际概率做一些简化性的假设来帮助模型的求解。但是作者并没有这么做,而是希望通过设计一个更一般性的算法来求解这个问题,下面我们就来看一看作者究竟是怎么做的。

1.基本思路

既然我无法直接对真实的后验概率pθ(z|x) (绝大多数情况下非常复杂)进行操作,那么我们就引入另一个分布qφ(z|x)来近似真实的后验概率。这里其实就使用了变分推断的思想。

整个模型的基本框架可以用原文中的这张图来进行说明。
在这里插入图片描述
其中实线代表生成式模型部分,即pθ(z)pθ(x|z),虚线则代表使用qφ(z|x)来变分近似未知的真实后验概率分布。其中的模型参数φ会在后续与θ一起进行优化。

2.如何定义pθ(z|x) 与 qφ(z|x)的近似程度?

既然我们需要一个近似的分布来代替真实的后验分布,那么我们必然需要一个度量手段来度量这两个分布之间的相似程度。这里我们使用KL散度,常用于度量两个分布之间的相似程度。
如qφ(z|x) 与 p(z|x)的KL散度定义如下。
在这里插入图片描述

3.使用qφ(z|x)代替pθ(z|x)进行目标函数推导

现在我们回到我们的目标—极大化边际似然函数p(x)(取log)。
通过引入qφ(z|x),我们可以将其表达为如下形式。(下图取自原文)
在这里插入图片描述
即我们可以将边际似然转化为两个部分,上式右边第一部分就是我上面提到的引入的变分近似分布qφ(z|x)与真实的后验概率分布pθ(z|x)的KL散度。第二部分被称为变分下限(the varitional lower bound)。根据KL散度的定义,我们有KL散度恒大于0的结论,因此可以得到如下这个重要的关系式。
在这里插入图片描述
这也是我们为什么将该部分称为变分下限的原因,所以我们就可以通过优化我们的变分下限来间接地优化我们的目标似然函数。那么我们现在就将注意力放在如何优化我们的变分下限函数L上。

事实上,它可以进一步被拆分成两个部分。如下
在这里插入图片描述
我们先来看后面红色的一项,这一项非常好理解,本质上就是希望模型能基于q分布(encoder)产生的隐变量z,最大概率得通过p分布 (decoder)还原得到x

然后来看蓝色这一部分,这部分代表的是我们引入的变分近似分布qφ(z|x)与隐变量z的分布之间的负KL散度,为了让最终结果最大化,这项KL散度要尽可能地小,即我们希望让q分布和隐变量z的分布尽可能靠近。本质上这一项可以看作是正则项,如果没有这一项的约束,单纯地优化红色部分,那么模型会倾向于消除所有的噪声,过分地保证训练集中的输入与输出的高度一致,这不是我们想要的,我们希望地是模型更加具有创造性。

哦对了,说了这么多好像还没有说我们是如何从最初的对数似然函数推导到我们的变分下限函数的。以下是简单的推导过程。
在这里插入图片描述
最后一行的前两项就是我们的L(θ,φ;x )。

4.如何优化新的目标函数L(θ,φ;x )?

在前一部分,我们已经得到了我们的新目标函数,即变分下限函数L(θ,φ;x )。这里我们需要优化的参数有两个,即θ,φ。但是其中φ的求解存在一些问题。如果我们想直接利用梯度来优化φ是不行的。因为参数φ决定了隐变量的分布以及之后的采样结果,但是注意,采样这个过程是不可导的,因此梯度的反向传播也会因此受阻。

因此为了使得采样过后,参数φ能够依然留在表达式中,使得我们依旧能够对参数
φ求梯度,这里作者使用了一个小trick,叫做重参数法(the reparameterization trick)。即我们再引入一个辅助变量ε,其遵循某个已知的独立分布p,比如标准正态分布。然后让其参与到z的产生过程。具体来说,我们可以将其表达为下图中黄色的表达式。
在这里插入图片描述
这里解释一下整个逻辑,首先我们的采样过程由辅助变量ε来完成,然后参数φ负责建立从采样结果ε到我们需要的采样结果z的映射关系。这样一来即使我们利用蒙特卡罗方法用采样均值替代期望后,参数φ依然留在表达式中,这样一来我们就可以使用对其求梯度来进行正常优化了。
在这里插入图片描述
利用这个方法作者就成功构建了SGVB estimator来近似我们的变分下限目标函数L(θ,φ;x ),使其依旧对参数θ,φ可微。我们来看一下其具体的表达式,如下。
在这里插入图片描述
说白了,就是用蒙特卡罗估计(L个样本的采样平均)代替原来的期望。前一项KL divergence是有表达式的,因此不需要做近似。

事实上,在原文中变分下限函数L(θ,φ;x )有两种表达式,因此SGVB estimator自然也有对应的两种,这里我列举了第二种(B), 也是因为与我上文列出的变分下限表达式保持一致。

Anyway, 现在我们已经提出了对目标参数θ,φ可微的目标函数表达式,接下来我们可以来引出完整的算法了,也是这篇文章的核心算法—AEVB algorithm
直接copy原文:)
在这里插入图片描述
其实核心就是得到SGVB estimator,之前的步骤包括初始化参数,分割minibatch,辅助变量的采样,后面的步骤就是使用梯度更新参数。这些更多是实践层面的东西,这里就不再展开了。

5.重要的应用VAE

这篇文章提出了一个一般性的算法,用于求解包含隐变量的有向概率图模型。其中非常重要的一个应用就是VAE。

在上述的生成式概率图模型中,我们引入了一个变分近似qφ (z|x),用于近似表示未知的后验概率pθ(z|x)。这个部分可以被看作是一个编码(encoding)的过程,即将原始输入x转化为对应隐变量z的分布参数。在VAE中,我们使用一个神经网络来拟合这个过程。但是光这样还是不够的,因为我们对pθ(z|x)是完全未知的,因此能够使用的近似分布qφ (z|x)的形式存在太多不同的可能了。所以作者在这里做了一个假设,真实的后验概率为高斯分布。下面是原文中的相关描述。
在这里插入图片描述
从而我们神经网络的构建方向也清晰了: 使用一个多层感知机,输出为高斯分布的均值和方差,当然在高维情况下,为均值和方差向量。
在这里插入图片描述
然后就是借助上文中提到的reparameterization技巧借助辅助变量ε进行z的采样。
偷个懒,直接拷贝原文的描述:)
在这里插入图片描述
完成z的采样之后,接下来就来到了解码过程了(decoding),即通过隐变量z来还原x。同样的,这个解码器(decoder)也使用神经网络来拟合。在原文中,作者针对输出数据的不同分布,一种是伯努利分布,一种是高斯分布,提出了两种对应的解码器结构。
在这里插入图片描述
理论部分差不多就是这样了,接下来在最后一部分,用pytorch写一个简单的VAE。

6.Pytorch实现一个简单的VAE结构

import torch
import torch.nn as nn
import torch.nn.function

#basic VAE
class simple_vae(nn.Module):
    def __init__(self,org_size,input_size,hid_dim,lat_dim):
        super(VAE,self).__init__()
        self.org_size=org_size # the original size of x
        self.input_size=input_size # the reshaped size of x
        self.hid_dim=hid_dim # hidden layer dimension
        self.lat_dim=lat_dim # latent variable dimension
        self.encoder=nn.Sequential(nn.Linear(input_size,hid_dim),
                                   nn.ReLU())
        # mu vector
        self.mu=nn.Linear(hid_dim,lat_dim)
        # log_var vector : make sure var is always positive
        self.log_var=nn.Linear(hid_dim,lat_dim) 
        
        self.decoder=nn.Sequential(nn.Linear(lat_dim,hid_dim),
                                   nn.ReLU(),
                                   nn.Linear(hid_dim,input_size))
       
    def encode(self,x):
        x=x.view(x.size(0),self.input_size) # reshape to (batch_size,input_size)
        hid_out=self.encoder(x)
        return self.mu(hid_out),self.log_var(hid_out)
    
    def reparameterize(self,mu,log_var):
        # calculate std
        std=torch.exp(log_var/2)
        # produce random auxiliary variable eps with the same size of mu (or std)
        eps=torch.randn_like(mu) 
        return mu+eps*std
    
    def decode(self,z):
        out=self.decoder(z)
        return F.sigmoid(out)
    
    def forward(self,x):
        mu,log_var=self.encode(x)
        z=self.reparameterize(mu,log_var)
        output=self.decode(z)
        
        return output 

这里只是一个示例代码,单纯是为了博客完整性而写的,我并没有在实际数据集上做任何实验。如果你想直接使用的话还需要注意以下问题:

1上面示例中我的output是没有最终还原到原始数据尺寸的。
2.最终解码时我使用sigmoid激活函数将其输出分布压缩到0-1之间,但是在实际使用中还是需要依据你的数据分布来进行修改。
3.这样简单的结构一般在实际使用中很难有好的效果。

以上就是这篇文章的全部内容了,谢谢大家的阅读。

参考文章:
1.Auto-Encoding Variational Bayes
2.《Auto-Encoding Variational Bayes》论文阅读
3.变分自编码器(一):原来是这么一回事

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-10 13:25:17  更:2021-08-10 13:28:03 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 2:51:52-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码