python常用函数(一)
repeat()函数
功能:指定维度上的元素重复n次。 例:
a = torch.rand(12,512,1,64)
b = a.repeat(1,1,32,1)
表示第2维上的元素重复32次,其他维度为1表示重复1次, 也就是这维的元素不变动 这样b的维度就是(12,512,32,64)
如果现在这样:
b = a.repeat(1,2,32,2)
表示第0维不变,第1维所有元素重复2次,第2维元素重复32次,第3维元素重复2次 b 的维度:(12,1024,32,128) b 对应维度元素个数=a上该维度元素个数×重复的次数。。。 即:b 的维度(12×1,512×2,1×32,64×2) = (12,1024,32,128)
repeat()函数在深度学习网络之间的维度拼接中发挥很大的作用,当两个特征的维度不一致时,但需要将两个特征融合在一起时,就可以通过repeat()函数将两者的维度化为一致,再融合。 例:
a = torch.rand(12,512,32,64)
b = torch.rand(12,512,1,64)
a和b的第三维表示的是特征,现在需要将a和b第三维特征进行拼接,但由于第2维的维度不一致,则无法拼接。因此引入repeat() 函数
c = torch.cat([a, b.repeat(1, 1, 32, 1)], dim=-1)
b.repeat(1, 1, 32, 1)
→
\to
→ b 的维度(12,512,32,64) dim = -1 表示对最后一维的元素进行拼接。 结果c的维度为:(12,512,32,128)
|