torch.dstack
dstack也是一种粘合张量的方法,不过它也是有方向限定的,从这一点来说还是cat方法更自由一些。
其函数原型如下:
torch.dstack(tensors, *, out=None) → Tensor
例程
>>> a = torch.tensor([1, 2, 3])
>>> b = torch.tensor([4, 5, 6])
>>> a.shape
torch.Size([3])
>>> b.shape
torch.Size([3])
>>> c = torch.dstack((a,b))
>>> c.shape
torch.Size([1, 3, 2])
>>> c
tensor([[[1, 4],
[2, 5],
[3, 6]]])
然后我们再来看看另外一个例程
>>> a = torch.tensor([[1], [2], [3]])
>>> a.shape
torch.Size([3, 1])
>>> b = torch.tensor([[4], [5], [6]])
>>> b.shape
torch.Size([3, 1])
>>> c = torch.dstack((a, b))
>>> c.shape
torch.Size([3, 1, 2])
可以发现,这其实相当于先把所有参与计算的张量在最后的维度上再扩展一位,然后再执行粘接操作,用比较直观的数学表示,大概就是这样吧:
T
e
n
s
o
r
.
s
i
z
e
(
x
n
,
x
n
?
1
,
?
?
,
x
1
)
→
T
e
n
s
o
r
.
s
i
z
e
(
x
n
,
x
n
?
1
,
?
?
,
x
1
,
x
0
)
Tensor.size(x_n, x_{n-1}, \cdots, x_1) \rightarrow Tensor.size(x_n, x_{n-1}, \cdots, x_1, x_0)
Tensor.size(xn?,xn?1?,?,x1?)→Tensor.size(xn?,xn?1?,?,x1?,x0?)
明白这一点后,我们就可以用升维命令 torch.unsqueeze 和 torch.cat 来完成类似的功能:
a = torch.tensor([[1], [2], [3]])
b = torch.tensor([[4], [5], [6]])
c = torch.dstack((a, b))
a1 = torch.unsqueeze(a, -1)
b1 = torch.unsqueeze(b, -1)
c1 = torch.cat((a1, b1), -1)
print(c.shape, c1.shape)
if torch.equal(c, c1):
print("TRUE")
else:
print("FALSE")
输出的结果是True,说明我们的猜想是准确的。
|