广播条件
两个张量只有都满足下面两个条件,才可以广播:
- 每个张量都至少有一个维度
- 对两个张量的维度从后往前(从右向左) 处理,维度的大小(这个维度的长度)必须要么相等,要么其中一个为1,或者其中一个张量后面不存在维度了
例:
>>>import torch
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
'''相同的形状总是可以广播的'''
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
'''不能广播,因为两个张量都必须只有一个维度'''
'''可以将尾部对齐(can line up trailing dimensions)'''
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
'''
x和y可以广播
# 倒数第一个维度:x size == y size == 1
# 倒数第二个维度: y has size 1
# 倒数第三个维度 == y size
# 倒数第四个维度: y后面不再有维度了
——————英文原文如下,表达语句值得学习——————
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
'''
>>> x=torch.empty(3,2,4,1)
>>> y=torch.empty( 3,1,1)
'''x和y不能广播,因为倒数第三个维度大小不同,且不为1'''
———————————————————————————
运算法则
如果两个张量x, y是可广播的,结果的张量大小按如下方式计算:
- 如果x和y的维度数量不同,对维度数量少的张量增加新的维度,且维度大小为1,使得两个张量的维度数量相同
- 对每个维度,结果的维度大小是x和y的维度大小的最大值。(其实如果某个维度大小不同,那么有一个维度大小肯定是1)
例1:
import torch
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
上面以维度的角度展示了运算法则。下面从元素级运算,展示具体操作原理:
例2:
import torch
x = torch.arange(0,24).reshape(2,4,3)
print(x)
y = torch.arange(0,3).reshape(1,3)
print(y)
z = x + y
print(z)
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23]]])
tensor([[0, 1, 2]])
tensor([[[ 0, 2, 4],
[ 3, 5, 7],
[ 6, 8, 10],
[ 9, 11, 13]],
[[12, 14, 16],
[15, 17, 19],
[18, 20, 22],
[21, 23, 25]]])
例3:
x = torch.arange(0,24).reshape(2,4,3)
print(x)
y = torch.arange(0,4).reshape(4,1)
print(y)
z = x + y
print(z)
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23]]])
tensor([[0],
[1],
[2],
[3]])
tensor([[[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10],
[12, 13, 14]],
[[12, 13, 14],
[16, 17, 18],
[20, 21, 22],
[24, 25, 26]]])
例4:
x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(0,3).reshape(3,1)
print(y)
z = x + y
print(z)
结果为:
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27],
[28, 29]],
[[30, 31],
[32, 33],
[34, 35]],
[[36, 37],
[38, 39],
[40, 41]],
[[42, 43],
[44, 45],
[46, 47]]]])
tensor([[0],
[1],
[2]])
tensor([[[[ 0, 1],
[ 3, 4],
[ 6, 7]],
[[ 6, 7],
[ 9, 10],
[12, 13]],
[[12, 13],
[15, 16],
[18, 19]],
[[18, 19],
[21, 22],
[24, 25]]],
[[[24, 25],
[27, 28],
[30, 31]],
[[30, 31],
[33, 34],
[36, 37]],
[[36, 37],
[39, 40],
[42, 43]],
[[42, 43],
[45, 46],
[48, 49]]]])
例5:
x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(0,8).reshape(4,1,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27],
[28, 29]],
[[30, 31],
[32, 33],
[34, 35]],
[[36, 37],
[38, 39],
[40, 41]],
[[42, 43],
[44, 45],
[46, 47]]]])
tensor([[[0, 1]],
[[2, 3]],
[[4, 5]],
[[6, 7]]])
tensor([[[[ 0, 2],
[ 2, 4],
[ 4, 6]],
[[ 8, 10],
[10, 12],
[12, 14]],
[[16, 18],
[18, 20],
[20, 22]],
[[24, 26],
[26, 28],
[28, 30]]],
[[[24, 26],
[26, 28],
[28, 30]],
[[32, 34],
[34, 36],
[36, 38]],
[[40, 42],
[42, 44],
[44, 46]],
[[48, 50],
[50, 52],
[52, 54]]]])
例6:
x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(2,4).reshape(1,1,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27],
[28, 29]],
[[30, 31],
[32, 33],
[34, 35]],
[[36, 37],
[38, 39],
[40, 41]],
[[42, 43],
[44, 45],
[46, 47]]]])
tensor([[[2, 3]]])
tensor([[[[ 2, 4],
[ 4, 6],
[ 6, 8]],
[[ 8, 10],
[10, 12],
[12, 14]],
[[14, 16],
[16, 18],
[18, 20]],
[[20, 22],
[22, 24],
[24, 26]]],
[[[26, 28],
[28, 30],
[30, 32]],
[[32, 34],
[34, 36],
[36, 38]],
[[38, 40],
[40, 42],
[42, 44]],
[[44, 46],
[46, 48],
[48, 50]]]])
例7:
x = torch.arange(0,48).reshape(2,4,3,2)
print(x)
y = torch.arange(2,8).reshape(2,1,3,1)
print(y)
z = x + y
print(z)
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27],
[28, 29]],
[[30, 31],
[32, 33],
[34, 35]],
[[36, 37],
[38, 39],
[40, 41]],
[[42, 43],
[44, 45],
[46, 47]]]])
tensor([[[[2],
[3],
[4]]],
[[[5],
[6],
[7]]]])
tensor([[[[ 2, 3],
[ 5, 6],
[ 8, 9]],
[[ 8, 9],
[11, 12],
[14, 15]],
[[14, 15],
[17, 18],
[20, 21]],
[[20, 21],
[23, 24],
[26, 27]]],
[[[29, 30],
[32, 33],
[35, 36]],
[[35, 36],
[38, 39],
[41, 42]],
[[41, 42],
[44, 45],
[47, 48]],
[[47, 48],
[50, 51],
[53, 54]]]])
例8:
x = torch.arange(0,24).reshape(2,4,3,1)
print(x)
y = torch.arange(0,12).reshape(2,1,3,2)
print(y)
z = x + y
print(z)
tensor([[[[ 0],
[ 1],
[ 2]],
[[ 3],
[ 4],
[ 5]],
[[ 6],
[ 7],
[ 8]],
[[ 9],
[10],
[11]]],
[[[12],
[13],
[14]],
[[15],
[16],
[17]],
[[18],
[19],
[20]],
[[21],
[22],
[23]]]])
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]]],
[[[ 6, 7],
[ 8, 9],
[10, 11]]]])
tensor([[[[ 0, 1],
[ 3, 4],
[ 6, 7]],
[[ 3, 4],
[ 6, 7],
[ 9, 10]],
[[ 6, 7],
[ 9, 10],
[12, 13]],
[[ 9, 10],
[12, 13],
[15, 16]]],
[[[18, 19],
[21, 22],
[24, 25]],
[[21, 22],
[24, 25],
[27, 28]],
[[24, 25],
[27, 28],
[30, 31]],
[[27, 28],
[30, 31],
[33, 34]]]])
|