参考链接:[译] 理解 LSTM(Long Short-Term Memory, LSTM) 网络
RNN
- 输入为序列
X
=
{
x
0
,
x
1
,
.
.
.
,
x
t
,
.
.
.
}
X=\{x_0,x_1,...,x_t,...\}
X={x0?,x1?,...,xt?,...} ,输出为
H
=
{
h
0
,
h
1
,
.
.
.
,
h
t
,
.
.
.
}
H=\{h_0,h_1,...,h_t,...\}
H={h0?,h1?,...,ht?,...},将内部的隐藏层认为是 A。
- 当前输出不仅依赖于当前的输入,还依赖于上一次的输出:
h
t
=
t
a
n
h
(
W
[
h
t
?
1
,
x
t
]
+
b
)
h_t=tanh(W[h_{t-1},x_t]+b)
ht?=tanh(W[ht?1?,xt?]+b)。
- 存在的问题的是,对于长序列,如果相距过远,信息很难挖掘。
LSTM
相比于 RNN 直接将上一输出
h
t
?
1
h_{t-1}
ht?1? 交给A,还设计了一个称为 cell state 的东西,用来保存长期记忆,这里用
C
=
{
c
0
,
c
1
,
.
.
.
,
c
t
,
.
.
.
}
C=\{c_0,c_1,...,c_t,...\}
C={c0?,c1?,...,ct?,...} 表示。很自然地,确定当前状态
c
t
c_t
ct? 的值成了一个问题。
具体来说,有以下两个方面:
-
根据当前状态
h
t
(
x
t
,
h
t
?
1
,
c
t
?
1
)
h_t(x_t,h_{t-1},c_{t-1})
ht?(xt?,ht?1?,ct?1?),过去 cell state 即
c
t
?
1
c_{t-1}
ct?1? 的信息如何调整; -
在当前的输入
(
x
t
,
h
t
?
1
)
(x_t,h_{t-1})
(xt?,ht?1?)下,哪些特征是需要保留在 cell state 里的。
这就要提到 LSTM 里的一个重要的设计,称为“门”,类似信号处理里的“窗”,其作用是选择信息。LSTM 的门由 sigmoid 和乘操作实现。在一些解释里,可能会有遗忘门,记忆门这样的说法,这里稍作参考。
sigmoid 函数如下:
σ
(
x
)
=
1
1
+
e
?
x
\sigma(x)=\frac{1}{1+e^{-x}}
σ(x)=1+e?x1? 其取值范围在 (0,1),x 趋近于负无穷,值趋近于0,x趋近于正无穷,值趋近于1。这里就当作是按 0 和 1 来取值。这是一种毫不犹豫地取舍,1完全保留,0完全抛弃。经过sigmoid 的取舍后,可以得到一组 特征数据,表示着对各个维度的关注程度。
乘:在线性代数中,向量/矩阵的乘法其实是一种在不同的基下的映射/变换。所以,乘是 目标数据 在 特征数据 下,做了一次映射/变换。 目标数据 在 特征数据 所在的空间,进行投影。
以
c
t
?
1
c_{t-1}
ct?1? 为例,门的功能如下:
T
?
c
t
?
1
=
c
t
 ̄
T\cdot c_{t-1}=\overline{c_t}
T?ct?1?=ct??,其中,T 表示经过 sigmoid 选择后的特征数据,
T
=
σ
(
W
[
x
t
,
h
t
?
1
]
+
b
)
T=\sigma(W[x_t,h_{t-1}]+b)
T=σ(W[xt?,ht?1?]+b),其中的变量是
[
x
t
,
h
t
?
1
]
[x_t,h_{t-1}]
[xt?,ht?1?],W,b 是参数。
接下来,对 A 里面的一些动作进行分解。
-
c
t
c_{t}
ct? 的更新:
a)
c
t
 ̄
=
T
1
?
c
t
?
1
\overline{c_t}=T_1\cdot c_{t-1}
ct??=T1??ct?1?。这里的
T
1
?
T_1\cdot
T1?? 是常说的遗忘门,根据新的输入对当前的 cell state 进行调整; b)
c
~
=
T
2
?
t
a
n
h
(
W
[
x
t
,
h
t
?
1
]
+
b
)
\widetilde{c}=T_2\cdot tanh(W[x_t,h_{t-1}]+b)
c
=T2??tanh(W[xt?,ht?1?]+b)。这
T
2
?
T_2\cdot
T2?? 是常说的记忆门,对当前打算保存到 Cell state 的候选数据
t
a
n
h
(
W
[
x
t
,
h
t
?
1
]
+
b
)
tanh(W[x_t,h_{t-1}]+b)
tanh(W[xt?,ht?1?]+b) 进行选择; c)
c
t
=
c
t
 ̄
+
c
~
c_t=\overline{c_t}+\widetilde{c}
ct?=ct??+c
。至此,更新完成。 - 输出
h
t
h_t
ht?:
h
t
=
T
3
?
t
a
n
h
(
c
t
)
h_t=T_3\cdot tanh(c_t)
ht?=T3??tanh(ct?)。这里的
T
3
?
T_3\cdot
T3?? 就是常说的输出门。
|