Dive Into MindSpore–LSTM Operator For Network Construction
MindSpore易点通·精讲系列–网络构建之LSTM算子
本文开发环境
本文内容提要
1. 原理介绍
LSTM,Long Short Term Memory,又称长短时记忆网络。原始RNN存在一个严重的缺陷:训练过程中经常会出现梯度爆炸和梯度消失的问题,以至于原始的RNN很难处理长距离的依赖,为了解决(缓解)这个问题,研究人员提出了LSTM。
1.1 LSTM公式
LSTM的公式表示如下: 其中 σ 是sigmoid激活函数, * 是乘积。 W, b 是公式中输出和输入之间的可学习权重。
1.2 LSTM结构
为方便理解,1.1 中的公式的结构示意图如下:
1.3 LSTM门控
1.3.1 遗忘门
遗忘门公式为:
f
t
=
σ
(
W
f
x
x
t
+
b
f
x
+
W
f
h
h
(
t
?
1
)
+
b
f
h
)
f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh})
ft?=σ(Wfx?xt?+bfx?+Wfh?h(t?1)?+bfh?)
解读:
“遗忘门”决定之前状态中的信息有多少应该舍弃。它会读取
h
t
?
1
h_{t-1}
ht?1? 和
x
t
x_t
xt?的内容,
σ
\sigma
σ符号代表Sigmoid函数,它会输出一个0到1之间的值。其中0代表舍弃之前细胞状态
C
t
?
1
C_{t-1}
Ct?1?中的内容,1代表完全保留之前细胞状态
C
t
?
1
C_{t-1}
Ct?1?中的内容。0、1之间的值代表部分保留之前细胞状态
C
t
?
1
C_{t-1}
Ct?1?中的内容。
1.3.2 输入门
输入门公式为:
i
t
=
σ
(
W
i
x
x
t
+
b
i
x
+
W
i
h
h
(
t
?
1
)
+
b
i
h
)
c
~
t
=
tanh
?
(
W
c
x
x
t
+
b
c
x
+
W
c
h
h
(
t
?
1
)
+
b
c
h
)
i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch})
it?=σ(Wix?xt?+bix?+Wih?h(t?1)?+bih?)c~t?=tanh(Wcx?xt?+bcx?+Wch?h(t?1)?+bch?)
解读:
“输入门”决定什么样的信息保留在细胞状态
C
t
C_t
Ct?中,它会读取
h
t
?
1
h_{t-1}
ht?1? 和
x
t
x_t
xt?的内容,
σ
\sigma
σ符号代表Sigmoid函数,它会输出一个0到1之间的值。
和“输入门”配合的还有另外一部分,这部分输入也是
h
t
?
1
h_{t-1}
ht?1? 和
x
t
x_t
xt?,不过采用tanh激活函数,将这部分标记为
c
~
(
t
)
\tilde c^{(t)}
c~(t),称作为“候选状态”。
1.3.3 细胞状态
细胞状态公式为:
c
t
=
f
t
?
c
(
t
?
1
)
+
i
t
?
c
~
t
c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t
ct?=ft??c(t?1)?+it??c~t?
解读:
由
C
t
?
1
C_{t-1}
Ct?1? 计算得到
C
t
C_t
Ct?
旧“细胞状态”
C
t
?
1
C_{t-1}
Ct?1?和“遗忘门”的结果进行计算,决定旧的“细胞状态”保留多少,忘记多少。接着“输入门”
i
(
t
)
i^{(t)}
i(t)和候选状态
c
~
(
t
)
\tilde c^{(t)}
c~(t)进行计算,将所得到的结果加入到“细胞状态”中,这表示新的输入信息有多少加入到“细胞状态中”。
1.3.4 输出门
输出门公式为:
o
t
=
σ
(
W
o
x
x
t
+
b
o
x
+
W
o
h
h
(
t
?
1
)
+
b
o
h
)
h
t
=
o
t
?
tanh
?
(
c
t
)
o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ h_t = o_t * \tanh(c_t)
ot?=σ(Wox?xt?+box?+Woh?h(t?1)?+boh?)ht?=ot??tanh(ct?)
解读:
和其他门计算一样,它会读取
h
t
?
1
h_{t-1}
ht?1? 和
x
t
x_t
xt?的内容,然后计算Sigmoid函数,得到“输出门”的值。接着把“细胞状态”通过tanh进行处理(得到一个在-1到1之间的值),并将它和输出门的结果相乘,最终得到确定输出的部分
h
t
h_t
ht?,即新的隐藏状态。
特别说明:
在上述公式中,xt为当前的输入,h(t-1)为上一步的隐藏状态,c(t-1)为上一步的细胞状态。
当t=1时,可知h(t-1)为h0,c(t-1)为c0。
一般来说,h0/c0设置为0或1,或固定的随机值。
2. 文档说明
下面来看看官网文档说明,主要看参数部分:
从官方文档可知,MindSpore中的LSTM算子支持多层双向设置,同时可接受输入数据第一维为非batch_size的情况,而且自带dropout。
下面通过案例来对该算子的输入和输出进行讲解。
3. 案例解说
3.1 单层正向LSTM
本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。
本示例采用单层单向LSTM,隐层大小为8。
本示例中LSTM调用时进行对比测试,一个seq_length 为默认值None,一个为有效长度input_seq_length 。
示例代码如下:
import numpy as np
from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM
def single_layer_lstm():
random_data = np.random.rand(4, 8, 4)
seq_length = [3, 8, 5, 1]
input_seq_data = Tensor(random_data, dtype=dtype.float32)
input_seq_length = Tensor(seq_length, dtype=dtype.int32)
batch_size = 4
input_size = 4
hidden_size = 8
num_layers = 1
bidirectional = False
num_bi = 2 if bidirectional else 1
lstm = LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)
h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)
print("====== single layer lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
print("====== single layer lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
print("====== single layer lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)
print("====== single layer lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
print("====== single layer lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
print("====== single layer lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)
示例代码输出内容如下:
对输出内容进行分析:
- output_0和output_1维度都是[4, 8, 8],即batch_size, seq_length和hidden_size
- output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
- output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
- hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
- hn_1维度为[1, 4, 8],1代表单向单层(1*1),4代表batch_size,8代表hidden_size。
- 仔细观察可以看出,hn_1的输出与output_1最后一维的输出一致,即与有效长度内最后一个的输出保持一致。
- cn_1为有效最后一步的细胞状态。
====== single layer lstm output 0 shape: (4, 8, 8) ======
[[[ 0.13193643 0.31574252 0.21773982 0.359429 0.23590101
0.28213733 0.24443595 0.37388077]
[-0.02988351 0.1415896 0.15356182 0.2834958 -0.00328176
0.3491612 0.12643641 0.142024 ]
[-0.09670443 0.03373189 0.1445203 0.19673887 0.06278481
0.33509392 -0.02579015 0.07650157]
[-0.15380219 -0.04781847 0.07795938 0.15893918 0.01305779
0.33979264 -0.00364386 0.04361304]
[-0.16254447 -0.06737433 0.05285644 0.10944269 0.01782622
0.34567034 -0.04204851 0.01285298]
[-0.21082401 -0.09526701 0.0265205 0.10617667 -0.03112434
0.33731762 -0.02207689 -0.00955394]
[-0.23450094 -0.09586379 0.02365175 0.09352495 -0.03744857
0.33376914 -0.04699665 -0.03528202]
[-0.24089803 -0.06166056 0.02839395 0.09916345 -0.04156012
0.31369895 -0.08876226 -0.0487675 ]]
[[ 0.10673305 0.30631748 0.22279048 0.35392687 0.270858
0.2800686 0.21576329 0.37215734]
[ 0.07373721 0.07924869 0.20754944 0.2059646 0.12672944
0.35556036 0.05576535 0.2124105 ]
[-0.09233213 0.02507205 0.11608997 0.23507075 0.0269099
0.3196378 0.00475359 0.05898073]
[-0.14939436 -0.04166775 0.07941992 0.15797664 0.02167228
0.34059638 -0.02956495 0.00525782]
[-0.18659307 -0.08790994 0.04543061 0.12085741 0.01649844
0.33063915 -0.03531799 -0.01156766]
[-0.22867033 -0.10603286 0.03872797 0.11688479 0.01904946
0.3056394 -0.05695718 -0.01623933]
[-0.21695574 -0.11095987 0.03115554 0.08672465 0.04249544
0.3152427 -0.07418983 -0.02036544]
[-0.21967101 -0.10076816 0.01712734 0.08198812 0.02862469
0.31535396 -0.09173042 -0.05647325]]
[[ 0.1493079 0.28768584 0.2575181 0.3199168 0.30599245
0.28865623 0.16678075 0.41237575]
[ 0.01445133 0.13631815 0.18265024 0.2577204 0.09361918
0.3227448 0.04080902 0.17163058]
[-0.1164555 0.05409181 0.1229048 0.24406306 0.02090637
0.31171325 -0.02868806 0.06015658]
[-0.12215493 -0.04073931 0.09229688 0.13461691 0.05322267
0.34697118 -0.04028781 0.05017967]
[-0.16058712 -0.02990636 0.06711683 0.13881728 0.04944531
0.30471358 -0.08764775 0.01227296]
[-0.17542893 -0.04518626 0.06441598 0.12666796 0.1039256
0.29512212 -0.12625514 -0.01764686]
[-0.18198647 -0.06205402 0.05437353 0.12312049 0.11571115
0.27589387 -0.13898477 -0.00659172]
[-0.18840623 -0.03089028 0.02871101 0.13332503 0.02779378
0.2934873 -0.12758468 -0.02508291]]
[[ 0.16055782 0.28248906 0.24979302 0.3381475 0.28849283
0.3085897 0.21882199 0.3911534 ]
[ 0.03212452 0.10363571 0.18571742 0.25555134 0.11808199
0.33315352 0.0612903 0.16566488]
[-0.09707587 0.08886775 0.130165 0.23324937 0.0596167
0.28433815 -0.05993269 0.06611289]
[-0.15705962 -0.00274712 0.09360209 0.18597823 0.04157853
0.32279128 -0.07580574 0.01155218]
[-0.15376413 -0.07929687 0.06302985 0.11465057 0.07184268
0.3261627 -0.05871713 0.04223134]
[-0.18791473 -0.07859336 0.02364462 0.12526496 -0.02513029
0.33071572 -0.03542359 -0.00976665]
[-0.23625109 -0.03007499 0.03267653 0.15940045 -0.08530897
0.30445266 -0.0852924 -0.04507463]
[-0.23499809 -0.07687293 0.03790941 0.08663946 -0.00264841
0.33423126 -0.06512782 0.01413365]]]
====== single layer lstm hn0 shape: (1, 4, 8) ======
[[[-0.24089803 -0.06166056 0.02839395 0.09916345 -0.04156012
0.31369895 -0.08876226 -0.0487675 ]
[-0.21967101 -0.10076816 0.01712734 0.08198812 0.02862469
0.31535396 -0.09173042 -0.05647325]
[-0.18840623 -0.03089028 0.02871101 0.13332503 0.02779378
0.2934873 -0.12758468 -0.02508291]
[-0.23499809 -0.07687293 0.03790941 0.08663946 -0.00264841
0.33423126 -0.06512782 0.01413365]]]
====== single layer lstm cn0 shape: (1, 4, 8) ======
[[[-0.72842515 -0.10623126 0.07748945 0.23840414 -0.0663506
0.82394135 -0.20612013 -0.11983471]
[-0.6431069 -0.17861958 0.04168103 0.20188545 0.0463764
0.73273325 -0.21914008 -0.13169488]
[-0.61163914 -0.05123866 0.07892742 0.32583922 0.04181815
0.79872614 -0.2969701 -0.0625343 ]
[-0.58037984 -0.15040846 0.09998614 0.24211554 -0.0044073
0.8616534 -0.1546249 0.03137078]]]
====== single layer lstm output 1 shape: (4, 8, 8) ======
[[[ 0.13193643 0.31574252 0.21773985 0.35942894 0.23590101
0.28213733 0.24443595 0.37388077]
[-0.02988352 0.1415896 0.15356182 0.28349578 -0.00328175
0.34916118 0.12643641 0.142024 ]
[-0.09670443 0.0337319 0.14452031 0.19673884 0.06278481
0.33509392 -0.02579015 0.07650157]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]]
[[ 0.10673306 0.30631748 0.22279048 0.35392687 0.27085796
0.2800686 0.21576326 0.37215734]
[ 0.07373722 0.0792487 0.20754944 0.2059646 0.12672943
0.35556036 0.05576536 0.2124105 ]
[-0.09233214 0.02507207 0.11608997 0.23507075 0.02690989
0.3196378 0.00475359 0.05898073]
[-0.14939436 -0.04166774 0.07941992 0.15797664 0.02167228
0.34059638 -0.02956495 0.00525782]
[-0.18659307 -0.08790994 0.04543061 0.12085741 0.01649844
0.33063915 -0.03531799 -0.01156766]
[-0.22867033 -0.10603285 0.03872797 0.11688479 0.01904945
0.3056394 -0.05695718 -0.01623933]
[-0.21695574 -0.11095986 0.03115554 0.08672465 0.04249543
0.3152427 -0.07418983 -0.02036544]
[-0.21967097 -0.10076815 0.01712734 0.08198812 0.02862468
0.31535396 -0.09173042 -0.05647324]]
[[ 0.1493079 0.28768584 0.25751814 0.3199168 0.30599245
0.28865623 0.16678077 0.41237575]
[ 0.01445133 0.13631816 0.18265024 0.25772038 0.09361918
0.3227448 0.04080902 0.17163058]
[-0.1164555 0.05409183 0.1229048 0.24406303 0.02090637
0.31171325 -0.02868806 0.06015658]
[-0.12215493 -0.0407393 0.09229688 0.1346169 0.05322267
0.3469712 -0.0402878 0.05017967]
[-0.16058712 -0.02990635 0.06711683 0.13881728 0.0494453
0.30471358 -0.08764775 0.01227296]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]]
[[ 0.16055782 0.2824891 0.24979301 0.33814746 0.28849283
0.30858967 0.21882202 0.3911534 ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. ]]]
====== single layer lstm hn1 shape: (1, 4, 8) ======
[[[-0.09670443 0.0337319 0.14452031 0.19673884 0.06278481
0.33509392 -0.02579015 0.07650157]
[-0.21967097 -0.10076815 0.01712734 0.08198812 0.02862468
0.31535396 -0.09173042 -0.05647324]
[-0.16058712 -0.02990635 0.06711683 0.13881728 0.0494453
0.30471358 -0.08764775 0.01227296]
[ 0.16055782 0.2824891 0.24979301 0.33814746 0.28849283
0.30858967 0.21882202 0.3911534 ]]]
====== single layer lstm cn1 shape: (1, 4, 8) ======
[[[-0.22198828 0.05788375 0.38487202 0.5277796 0.10692163
0.88817626 -0.06333658 0.15489307]
[-0.6431068 -0.17861956 0.04168103 0.20188545 0.04637639
0.73273325 -0.21914008 -0.13169487]
[-0.44337854 -0.05043292 0.17615467 0.36942852 0.0769525
0.8138213 -0.22219141 0.02737183]
[ 0.50136805 0.47527558 0.8696786 0.7511291 0.37594885
0.9162327 0.5345433 0.6333548 ]]]
3.2 单层双向LSTM
本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。
本示例采用单层双向LSTM,隐层大小为8。
本示例中LSTM调用时进行对比测试,一个seq_length 为默认值None,一个为有效长度input_seq_length 。
示例代码如下:
import numpy as np
from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM
def single_layer_bi_lstm():
random_data = np.random.rand(4, 8, 4)
seq_length = [3, 8, 5, 1]
input_seq_data = Tensor(random_data, dtype=dtype.float32)
input_seq_length = Tensor(seq_length, dtype=dtype.int32)
batch_size = 4
input_size = 4
hidden_size = 8
num_layers = 1
bidirectional = True
num_bi = 2 if bidirectional else 1
lstm = LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)
h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)
print("====== single layer bi lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
print("====== single layer bi lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
print("====== single layer bi lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)
print("====== single layer bi lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
print("====== single layer bi lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
print("====== single layer bi lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)
示例代码输出内容如下:
对输出内容进行分析:
- output_0和output_1维度都是[4, 8, 16],即batch_size, seq_length和hidden_size * 2,这里乘2是因为是双向输出。
- output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
- output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
- hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
- hn_1维度为[2, 4, 8],2代表双向单层(2*1),4代表batch_size,8代表hidden_size。
- 仔细观察可以看出,hn_1中第一维度第0索引的正向输出部分与output_1最后一维输出前hidden_size数值一致,即与有效长度内最后一个的输出的前hidden_size数值保持一致。
- 仔细观察可以看出,hn_1中第一维度第1索引的反向输出部分与output_1开始一维输出后hidden_size数值一致。
- cn_1为有效最后一步的细胞状态。
====== single layer bi lstm output 0 shape: (4, 8, 16) ======
[[[ 0.11591419 0.29961097 0.3425573 0.4287143 0.17212108
0.07444338 0.43271446 0.15715674 0.08194006 0.11577142
-0.09744498 -0.02763127 0.09280778 0.08716499 0.02522062
0.33181873]
[-0.01308823 0.13623668 0.19448121 0.37028143 0.22777143
0.00628781 0.39128026 0.15501572 0.08111142 0.11017906
-0.12316822 -0.00816909 0.09567513 0.05021677 0.08249568
0.33742255]
[-0.05627449 0.04682723 0.15380071 0.3137156 0.26430035
-0.046514 0.35723254 0.16584632 0.10204285 0.10223756
-0.13232729 -0.00190703 0.11279006 0.07007243 0.07809626
0.36085904]
[-0.09489179 -0.00705127 0.1340199 0.24711385 0.27097055
-0.05539801 0.29088783 0.180727 0.13702057 0.07165765
-0.15263684 -0.02301912 0.14440101 0.09643525 0.04434848
0.32824463]
[-0.13192342 -0.09842218 0.13483751 0.2363211 0.2714419
-0.06301905 0.23002718 0.12190706 0.1600955 0.0820565
-0.13324322 0.00847512 0.15308659 0.12757084 0.06873622
0.3726861 ]
[-0.16037701 -0.12437794 0.12642992 0.23676534 0.29797453
-0.04277696 0.24219972 0.16359471 0.16195399 0.07269616
-0.1250204 -0.0185749 0.19040069 0.12709007 0.12064856
0.30454746]
[-0.1353235 -0.12385159 0.1025193 0.23867385 0.30110353
-0.03195428 0.2832907 0.18136714 0.19130123 0.09153596
-0.05207976 0.02430173 0.2524703 0.22256352 0.17788586
0.3196903 ]
[-0.15227936 -0.16710246 0.11279354 0.2324703 0.3158889
-0.05391366 0.28967926 0.21905534 0.34464788 0.06061291
0.10662059 0.08228769 0.38103724 0.44488934 0.22631703
0.38864976]]
[[ 0.07946795 0.30921736 0.35205007 0.37194842 0.2058839
0.09482588 0.4332572 0.2775039 0.10343523 0.07151344
-0.13616626 -0.04245609 0.10985457 0.06919786 0.0364913
0.31924048]
[-0.04591701 0.14795585 0.20307627 0.35713255 0.21074952
0.03478044 0.36047992 0.15351431 0.11235587 0.07168273
-0.11715946 -0.02380875 0.11772131 0.11803672 0.00387634
0.33266184]
[-0.09412251 0.02499678 0.17255405 0.3178058 0.23692454
-0.03471331 0.26576498 0.10732022 0.14581609 0.07355653
-0.12852795 0.01927058 0.13053373 0.14796041 0.01590303
0.3854578 ]
[-0.09348419 0.00631614 0.1466178 0.22848201 0.22966608
-0.05388562 0.14963126 0.08823045 0.15729474 0.0657778
-0.15222837 -0.01835432 0.15758416 0.17561477 -0.03188463
0.3511778 ]
[-0.15382743 -0.04836275 0.14573918 0.22835778 0.2532363
-0.03674607 0.1401736 0.09852327 0.17570393 0.04582136
-0.13850203 0.00081276 0.16863164 0.14211492 0.04397457
0.33833435]
[-0.14028388 -0.08847751 0.13194019 0.21878807 0.28851762
-0.06432837 0.15592363 0.16226491 0.20294866 0.04400881
-0.11535563 0.04870296 0.22049154 0.17808373 0.09339966
0.34441146]
[-0.1683049 -0.16189072 0.1318028 0.22591397 0.3027075
-0.07447627 0.15145044 0.1329806 0.2544369 0.06014252
-0.01793557 0.11026148 0.2146467 0.3118566 0.12141219
0.39812002]
[-0.19805393 -0.17752953 0.12876241 0.21628919 0.3038769
-0.036511 0.1357605 0.10460708 0.3527281 0.07156999
0.1540587 0.09252883 0.35960466 0.54258245 0.16377062
0.40849966]]
[[ 0.08452003 0.3159105 0.3420099 0.3319746 0.20285761
0.08632328 0.3581056 0.27760154 0.14828831 0.04973472
-0.18127252 -0.02664946 0.11601479 0.06740937 0.0379785
0.342705 ]
[-0.0266434 0.16035607 0.18312001 0.31999707 0.22840345
0.01311543 0.3133277 0.20360778 0.12191478 0.06214391
-0.16598006 -0.03916245 0.10791545 0.06448431 0.03113508
0.33138022]
[-0.10794992 0.03787376 0.16952753 0.2500641 0.24685495
-0.05109966 0.20483223 0.18794663 0.16794644 0.03811646
-0.17785533 0.00866746 0.13491729 0.06493596 0.055873
0.3487326 ]
[-0.11205798 -0.04663825 0.13637729 0.2688466 0.2944545
-0.06623676 0.24580626 0.1894824 0.12357055 0.08545923
-0.13890322 0.02125055 0.12671538 0.05041068 0.10938939
0.37651145]
[-0.14464049 -0.11277611 0.12929943 0.2506328 0.32429394
-0.06989705 0.26676533 0.22626272 0.14871088 0.06151669
-0.14160013 0.01764496 0.15616798 0.06309532 0.11477884
0.3533678 ]
[-0.1919359 -0.14934857 0.12687694 0.2482472 0.30332044
-0.02129422 0.24142255 0.19039477 0.1872613 0.05607529
-0.10981983 0.02655923 0.19725962 0.15991098 0.08460074
0.32532936]
[-0.15997384 -0.16905244 0.12601317 0.24978957 0.3109707
-0.05129525 0.25644392 0.18721735 0.23115595 0.07164647
-0.04363466 0.09616573 0.23608637 0.23462081 0.16639999
0.36137852]
[-0.17784727 -0.19330868 0.12555353 0.25036657 0.3237954
-0.05024423 0.27374345 0.16953917 0.3444527 0.074378
0.12866443 0.11058272 0.34053382 0.47292238 0.20279881
0.42136478]]
[[ 0.09268619 0.35032618 0.34263822 0.33635783 0.19130397
0.089779 0.3541034 0.26252666 0.15370639 0.05593391
-0.16430146 -0.00316385 0.14068598 0.13546935 -0.01566708
0.32892445]
[ 0.00249528 0.16723414 0.19037648 0.32905748 0.20670214
-0.01093364 0.22814633 0.10346357 0.14574584 0.08942283
-0.13508694 0.02989143 0.13283192 0.155128 -0.00928066
0.38435996]
[-0.09191902 0.02066077 0.1762495 0.2693505 0.2615397
-0.07361222 0.17539641 0.12341685 0.14845897 0.06833903
-0.15054268 0.02503714 0.12414654 0.08736143 0.07049443
0.35888508]
[-0.08116069 -0.0288023 0.12298302 0.24174306 0.3107592
-0.07053182 0.23929915 0.17529318 0.09909797 0.10476568
-0.13906275 -0.0065798 0.12028767 0.09093229 0.08531829
0.33838242]
[-0.08996075 -0.04482763 0.10432535 0.18569301 0.29469466
-0.064595 0.21119419 0.19096416 0.15567164 0.06260847
-0.15861334 -0.01660161 0.17961282 0.14018227 0.05389842
0.32480207]
[-0.13079894 -0.12208281 0.11661161 0.20262218 0.31364897
-0.09002802 0.23725566 0.21705934 0.20321131 0.03772969
-0.12727125 0.04301733 0.21097985 0.16362298 0.12457186
0.3570657 ]
[-0.14077222 -0.14493458 0.10797977 0.20154148 0.32082993
-0.06558356 0.24276899 0.20433648 0.23955566 0.04574178
-0.03365875 0.05299059 0.26905897 0.3059458 0.11437013
0.3523326 ]
[-0.20353709 -0.20380074 0.12652008 0.19772139 0.28259847
-0.04320877 0.1549557 0.12743628 0.37037018 0.04201189
0.16136979 0.10812846 0.3535916 0.573114 0.14248823
0.42301312]]]
====== single layer bi lstm hn0 shape: (2, 4, 8) ======
[[[-0.15227936 -0.16710246 0.11279354 0.2324703 0.3158889
-0.05391366 0.28967926 0.21905534]
[-0.19805393 -0.17752953 0.12876241 0.21628919 0.3038769
-0.036511 0.1357605 0.10460708]
[-0.17784727 -0.19330868 0.12555353 0.25036657 0.3237954
-0.05024423 0.27374345 0.16953917]
[-0.20353709 -0.20380074 0.12652008 0.19772139 0.28259847
-0.04320877 0.1549557 0.12743628]]
[[ 0.08194006 0.11577142 -0.09744498 -0.02763127 0.09280778
0.08716499 0.02522062 0.33181873]
[ 0.10343523 0.07151344 -0.13616626 -0.04245609 0.10985457
0.06919786 0.0364913 0.31924048]
[ 0.14828831 0.04973472 -0.18127252 -0.02664946 0.11601479
0.06740937 0.0379785 0.342705 ]
[ 0.15370639 0.05593391 -0.16430146 -0.00316385 0.14068598
0.13546935 -0.01566708 0.32892445]]]
====== single layer bi lstm cn0 shape: (2, 4, 8) ======
[[[-0.48307976 -0.40690032 0.24048738 0.49366224 0.5961513
-0.13565473 0.5191028 0.48418468]
[-0.55306923 -0.41890883 0.31527558 0.4081013 0.5560535
-0.10868378 0.22270739 0.224445 ]
[-0.5595058 -0.5172409 0.28816614 0.4680259 0.6353333
-0.1406159 0.45408633 0.39424264]
[-0.55914015 -0.42366728 0.29431793 0.42468843 0.5133875
-0.11134674 0.27713037 0.2564772 ]]
[[ 0.13141792 0.26979685 -0.20174497 -0.06629345 0.16831748
0.14618596 0.05280813 0.84774 ]
[ 0.16957031 0.19068424 -0.28012666 -0.10653219 0.1932735
0.12457087 0.07286038 0.91865647]
[ 0.25553685 0.1275407 -0.37673476 -0.06495219 0.21608156
0.11330918 0.07597075 0.97954106]
[ 0.2739099 0.14198926 -0.342751 -0.00778307 0.25392675
0.23573248 -0.03052862 0.89955646]]]
====== single layer bi lstm output 1 shape: (4, 8, 16) ======
[[[ 0.11591419 0.299611 0.3425573 0.4287143 0.17212108
0.07444337 0.43271446 0.15715674 0.14267941 0.11772849
-0.08396029 -0.0199183 0.17602898 0.19761203 0.06850712
0.30409858]
[-0.01308823 0.1362367 0.19448121 0.3702814 0.22777143
0.00628781 0.39128026 0.1550157 0.19404428 0.11392959
-0.04281732 0.02546077 0.24461909 0.24037687 0.16997418
0.30728906]
[-0.05627449 0.04682725 0.15380071 0.3137156 0.26430035
-0.04651401 0.3572325 0.1658463 0.32523182 0.10201547
0.12631407 0.07232428 0.37344953 0.46444228 0.22052252
0.38782993]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]]
[[ 0.07946795 0.30921736 0.35205007 0.37194842 0.2058839
0.09482589 0.4332572 0.27750388 0.10343523 0.07151344
-0.13616627 -0.04245608 0.10985459 0.06919786 0.0364913
0.31924048]
[-0.04591701 0.14795585 0.20307627 0.35713255 0.21074952
0.03478044 0.36047992 0.1535143 0.11235587 0.07168273
-0.11715946 -0.02380875 0.11772133 0.11803672 0.00387635
0.33266184]
[-0.09412251 0.02499679 0.17255405 0.31780577 0.23692457
-0.03471331 0.265765 0.10732021 0.14581607 0.07355653
-0.12852795 0.01927058 0.13053373 0.14796041 0.01590303
0.38545772]
[-0.09348419 0.00631614 0.14661779 0.228482 0.2296661
-0.05388563 0.14963126 0.08823042 0.15729474 0.0657778
-0.15222837 -0.01835432 0.15758416 0.17561477 -0.03188463
0.35117778]
[-0.15382743 -0.04836275 0.14573918 0.22835778 0.25323635
-0.03674608 0.14017357 0.09852324 0.17570391 0.04582136
-0.13850203 0.00081274 0.16863164 0.14211491 0.04397457
0.33833435]
[-0.14028388 -0.08847751 0.13194019 0.21878807 0.28851762
-0.06432837 0.15592363 0.16226488 0.20294866 0.04400881
-0.11535563 0.04870294 0.22049154 0.17808372 0.09339967
0.34441146]
[-0.1683049 -0.16189072 0.1318028 0.22591396 0.30270752
-0.07447628 0.15145041 0.13298061 0.2544369 0.06014251
-0.01793558 0.11026147 0.2146467 0.31185657 0.1214122
0.39812005]
[-0.19805394 -0.17752953 0.12876241 0.21628918 0.30387694
-0.036511 0.1357605 0.10460708 0.3527281 0.07156998
0.1540587 0.09252883 0.35960466 0.54258245 0.16377063
0.40849966]]
[[ 0.08452003 0.31591052 0.3420099 0.3319746 0.2028576
0.08632328 0.3581056 0.2776015 0.16127887 0.05090985
-0.18798977 -0.03278283 0.14869703 0.09618111 0.05077953
0.32884052]
[-0.0266434 0.16035606 0.18312001 0.31999707 0.22840345
0.01311543 0.31332764 0.20360778 0.14828573 0.06162609
-0.16532603 -0.04184524 0.17109753 0.11741111 0.05272176
0.31123316]
[-0.10794992 0.03787376 0.16952753 0.2500641 0.24685495
-0.05109966 0.2048322 0.18794663 0.21637706 0.03754523
-0.15342048 0.0159312 0.2186653 0.17495207 0.09126361
0.32591543]
[-0.11205798 -0.04663826 0.13637729 0.2688466 0.2944545
-0.06623676 0.24580622 0.1894824 0.21777555 0.08560579
-0.0555483 0.0522357 0.2504716 0.23061936 0.18061498
0.34555358]
[-0.14464049 -0.11277609 0.12929943 0.2506328 0.32429394
-0.06989705 0.26676533 0.22626273 0.34267974 0.06394035
0.10800922 0.07929072 0.38286424 0.44688055 0.22619261
0.38621217]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]]
[[ 0.09268619 0.35032618 0.34263822 0.33635783 0.19130397
0.089779 0.3541034 0.26252666 0.34620598 0.06714007
0.13512857 0.04233981 0.42014182 0.5216394 0.18838547
0.3683127 ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]
[ 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. 0. 0. 0. 0.
0. ]]]
====== single layer bi lstm hn1 shape: (2, 4, 8) ======
[[[-0.05627449 0.04682725 0.15380071 0.3137156 0.26430035
-0.04651401 0.3572325 0.1658463 ]
[-0.19805394 -0.17752953 0.12876241 0.21628918 0.30387694
-0.036511 0.1357605 0.10460708]
[-0.14464049 -0.11277609 0.12929943 0.2506328 0.32429394
-0.06989705 0.26676533 0.22626273]
[ 0.09268619 0.35032618 0.34263822 0.33635783 0.19130397
0.089779 0.3541034 0.26252666]]
[[ 0.14267941 0.11772849 -0.08396029 -0.0199183 0.17602898
0.19761203 0.06850712 0.30409858]
[ 0.10343523 0.07151344 -0.13616627 -0.04245608 0.10985459
0.06919786 0.0364913 0.31924048]
[ 0.16127887 0.05090985 -0.18798977 -0.03278283 0.14869703
0.09618111 0.05077953 0.32884052]
[ 0.34620598 0.06714007 0.13512857 0.04233981 0.42014182
0.5216394 0.18838547 0.3683127 ]]]
====== single layer bi lstm cn1 shape: (2, 4, 8) ======
[[[-0.16340391 0.12338591 0.36321753 0.60983956 0.4963916
-0.14528881 0.61422133 0.37583172]
[-0.5530693 -0.41890883 0.31527558 0.40810126 0.5560536
-0.10868377 0.22270739 0.22444502]
[-0.46137562 -0.27004397 0.27595642 0.5348579 0.62363803
-0.18086377 0.46610427 0.4973321 ]
[ 0.23746979 0.6868869 0.56339467 0.96855223 0.39346337
0.32335475 0.7259624 0.4185825 ]]
[[ 0.22938183 0.2952913 -0.17549752 -0.05000385 0.33509728
0.3336044 0.14473113 0.7370499 ]
[ 0.16957031 0.19068426 -0.2801267 -0.10653219 0.19327351
0.12457087 0.07286038 0.91865647]
[ 0.27940926 0.13317151 -0.39137632 -0.081429 0.28198367
0.16170114 0.10146889 0.91004795]
[ 0.6180897 0.28882137 0.28748003 0.15160248 0.7991137
0.90929043 0.45457762 0.8128108 ]]]
3.3 双层双向LSTM
本示例中随机生成了[4, 8, 4]数据,该数据batch_size为4,固定seq_length为8,输入维度为4。
本示例采用双层双向LSTM,隐层大小为8。
本示例中LSTM调用时进行对比测试,一个seq_length 为默认值None,一个为有效长度input_seq_length 。
示例代码如下:
import numpy as np
from mindspore import dtype
from mindspore import Tensor
from mindspore.nn import LSTM
def double_layer_bi_lstm():
random_data = np.random.rand(4, 8, 4)
seq_length = [3, 8, 5, 1]
input_seq_data = Tensor(random_data, dtype=dtype.float32)
input_seq_length = Tensor(seq_length, dtype=dtype.int32)
batch_size = 4
input_size = 4
hidden_size = 8
num_layers = 2
bidirectional = True
num_bi = 2 if bidirectional else 1
lstm = LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
has_bias=True, batch_first=True, dropout=0.0, bidirectional=bidirectional)
h0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
c0 = Tensor(np.ones([num_bi * num_layers, batch_size, hidden_size]).astype(np.float32))
output_0, (hn_0, cn_0) = lstm(input_seq_data, (h0, c0))
output_1, (hn_1, cn_1) = lstm(input_seq_data, (h0, c0), input_seq_length)
print("====== double layer bi lstm output 0 shape: {} ======\n{}".format(output_0.shape, output_0), flush=True)
print("====== double layer bi lstm hn0 shape: {} ======\n{}".format(hn_0.shape, hn_0), flush=True)
print("====== double layer bi lstm cn0 shape: {} ======\n{}".format(cn_0.shape, cn_0), flush=True)
print("====== double layer bi lstm output 1 shape: {} ======\n{}".format(output_1.shape, output_1), flush=True)
print("====== double layer bi lstm hn1 shape: {} ======\n{}".format(hn_1.shape, hn_1), flush=True)
print("====== double layer bi lstm cn1 shape: {} ======\n{}".format(cn_1.shape, cn_1), flush=True)
示例代码输出内容如下:
对输出内容进行分析:
- output_0和output_1维度都是[4, 8, 16],即batch_size, seq_length和hidden_size * 2,这里乘2是因为是双向输出。
- output_0和output_1皆是第二层(最后一层)的输出,中间层(本例为第一层)输出没有显示给出。
- output_0对应的是调用时seq_length为None的情况,即默认有效seq_length为8,可以看到output_0各个长度输出数值皆非全零。
- output_1对应的是调用时seq_length为设定值[3, 8, 5, 1],可以看到output_1超过有效长度的输出部分皆为全零。
- hn和cn分别为隐层状态和细胞状态输出。下面以hn_1和cn_1为例进行讲解。
- hn_1维度为[4, 4, 8],4代表双向双层(2*2),4代表batch_size,8代表hidden_size。
- 6中说明4代表双向双层(2*2),hn_1包含各层的最终有效隐层状态输出,这里同output_1只包含最后一层的输出不同。
- 仔细观察可以看出,hn_1中第一维度第2索引位置(即最后一层)的正向输出部分与output_1最后一维输出前hidden_size数值一致,即与有效长度内最后一个的输出的前hidden_size数值保持一致。
- 仔细观察可以看出,hn_1中第一维度第3索引位置(即最后一层)的反向输出部分与output_1开始一维输出后hidden_size数值一致。
- cn_1为有效最后一步的细胞状态。
====== double layer bi lstm output 0 shape: (4, 8, 16) ======
[[[ 3.70550364e-01 2.17652053e-01 3.79816592e-01 5.39002419e-01
2.28588611e-01 3.83301824e-02 2.20795229e-01 2.44438455e-01
2.06572518e-01 -3.78293954e-02 2.60271341e-01 -4.60247397e-02
-3.78369205e-02 -1.90976545e-01 -1.01466656e-01 1.76680252e-01]
[ 1.65173441e-01 7.22418576e-02 4.98769164e-01 2.52682149e-01
2.94478923e-01 -1.56086944e-02 1.32235214e-01 4.96024750e-02
1.81777030e-01 -7.20555857e-02 2.31085896e-01 7.43698841e-03
-2.21280195e-02 -1.63902551e-01 -8.19268897e-02 1.90522313e-01]
[ 3.94219235e-02 -8.84856097e-03 4.88511086e-01 1.51095495e-01
2.83691764e-01 -2.36562286e-02 1.14125453e-01 -4.99135666e-02
1.84900641e-01 -9.07974318e-02 2.06634849e-01 5.43768853e-02
-2.88773868e-02 -1.41080543e-01 -7.59911761e-02 1.93940982e-01]
[-3.12361736e-02 -5.87114133e-02 4.53683615e-01 7.93214589e-02
2.92402357e-01 -2.14897078e-02 1.08925141e-01 -9.88882780e-02
1.98123455e-01 -8.50049481e-02 1.91045731e-01 8.83036405e-02
-1.40397642e-02 -1.22237459e-01 -6.35140762e-02 1.80813670e-01]
[-6.87201098e-02 -6.12376854e-02 4.39131975e-01 2.83084475e-02
2.86313444e-01 -9.33245104e-03 1.12482831e-01 -1.27253398e-01
2.32264340e-01 -7.87357539e-02 1.86317161e-01 1.59440145e-01
1.36264751e-03 -9.95954126e-02 -4.97992262e-02 1.69756234e-01]
[-8.29227120e-02 -6.19332492e-02 4.27550107e-01 -1.70003679e-02
2.88041800e-01 3.62846977e-03 1.04239471e-01 -1.43706441e-01
2.90384740e-01 -5.84731065e-02 1.86135545e-01 2.26804867e-01
3.95135172e-02 -6.33978993e-02 -1.63939036e-02 1.48533911e-01]
[-8.72429982e-02 -6.10240139e-02 4.20974702e-01 -6.44157380e-02
2.92603880e-01 2.60243341e-02 9.26012769e-02 -1.46479979e-01
3.93343538e-01 -1.41044548e-02 1.96197629e-01 3.05834383e-01
1.02294169e-01 2.09005456e-03 5.07600456e-02 1.33950055e-01]
[-8.48584175e-02 -4.15292941e-02 4.26153004e-01 -1.12198450e-01
2.93441713e-01 4.73045520e-02 7.22456872e-02 -1.52661309e-01
6.08003795e-01 1.02589525e-01 2.28410736e-01 3.57809156e-01
2.30974391e-01 7.29562640e-02 1.54908523e-01 1.37615114e-01]]
[[ 3.73128176e-01 2.24487275e-01 3.83654892e-01 5.39644539e-01
2.24863932e-01 3.69703583e-02 2.22563371e-01 2.47377262e-01
2.09958509e-01 -3.67934220e-02 2.55294740e-01 -5.44558465e-02
-3.49954516e-02 -1.88630879e-01 -9.97974724e-02 1.72440261e-01]
[ 1.63444579e-01 7.47621208e-02 4.95126337e-01 2.49838263e-01
2.98441172e-01 -2.29644943e-02 1.30464450e-01 4.65075821e-02
1.87749639e-01 -5.69685884e-02 2.30926782e-01 -1.89751368e-02
-6.57672016e-03 -1.64301425e-01 -7.78417960e-02 1.70920238e-01]
[ 2.62361914e-02 -2.09027641e-02 4.81326580e-01 1.54101923e-01
2.95957267e-01 -3.76441851e-02 1.13665104e-01 -5.53984046e-02
1.96336910e-01 -6.99553713e-02 2.13279501e-01 2.09173746e-02
-6.90750126e-03 -1.42273992e-01 -7.39771128e-02 1.70230061e-01]
[-4.55044061e-02 -8.81957486e-02 4.49505210e-01 8.37849677e-02
3.12549353e-01 -3.09768375e-02 9.69471037e-02 -9.93652195e-02
2.07049429e-01 -6.65001795e-02 1.99929893e-01 4.60516922e-02
9.15598311e-03 -1.23334207e-01 -6.36003762e-02 1.58215716e-01]
[-8.54137391e-02 -9.59964097e-02 4.28828478e-01 2.81018596e-02
3.12747598e-01 -1.96594596e-02 1.04248613e-01 -1.21685371e-01
2.44353175e-01 -6.95914254e-02 1.87495902e-01 1.23339958e-01
1.20015517e-02 -9.19487774e-02 -5.30561097e-02 1.52850106e-01]
[-1.03212453e-01 -9.74086747e-02 4.10266966e-01 -2.03387272e-02
3.20133060e-01 5.47134259e-04 1.07527576e-01 -1.26215830e-01
3.05427969e-01 -5.22202961e-02 1.89031556e-01 2.16380343e-01
4.27359492e-02 -5.31105101e-02 -2.50125714e-02 1.44858196e-01]
[-1.06917843e-01 -7.13929683e-02 4.10624832e-01 -5.51486127e-02
3.07110429e-01 1.92907490e-02 1.03878655e-01 -1.38662428e-01
4.00884181e-01 -1.43600125e-02 1.82524621e-01 2.97586024e-01
9.52146128e-02 9.59962141e-03 5.30949272e-02 1.37635604e-01]
[-9.71160829e-02 -4.43801992e-02 4.20233607e-01 -1.02356419e-01
3.03063601e-01 3.99401113e-02 8.28935355e-02 -1.43912748e-01
6.09543681e-01 1.04935512e-01 2.27933496e-01 3.57850134e-01
2.31336534e-01 7.57181123e-02 1.55172557e-01 1.39436752e-01]]
[[ 3.74232024e-01 2.23312378e-01 3.80826175e-01 5.25748074e-01
2.30494052e-01 3.75359394e-02 2.19325155e-01 2.45338157e-01
1.90327644e-01 -9.49237868e-03 2.51282185e-01 -4.07305919e-02
-7.68693071e-03 -1.96041882e-01 -9.43402052e-02 1.52500823e-01]
[ 1.65756628e-01 8.52986127e-02 5.00474215e-01 2.32285380e-01
2.97197372e-01 -2.87767611e-02 1.31484732e-01 4.05624248e-02
1.72598451e-01 -3.74435596e-02 2.30013907e-01 1.03627918e-02
1.63554456e-02 -1.71838194e-01 -7.55213797e-02 1.56671956e-01]
[ 3.51614878e-02 -3.49920541e-02 4.85133171e-01 1.37813956e-01
3.03884476e-01 -3.76141518e-02 9.96868908e-02 -4.97255772e-02
1.81163609e-01 -4.24254723e-02 2.27177203e-01 3.23883444e-02
2.71688756e-02 -1.56165496e-01 -6.69138283e-02 1.53632939e-01]
[-4.10026051e-02 -8.96424949e-02 4.60784853e-01 8.30888674e-02
3.03816915e-01 -2.20339652e-02 9.38846841e-02 -9.45615992e-02
2.04564795e-01 -4.51925248e-02 2.18029544e-01 6.01283386e-02
3.36706154e-02 -1.35854393e-01 -5.57745472e-02 1.48557410e-01]
[-8.25456828e-02 -1.13149934e-01 4.36939508e-01 3.75392586e-02
3.10225427e-01 -7.73321884e-03 9.12441462e-02 -1.16306305e-01
2.42686659e-01 -4.25874330e-02 2.11468235e-01 1.09053820e-01
4.69379947e-02 -1.04551256e-01 -4.02252935e-02 1.34793952e-01]
[-1.08169496e-01 -1.15720116e-01 4.16452408e-01 4.10868321e-03
3.16107094e-01 6.06524665e-03 9.51950625e-02 -1.27826288e-01
3.06058168e-01 -3.21962573e-02 2.01961204e-01 1.87839821e-01
6.73103184e-02 -5.98271154e-02 -1.05028180e-02 1.28264755e-01]
[-1.16449505e-01 -1.07103497e-01 4.10319597e-01 -3.42636257e-02
3.23818535e-01 2.40915213e-02 9.08538699e-02 -1.28739789e-01
4.00041372e-01 5.13588311e-03 2.06977740e-01 2.77402431e-01
1.18934669e-01 6.60364656e-03 5.48240133e-02 1.22762337e-01]
[-1.07369550e-01 -7.64680207e-02 4.24612671e-01 -8.88631567e-02
3.25147092e-01 5.22605665e-02 7.02133700e-02 -1.30118832e-01
6.03053808e-01 1.08490229e-01 2.35621274e-01 3.42306137e-01
2.33348757e-01 7.23976195e-02 1.51835442e-01 1.38724014e-01]]
[[ 3.68833274e-01 2.19720796e-01 3.75712991e-01 5.39344609e-01
2.32777387e-01 3.75517495e-02 2.15990663e-01 2.38119900e-01
2.03846872e-01 -3.31601547e-03 2.63746709e-01 -5.33154309e-02
-1.53900171e-02 -1.96350247e-01 -9.86721516e-02 1.51238605e-01]
[ 1.61587596e-01 7.25713074e-02 4.97545034e-01 2.48409301e-01
3.00032824e-01 -2.52650958e-02 1.25469610e-01 4.12617065e-02
1.75564945e-01 -3.84877101e-02 2.34954998e-01 1.90881861e-03
7.01279286e-03 -1.72224715e-01 -7.77121335e-02 1.60935923e-01]
[ 2.84800380e-02 -2.69929953e-02 4.86053288e-01 1.57494590e-01
2.96494991e-01 -3.40557620e-02 1.04029477e-01 -5.39027080e-02
1.82317436e-01 -5.37234657e-02 2.23423839e-01 4.04849648e-02
8.95922631e-03 -1.53901607e-01 -7.44922534e-02 1.65948585e-01]
[-3.72786410e-02 -7.53442869e-02 4.61774200e-01 8.63353312e-02
2.97733396e-01 -2.75274049e-02 9.13189948e-02 -1.00060880e-01
1.94108337e-01 -5.79617955e-02 2.08687440e-01 6.31403774e-02
2.11703759e-02 -1.34831637e-01 -6.31042644e-02 1.52588978e-01]
[-7.60064349e-02 -1.06220305e-01 4.34687048e-01 3.19332667e-02
3.09678972e-01 -1.16188908e-02 8.85540992e-02 -1.18266501e-01
2.29653955e-01 -5.94241545e-02 2.00053185e-01 1.14932276e-01
3.13343108e-02 -1.04001120e-01 -4.90994565e-02 1.44359529e-01]
[-9.52797905e-02 -9.27509218e-02 4.22483116e-01 -1.29148299e-02
3.04568678e-01 9.32686683e-03 9.81104076e-02 -1.28704712e-01
2.98035592e-01 -5.08954525e-02 1.98656082e-01 2.12906018e-01
5.04655764e-02 -6.18565194e-02 -2.38872226e-02 1.40028179e-01]
[-9.81744751e-02 -8.54582712e-02 4.15283144e-01 -6.42896220e-02
3.11841279e-01 3.18106599e-02 8.80582407e-02 -1.32987425e-01
3.88665676e-01 -1.39519377e-02 1.92815915e-01 2.86827296e-01
1.07908145e-01 2.11709971e-03 4.85477857e-02 1.27813160e-01]
[-9.11041871e-02 -4.77942340e-02 4.29545075e-01 -1.14117011e-01
3.04611683e-01 5.14086746e-02 7.33837485e-02 -1.44734517e-01
6.06585741e-01 9.89784896e-02 2.24559098e-01 3.55441421e-01
2.28052005e-01 7.30600879e-02 1.55306384e-01 1.37683451e-01]]]
====== double layer bi lstm hn0 shape: (4, 4, 8) ======
[[[ 0.25934413 -0.07461581 0.19370164 0.11095355 0.02041678
0.29797387 0.03047622 0.19640712]
[ 0.2874061 -0.08844143 0.22119689 0.1251989 -0.01900517
0.29294112 0.05027778 0.2071664 ]
[ 0.2596095 0.03271259 0.26155 0.10348854 0.08536521
0.28197888 -0.08929807 0.18018515]
[ 0.2509837 -0.07010224 0.20813467 0.10349585 0.04007874
0.27277622 0.01278557 0.18474495]]
[[-0.00949934 0.10407767 0.038502 0.14573903 -0.14825179
-0.08745017 0.3038079 0.28010136]
[ 0.05813041 0.14894389 0.05397653 0.15691832 -0.16107248
-0.06869183 0.27977887 0.26698047]
[-0.05296279 0.02392143 0.06922498 0.16198513 -0.12499766
-0.063968 0.2682934 0.25862688]
[-0.03301367 0.04014921 -0.00048225 0.1180163 -0.12858163
-0.07102007 0.35664883 0.26105112]]
[[-0.08485842 -0.04152929 0.426153 -0.11219845 0.2934417
0.04730455 0.07224569 -0.15266131]
[-0.09711608 -0.0443802 0.4202336 -0.10235642 0.3030636
0.03994011 0.08289354 -0.14391275]
[-0.10736955 -0.07646802 0.42461267 -0.08886316 0.3251471
0.05226057 0.07021337 -0.13011883]
[-0.09110419 -0.04779423 0.42954507 -0.11411701 0.30461168
0.05140867 0.07338375 -0.14473452]]
[[ 0.20657252 -0.0378294 0.26027134 -0.04602474 -0.03783692
-0.19097655 -0.10146666 0.17668025]
[ 0.20995851 -0.03679342 0.25529474 -0.05445585 -0.03499545
-0.18863088 -0.09979747 0.17244026]
[ 0.19032764 -0.00949238 0.2512822 -0.04073059 -0.00768693
-0.19604188 -0.09434021 0.15250082]
[ 0.20384687 -0.00331602 0.2637467 -0.05331543 -0.01539002
-0.19635025 -0.09867215 0.1512386 ]]]
====== double layer bi lstm cn0 shape: (4, 4, 8) ======
[[[ 0.5770398 -0.16899881 0.40028483 0.25001454 0.04046626
0.57915956 0.05266067 0.52447474]
[ 0.66343445 -0.19959925 0.49729916 0.27566156 -0.03596141
0.5509572 0.0853648 0.5394346 ]
[ 0.5707181 0.07038814 0.5712474 0.2565448 0.1530705
0.57276523 -0.15605333 0.46282846]
[ 0.55990976 -0.16366895 0.4313923 0.23668876 0.08243398
0.53433377 0.02196771 0.4817235 ]]
[[-0.02554817 0.2071405 0.07978731 0.2778875 -0.24753608
-0.2485388 0.62492937 0.6474521 ]
[ 0.16052538 0.31375027 0.1059354 0.2853353 -0.26115927
-0.20904504 0.5899866 0.56931025]
[-0.14657407 0.05189808 0.13706218 0.33399543 -0.2142592
-0.16363172 0.612855 0.61697096]
[-0.0884767 0.07950284 -0.00107491 0.2254872 -0.21063672
-0.20023198 0.72448045 0.60711044]]
[[-0.2504415 -0.0814982 0.7923428 -0.19285998 0.5903069
0.13990048 0.15511556 -0.2908177 ]
[-0.28950468 -0.08669281 0.7886544 -0.17458251 0.6081315
0.12001925 0.17698732 -0.2759574 ]
[-0.30495524 -0.14845964 0.79688644 -0.15463473 0.6548568
0.15446547 0.1526669 -0.24459954]
[-0.265516 -0.09397535 0.79843074 -0.19696996 0.6198776
0.15148453 0.15768716 -0.275381 ]]
[[ 0.32853472 -0.05710489 0.7447654 -0.0758819 -0.09938034
-0.47783113 -0.28168824 0.36019933]
[ 0.33408064 -0.05591211 0.7391405 -0.08961775 -0.0917803
-0.47115833 -0.278066 0.35383248]
[ 0.30187273 -0.01431822 0.7146605 -0.06792408 -0.02012375
-0.48834586 -0.26035625 0.3151392 ]
[ 0.32118577 -0.00497683 0.7502155 -0.08775105 -0.04013083
-0.4903597 -0.27541417 0.30617815]]]
====== double layer bi lstm output 1 shape: (4, 8, 16) ======
[[[ 3.5416836e-01 2.0936093e-01 3.8317284e-01 5.3357160e-01
2.4053907e-01 4.1459590e-02 2.0509864e-01 2.5311515e-01
3.7313861e-01 2.2726113e-02 2.4815443e-01 1.6349553e-01
1.1913014e-02 -1.0416587e-01 -4.6682160e-02 1.2466244e-01]
[ 1.6695338e-01 8.1573747e-02 5.0642765e-01 2.2585270e-01
3.1199178e-01 7.0200888e-03 1.0298288e-01 7.1754217e-02
4.2964008e-01 2.7423983e-02 2.2389892e-01 2.8188041e-01
9.3678713e-02 -1.6824452e-02 4.4604652e-02 1.2561245e-01]
[ 6.0777575e-02 3.0208385e-02 5.1636058e-01 8.0109224e-02
3.0168548e-01 1.5010678e-02 5.8312915e-02 -2.7518146e-02
6.2040079e-01 1.1676422e-01 2.4167898e-01 3.6679846e-01
2.2570200e-01 6.9053181e-02 1.5332413e-01 1.3909420e-01]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
[[ 3.7312818e-01 2.2448727e-01 3.8365489e-01 5.3964454e-01
2.2486390e-01 3.6970358e-02 2.2256340e-01 2.4737728e-01
2.0995849e-01 -3.6793407e-02 2.5529474e-01 -5.4455854e-02
-3.4995444e-02 -1.8863088e-01 -9.9797480e-02 1.7244026e-01]
[ 1.6344458e-01 7.4762136e-02 4.9512634e-01 2.4983825e-01
2.9844120e-01 -2.2964491e-02 1.3046446e-01 4.6507578e-02
1.8774964e-01 -5.6968573e-02 2.3092678e-01 -1.8975141e-02
-6.5767197e-03 -1.6430146e-01 -7.7841796e-02 1.7092024e-01]
[ 2.6236186e-02 -2.0902762e-02 4.8132658e-01 1.5410189e-01
2.9595733e-01 -3.7644185e-02 1.1366512e-01 -5.5398405e-02
1.9633688e-01 -6.9955371e-02 2.1327947e-01 2.0917373e-02
-6.9075003e-03 -1.4227399e-01 -7.3977120e-02 1.7023006e-01]
[-4.5504406e-02 -8.8195749e-02 4.4950521e-01 8.3784960e-02
3.1254938e-01 -3.0976830e-02 9.6947111e-02 -9.9365219e-02
2.0704943e-01 -6.6500187e-02 1.9992988e-01 4.6051688e-02
9.1559850e-03 -1.2333421e-01 -6.3600369e-02 1.5821570e-01]
[-8.5413747e-02 -9.5996402e-02 4.2882851e-01 2.8101865e-02
3.1274763e-01 -1.9659458e-02 1.0424862e-01 -1.2168537e-01
2.4435315e-01 -6.9591425e-02 1.8749590e-01 1.2333996e-01
1.2001552e-02 -9.1948770e-02 -5.3056102e-02 1.5285012e-01]
[-1.0321245e-01 -9.7408667e-02 4.1026697e-01 -2.0338718e-02
3.2013306e-01 5.4713513e-04 1.0752757e-01 -1.2621583e-01
3.0542794e-01 -5.2220318e-02 1.8903156e-01 2.1638034e-01
4.2735931e-02 -5.3110521e-02 -2.5012573e-02 1.4485820e-01]
[-1.0691784e-01 -7.1392961e-02 4.1062483e-01 -5.5148609e-02
3.0711043e-01 1.9290760e-02 1.0387863e-01 -1.3866244e-01
4.0088418e-01 -1.4360026e-02 1.8252462e-01 2.9758602e-01
9.5214583e-02 9.5995963e-03 5.3094927e-02 1.3763560e-01]
[-9.7116083e-02 -4.4380195e-02 4.2023361e-01 -1.0235640e-01
3.0306363e-01 3.9940134e-02 8.2893521e-02 -1.4391276e-01
6.0954368e-01 1.0493548e-01 2.2793353e-01 3.5785013e-01
2.3133652e-01 7.5718097e-02 1.5517256e-01 1.3943677e-01]]
[[ 3.6901441e-01 2.1822800e-01 3.7994039e-01 5.2547783e-01
2.3396042e-01 3.9366722e-02 2.1538821e-01 2.4702020e-01
2.4914475e-01 -6.9778422e-03 2.4806115e-01 2.1838229e-02
-1.3991867e-02 -1.6620368e-01 -8.7110944e-02 1.4123847e-01]
[ 1.6616049e-01 8.4187903e-02 4.9948204e-01 2.2646046e-01
3.0369779e-01 -1.7643329e-02 1.2668489e-01 4.9117617e-02
2.6261702e-01 -2.7619595e-02 2.2540939e-01 1.1914852e-01
2.3004401e-02 -1.2194993e-01 -5.5561494e-02 1.3998528e-01]
[ 4.2908981e-02 -2.5578242e-02 4.8486653e-01 1.1890158e-01
3.1149039e-01 -1.4618633e-02 9.1249026e-02 -3.3213440e-02
3.1701097e-01 -1.8276740e-02 2.2031868e-01 2.0087981e-01
5.8553118e-02 -7.3650509e-02 -1.7827954e-02 1.3095699e-01]
[-2.2401063e-02 -6.7246288e-02 4.6379456e-01 4.6429519e-02
3.1024706e-01 1.2560772e-02 7.6885723e-02 -7.1739145e-02
4.0658230e-01 1.3608186e-02 2.1248461e-01 2.7639762e-01
1.0969905e-01 -1.7181308e-03 5.7507429e-02 1.2614906e-01]
[-4.9086079e-02 -6.1570432e-02 4.6209678e-01 -3.5342608e-02
3.1426692e-01 4.2432975e-02 5.4815758e-02 -9.5721334e-02
6.0554379e-01 1.1493160e-01 2.4293001e-01 3.4404746e-01
2.3283333e-01 6.8980336e-02 1.5239350e-01 1.3767722e-01]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
[[ 3.3036014e-01 2.2069807e-01 4.0932164e-01 5.0686938e-01
2.5304586e-01 4.5349576e-02 1.6947377e-01 2.6356062e-01
6.4686131e-01 1.8447271e-01 2.6571944e-01 3.6628011e-01
2.0576611e-01 5.9034787e-02 1.3657802e-01 1.4004102e-01]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]]
====== double layer bi lstm hn1 shape: (4, 4, 8) ======
[[[ 0.30786592 -0.05702875 0.2098356 0.1831936 0.1446731
0.35495615 0.10906219 0.2584008 ]
[ 0.28740606 -0.08844142 0.2211969 0.12519889 -0.01900517
0.29294112 0.05027781 0.2071664 ]
[ 0.25389883 0.05431987 0.24731106 0.1163514 0.12489295
0.31806058 -0.07178076 0.20686159]
[ 0.47720045 0.11175225 0.22376464 0.36412558 0.46750376
0.28765967 0.38535532 0.33306697]]
[[ 0.0012262 0.3199089 -0.02733669 0.17044675 -0.04726706
-0.02164171 0.28464028 0.3348536 ]
[ 0.05813042 0.14894389 0.05397653 0.15691833 -0.16107246
-0.06869183 0.27977887 0.26698047]
[-0.04329334 0.12033389 0.03753637 0.15189895 -0.11344916
-0.04964198 0.27086687 0.28215134]
[ 0.05921583 0.543903 0.00194274 0.27610534 0.16461822
0.25555757 0.18277422 0.3662175 ]]
[[ 0.06077757 0.03020838 0.5163606 0.08010922 0.30168548
0.01501068 0.05831292 -0.02751815]
[-0.09711608 -0.0443802 0.4202336 -0.1023564 0.30306363
0.03994013 0.08289352 -0.14391276]
[-0.04908608 -0.06157043 0.46209678 -0.03534261 0.31426692
0.04243298 0.05481576 -0.09572133]
[ 0.33036014 0.22069807 0.40932164 0.5068694 0.25304586
0.04534958 0.16947377 0.26356062]]
[[ 0.3731386 0.02272611 0.24815443 0.16349553 0.01191301
-0.10416587 -0.04668216 0.12466244]
[ 0.2099585 -0.03679341 0.25529474 -0.05445585 -0.03499544
-0.18863088 -0.09979748 0.17244026]
[ 0.24914475 -0.00697784 0.24806115 0.02183823 -0.01399187
-0.16620368 -0.08711094 0.14123847]
[ 0.6468613 0.18447271 0.26571944 0.3662801 0.20576611
0.05903479 0.13657802 0.14004102]]]
====== double layer bi lstm cn1 shape: (4, 4, 8) ======
[[[ 0.7061355 -0.13162777 0.46092123 0.4033497 0.2930356
0.76054144 0.18314546 0.70929015]
[ 0.6634344 -0.19959924 0.4972992 0.27566153 -0.0359614
0.5509572 0.08536483 0.5394347 ]
[ 0.5526391 0.1161246 0.5316373 0.28497726 0.22511882
0.67451394 -0.12430747 0.5528798 ]
[ 1.0954192 0.29093137 0.8067771 0.8504353 0.7032547
0.97427243 0.5589305 0.8662672 ]]
[[ 0.00324558 0.6688721 -0.05317001 0.32999027 -0.07784042
-0.05728557 0.58330244 0.8111321 ]
[ 0.16052541 0.31375027 0.1059354 0.28533533 -0.26115924
-0.20904504 0.5899867 0.56931025]
[-0.11802054 0.26023 0.07224996 0.31177503 -0.19568688
-0.12562011 0.6177163 0.6840635 ]
[ 0.16791074 1.2188046 0.00349617 0.670789 0.2591958
0.46886685 0.5807996 0.86447406]]
[[ 0.16193499 0.06143508 1.1399425 0.13840833 0.69956493
0.04888431 0.1235408 -0.0485969 ]
[-0.28950468 -0.0866928 0.7886544 -0.17458248 0.6081316
0.12001929 0.17698729 -0.27595744]
[-0.13397661 -0.12149224 0.9074148 -0.06176313 0.6541451
0.12807912 0.1181712 -0.17463374]
[ 0.8489872 0.6016479 1.3853014 0.8196937 1.020999
0.24127276 0.45320526 0.4759813 ]]
[[ 0.6076499 0.03351691 0.812855 0.27901018 0.02922555
-0.26106828 -0.12472634 0.24901994]
[ 0.3340806 -0.05591209 0.7391405 -0.08961776 -0.09178029
-0.47115833 -0.27806604 0.35383248]
[ 0.3964765 -0.01050393 0.7366462 0.03638346 -0.03574796
-0.41335842 -0.23882627 0.28892466]
[ 1.0575086 0.23200202 0.8150203 0.7750988 0.42505968
0.24064866 0.46888143 0.26767123]]]
本文总结
本文简单介绍了LSTM的基本原理,然后结合MindSpore中文档说明,通过案例解说详细介绍参数设定和输入输出情况,让读者更好的理解MindSpore中的LSTM算子。
本文参考
本文为原创文章,版权归作者所有,未经授权不得转载!
|