本系列为哔哩哔哩网站shuhuai008的【机器学习白板推导】的注解。
-
EM算法背景 EM算法应用于以下这类问题[1]: 有三枚硬币A、B、C,单独投掷正面朝上概率为a、b、c。下图为一个自动投掷机的逻辑,重复进行多次 (例:5次实验结果为[正,正,反,反,正]),估计a、b、c。
隐变量
正面
反面
投掷A
投掷B
投掷C
记录正反结果
上述问题是一个有隐变量的问题,硬币A的投掷结果是一个黑箱,我们观测不到。 把要估计参数的称为
θ
=
(
a
,
b
,
c
)
\boldsymbol{\theta}=(a,b,c)
θ=(a,b,c),把实验得到的一组正反称为样本
x
\bold{x}
x,最终目标 就是求:
arg?max
?
θ
P
(
x
∣
θ
)
=
arg?max
?
θ
log
?
P
(
x
∣
θ
)
=
log
?
∏
j
=
0
n
(
a
b
x
j
(
1
?
b
)
1
?
x
j
+
(
1
?
a
)
c
x
j
(
1
?
c
)
1
?
x
j
)
(1)
\begin{aligned} \argmax_{\boldsymbol{\theta}}P(\bold{x}|\boldsymbol{\theta})&=\argmax_{\boldsymbol{\theta}}\log{P(\bold{x}|\boldsymbol{\theta})} \\ &=\log{\prod_{j=0}^n(ab^{x_j}(1-b)^{1-x_j}+(1-a)c^{x_j}(1-c)^{1-x_j})}\\ \tag{1} \end{aligned}
θargmax?P(x∣θ)?=θargmax?logP(x∣θ)=logj=0∏n?(abxj?(1?b)1?xj?+(1?a)cxj?(1?c)1?xj?)?(1) 如何求
(
1
)
(1)
(1)式? 使用极大似然估计法对
θ
\boldsymbol{\theta}
θ各个参数进行求偏导令其等于0是没有解析解的[李航,统计学习方法,2012],因为对隐变量a求偏导得到的方程是复杂的。 PRML书[1]section9.2.1结尾给出了两种方法,一种是梯度下降方法,一种是EM算法。PRML介绍了用EM算法解决上述含有隐变量问题,其将EM算法描述为"elegant and powerful"。在胡浩基的机器学习课程的EM算法部分中指出EM算法相对于梯度下降法的优点是[1.不需要调节参数;2.编程简单]。 EM算法的思想是多次迭代,使每次迭代满足
log
?
P
(
x
∣
θ
[
i
+
1
]
)
≥
log
?
P
(
x
∣
θ
[
i
]
)
(2)
\log{P(\bold{x}|\boldsymbol{\theta}^{[i+1]})}\geq\log{P(\bold{x}|\boldsymbol{\theta}^{[i]})} \tag{2}
logP(x∣θ[i+1])≥logP(x∣θ[i])(2),
i
i
i表示第
i
i
i轮迭代,这样迭代多次后,似然越来越大,就能收敛,就能得到
(
1
)
(1)
(1)式的局部近似解(EM算法无法得到全局最优解)。 -
EM算法收敛证明 本系列根据B站【白板推导】顺序注解,所以先直接给出EM算法公式,证明是迭代收敛的,以后的章节再给出公式推导。 设隐变量为
z
\bold{z}
z,样本为
x
\bold{x}
x,要估计的参数为
θ
\boldsymbol{\theta}
θ,EM算法迭代公式为:
θ
[
i
+
1
]
=
arg?max
?
θ
∫
z
P
(
z
∣
x
,
θ
[
i
]
)
?
log
?
P
(
x
,
z
∣
θ
)
d
z
(3)
\boldsymbol{\theta}^{[i+1]}=\textcolor{FF0000}{\argmax_{\boldsymbol{\theta}}}\textcolor{0000FF}{\int_{\bold{z}}P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})*\log{P(\bold{x},\bold{z}|\boldsymbol{\theta})}d\bold{z}} \tag{3}
θ[i+1]=θargmax?∫z?P(z∣x,θ[i])?logP(x,z∣θ)dz(3) 蓝色的是期望E(Expectation),红色的是取最大M(Maximization),所以叫做EM算法。注意:
θ
\boldsymbol{\theta}
θ为变量,
θ
[
i
]
\boldsymbol{\theta}^{[i]}
θ[i]为数值,
θ
[
0
]
\boldsymbol{\theta}^{[0]}
θ[0]是人为给定的。 要证明EM算法迭代收敛,只要证明
(
2
)
(2)
(2)式成立即可。 根据条件概率公式,将
(
2
)
(2)
(2)式基本形式变形引出隐变量:
log
?
P
(
x
∣
θ
)
=
log
?
P
(
x
,
z
∣
θ
)
P
(
z
∣
x
,
θ
)
=
log
?
P
(
x
,
z
∣
θ
)
?
log
?
P
(
z
∣
x
,
θ
)
\log{P(\bold{x}|\boldsymbol{\theta})}=\log{\frac{P(\bold{x},\bold{z}|\boldsymbol{\theta})}{P(\bold{z}|\bold{x},\boldsymbol{\theta})}}=\log{P(\bold{x},\bold{z}|\boldsymbol{\theta})}-\log{P(\bold{z}|\bold{x},\boldsymbol{\theta})}
logP(x∣θ)=logP(z∣x,θ)P(x,z∣θ)?=logP(x,z∣θ)?logP(z∣x,θ) 对上式乘
P
(
z
∣
x
,
θ
[
i
]
)
P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})
P(z∣x,θ[i]),再对
z
\bold{z}
z积分,这步操作动机是为了利用
(
3
)
(3)
(3)式,得:
左
边
=
∫
z
P
(
z
∣
x
,
θ
[
i
]
)
?
log
?
P
(
x
∣
θ
)
d
z
=
log
?
P
(
x
∣
θ
)
右
边
=
∫
z
P
(
z
∣
x
,
θ
[
i
]
)
?
log
?
P
(
x
,
z
∣
θ
)
d
z
?
∫
z
P
(
z
∣
x
,
θ
[
i
]
)
?
log
?
P
(
z
∣
x
,
θ
)
d
z
\begin{aligned} 左边&=\int_{\bold{z}}P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})*\log{P(\bold{x}|\boldsymbol{\theta})}d\bold{z}=\log{P(\bold{x}|\boldsymbol{\theta})}\\ 右边&=\textcolor{008000}{\int_{\bold{z}}P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})*\log{P(\bold{x},\bold{z}|\boldsymbol{\theta})}d\bold{z}}-\textcolor{FFA500}{\int_{\bold{z}}P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})*\log{P(\bold{z}|\bold{x},\boldsymbol{\theta})}d\bold{z}} \end{aligned}
左边右边?=∫z?P(z∣x,θ[i])?logP(x∣θ)dz=logP(x∣θ)=∫z?P(z∣x,θ[i])?logP(x,z∣θ)dz?∫z?P(z∣x,θ[i])?logP(z∣x,θ)dz? 左边没有变化,记作
L
(
θ
)
L(\boldsymbol{\theta})
L(θ)。绿色部分记作
Q
(
θ
)
\textcolor{008000}{Q(\boldsymbol{\theta})}
Q(θ),橙色部分记作
H
(
θ
)
\textcolor{FFA500}{H(\boldsymbol{\theta})}
H(θ)。
(
2
)
(2)
(2)式其实就是:
L
(
θ
[
i
+
1
]
)
?
L
(
θ
[
i
]
)
=
(
Q
(
θ
[
i
+
1
]
)
?
Q
(
θ
[
i
]
)
)
?
(
H
(
θ
[
i
+
1
]
)
?
H
(
θ
[
i
]
)
)
≥
0
L(\boldsymbol{\theta}^{[i+1]})-L(\boldsymbol{\theta}^{[i]})=(Q(\boldsymbol{\theta}^{[i+1]})-Q(\boldsymbol{\theta}^{[i]}))-(H(\boldsymbol{\theta}^{[i+1]})-H(\boldsymbol{\theta}^{[i]}))\geq 0
L(θ[i+1])?L(θ[i])=(Q(θ[i+1])?Q(θ[i]))?(H(θ[i+1])?H(θ[i]))≥0 收敛性问题进一步转化为证明:
Q
(
θ
[
i
+
1
]
)
≥
Q
(
θ
[
i
]
)
H
(
θ
[
i
+
1
]
)
≤
H
(
θ
[
i
]
)
Q(\boldsymbol{\theta}^{[i+1]})\geq Q(\boldsymbol{\theta}^{[i]})\\ H(\boldsymbol{\theta}^{[i+1]})\leq H(\boldsymbol{\theta}^{[i]})
Q(θ[i+1])≥Q(θ[i])H(θ[i+1])≤H(θ[i])
Q
(
θ
)
Q(\boldsymbol{\theta})
Q(θ)就是故意凑出来的
(
3
)
(3)
(3)式蓝色部分,所以
Q
(
θ
[
i
+
1
]
)
≥
Q
(
θ
[
i
]
)
Q(\boldsymbol{\theta}^{[i+1]})\geq Q(\boldsymbol{\theta}^{[i]})
Q(θ[i+1])≥Q(θ[i])直接得证。 接下来证明
H
(
θ
)
H(\boldsymbol{\theta})
H(θ)不等式:
H
(
θ
[
i
+
1
]
)
?
H
(
θ
[
i
]
)
=
∫
z
P
(
z
∣
x
,
θ
[
i
]
)
?
log
?
P
(
z
∣
x
,
θ
[
i
+
1
]
)
P
(
z
∣
x
,
θ
[
i
]
)
d
z
H(\boldsymbol{\theta}^{[i+1]})-H(\boldsymbol{\theta}^{[i]})=\textcolor{9400D3}{\int_{\bold{z}}P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})*\log\frac{P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i+1]})}{P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})}d\bold{z}}
H(θ[i+1])?H(θ[i])=∫z?P(z∣x,θ[i])?logP(z∣x,θ[i])P(z∣x,θ[i+1])?dz 根据
E
(
log
?
猫
)
≤
log
?
E
(
猫
)
E(\log{猫})\leq \log{E(猫)}
E(log猫)≤logE(猫)(参考
log
?
a
+
log
?
b
2
≤
log
?
a
+
b
2
\frac{\log a+\log b}{2}\leq \log\frac{a+b}{2}
2loga+logb?≤log2a+b?,E就是相加除以2), 将紫色部分中
P
(
z
∣
x
,
θ
[
i
+
1
]
)
P
(
z
∣
x
,
θ
[
i
]
)
\frac{P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i+1]})}{P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})}
P(z∣x,θ[i])P(z∣x,θ[i+1])?看作猫,紫色部分为:
∫
z
P
(
z
∣
x
,
θ
[
i
]
)
?
log
?
猫
?
d
z
≤
log
?
∫
z
P
(
z
∣
x
,
θ
[
i
]
)
?
猫
?
d
z
=
消
去
分
母
log
?
∫
z
P
(
z
∣
x
,
θ
[
i
+
1
]
)
d
z
=
log
?
1
=
0
\begin{aligned} &\int_{\bold{z}}P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})*\log猫\,d\bold{z}\\ &\leq \log\int_{\bold{z}}P(\bold{z}|\bold{x},\boldsymbol{\theta}^{[i]})*猫\,d\bold{z}\xlongequal{消去分母} \log\int_{\bold{z}}P(\bold{z}|\bold{x },\boldsymbol{\theta}^{[i+1]})d\bold{z}\\&=\log{1}\\&=0\end{aligned}
?∫z?P(z∣x,θ[i])?log猫dz≤log∫z?P(z∣x,θ[i])?猫dz消去分母
log∫z?P(z∣x,θ[i+1])dz=log1=0? 所以
H
(
θ
[
i
+
1
]
)
≤
H
(
θ
[
i
]
)
H(\boldsymbol{\theta}^{[i+1]})\leq H(\boldsymbol{\theta}^{[i]})
H(θ[i+1])≤H(θ[i])得证。
Q
(
θ
)
Q(\boldsymbol{\theta})
Q(θ)和
H
(
θ
)
H(\boldsymbol{\theta})
H(θ)不等式都证明完毕后,得到
L
(
θ
[
i
+
1
]
)
≥
L
(
θ
[
i
]
)
L(\boldsymbol{\theta}^{[i+1]})\geq L(\boldsymbol{\theta}^{[i]})
L(θ[i+1])≥L(θ[i]),得到
(
2
)
(2)
(2)式得证,得到EM算法收敛得证。
参考文献: [1]李航. 统计学习方法[M]. 清华大学出版社, 2012. [2]Christopher M. Bishop. Pattern Recognition and Machine Learning[M]. Springer, 2006.
|