LSTM输出输出维度图示
单向单层
rnn_seq = nn.RNN(5, 10, 1)
x = torch.randn(6, 3, 5)
out, ht = rnn_seq(x)
out.shape
torch.Size([6, 3, 10])
ht.shape
torch.Size([1, 3, 10])
单向多层
rnn_seq = nn.RNN(5, 10, 3)
x = torch.randn(6, 3, 5)
out, ht = rnn_seq(x)
out.shape
torch.Size([6, 3, 10])
ht.shape
torch.Size([3, 3, 10])
双向单层
rnn_seq = nn.RNN(5, 10, 1, bidirectional=True)
x = torch.randn(6, 3, 5)
out, ht = rnn_seq(x)
out.shape
torch.Size([6, 3, 20])
ht.shape
torch.Size([2, 3, 10])
out[-1]
tensor([[ 0.3912, -0.3131, -0.5704, 0.1386, -0.0805, 0.2840, -0.4612, 0.4176,
0.1815, 0.3982, 0.3504, 0.0681, -0.3936, 0.5383, 0.0282, 0.3985,
0.3291, 0.3125, 0.3637, 0.1893],
[-0.5409, 0.7553, 0.2176, -0.3243, -0.1724, -0.0350, 0.2422, -0.2549,
0.4105, -0.3549, 0.2171, 0.5521, 0.0122, 0.3783, -0.2583, -0.0181,
0.1647, 0.6133, -0.0935, -0.2087],
[-0.1052, 0.7468, -0.3063, -0.3701, -0.5259, -0.3952, -0.4957, 0.0016,
0.7090, -0.1685, 0.2603, 0.1816, -0.3178, 0.3992, -0.6003, -0.5304,
-0.3403, 0.2522, -0.4018, 0.1896]], grad_fn=<SelectBackward>)
ht[0]
tensor([[ 0.3912, -0.3131, -0.5704, 0.1386, -0.0805, 0.2840, -0.4612, 0.4176,
0.1815, 0.3982],
[-0.5409, 0.7553, 0.2176, -0.3243, -0.1724, -0.0350, 0.2422, -0.2549,
0.4105, -0.3549],
[-0.1052, 0.7468, -0.3063, -0.3701, -0.5259, -0.3952, -0.4957, 0.0016,
0.7090, -0.1685]], grad_fn=<SelectBackward>)
out[0]
tensor([[ 0.4740, 0.4121, -0.5868, -0.2711, -0.2606, -0.4430, -0.5782, 0.8062,
0.7675, -0.7180, 0.5319, 0.4218, -0.5257, 0.5148, -0.7651, -0.1566,
-0.1108, 0.2430, -0.1809, -0.1221],
[ 0.4909, 0.1688, -0.2177, -0.2767, 0.2483, -0.3785, -0.3281, 0.8529,
0.6099, -0.4130, 0.3471, 0.6021, -0.7445, 0.1823, -0.6768, 0.2450,
0.1149, 0.2162, -0.3557, -0.5719],
[-0.2244, 0.7206, -0.0976, -0.5866, 0.3540, 0.1325, -0.5411, -0.7152,
0.3517, 0.3375, -0.8289, -0.7162, -0.1566, -0.3909, -0.4418, -0.2623,
0.1497, -0.6729, 0.2449, 0.4574]], grad_fn=<SelectBackward>)
ht[-1]
tensor([[ 0.5319, 0.4218, -0.5257, 0.5148, -0.7651, -0.1566, -0.1108, 0.2430,
-0.1809, -0.1221],
[ 0.3471, 0.6021, -0.7445, 0.1823, -0.6768, 0.2450, 0.1149, 0.2162,
-0.3557, -0.5719],
[-0.8289, -0.7162, -0.1566, -0.3909, -0.4418, -0.2623, 0.1497, -0.6729,
0.2449, 0.4574]], grad_fn=<SelectBackward>)
双向多层
rnn_seq = nn.RNN(5, 10, 2, bidirectional=True)
x = torch.randn(6, 3, 5)
out, ht = rnn_seq(x)
out.shape
torch.Size([6, 3, 20])
ht.shape
torch.Size([4, 3, 10])
out[0]
tensor([[ 0.1704, 0.4581, 0.1723, 0.2236, 0.0432, 0.6190, 0.1974, 0.2000,
-0.5012, -0.1075, -0.1713, -0.4623, -0.3120, 0.0759, 0.4959, -0.8103,
-0.2548, 0.4587, -0.2821, 0.7620],
[ 0.4484, 0.4716, 0.0320, 0.1836, 0.1117, 0.5677, 0.0560, 0.2435,
-0.3318, -0.0584, -0.4861, 0.0913, -0.2517, 0.1683, 0.5459, -0.3377,
-0.6199, -0.4051, -0.2039, 0.4189],
[ 0.5295, 0.4061, -0.1754, 0.2779, -0.0318, 0.6160, 0.1777, 0.5757,
-0.1380, -0.2663, -0.6953, -0.0388, -0.2153, 0.5317, 0.2948, -0.1002,
-0.6486, 0.1166, 0.0067, 0.1381]], grad_fn=<SelectBackward>)
ht[2:]
tensor([[[-0.0053, 0.7034, 0.1532, -0.0558, 0.5286, 0.8320, -0.2079,
-0.5418, -0.4331, 0.1198],
[ 0.2822, 0.6841, -0.5430, 0.1567, 0.5371, 0.8532, 0.0513,
-0.5214, -0.1258, -0.0206],
[ 0.3767, 0.6684, 0.0900, 0.2732, 0.4522, 0.8421, -0.0925,
-0.1310, -0.2546, -0.0969]],
[[-0.1713, -0.4623, -0.3120, 0.0759, 0.4959, -0.8103, -0.2548,
0.4587, -0.2821, 0.7620],
[-0.4861, 0.0913, -0.2517, 0.1683, 0.5459, -0.3377, -0.6199,
-0.4051, -0.2039, 0.4189],
[-0.6953, -0.0388, -0.2153, 0.5317, 0.2948, -0.1002, -0.6486,
0.1166, 0.0067, 0.1381]]], grad_fn=<SliceBackward>)
参考
[PyTorch] rnn,lstm,gru中输入输出维度
|