swin-transformer学习笔记1——window_partition函数的理解
功能如下所示
原文关于这部分的代码如下
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
首先我们看一下输入x的几个属性,B是Batch,H是图片的高,W是图片的宽,C是通道数 我们知道在pytorch中view操作其实就是先把原来张量变成一维的,然后嵌入新的格式中。
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
这里首先B是不动的,也就是说,每张图片的元素是不变的,C也是不变的,也就是说每个像素点的通道值是不变的。 那么我们可以把这两个维度先忽略不看 我们发现,其实就是把一个维度为(H,W)的张量,分成了(H // window_size, window_size, W // window_size, window_size)维度的张量。 那么我们举个例子来分析这个过程
import torch
a=torch.arange(36).view(6,-1)
print (a)
print(a.view(3,-1))
输出
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]])
也就是说第一步将图片水平切分了 然后我们进一步切分:
import torch
a=torch.arange(36).view(6,-1)
print (a)
print(a.view(3,2,-1))
输出:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]],
[[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]],
[[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]]])
可以看到这一步是把每一行都分了出来,并且在垂直方向上进行了分块。 之后:
import torch
a=torch.arange(36).view(6,-1)
print (a)
print(a.view(3,2,3,-1))
输出:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]]],
[[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27],
[28, 29]],
[[30, 31],
[32, 33],
[34, 35]]]])
这时,就是每一行都按照我们需要的window_size进行分块了,但是有一点过于细了。所以我们要把不同行的对应元素结合起来,于是我们需要用到permute操作。
x.permute(0,2,1,3)
这里等于是先将图片按照我们的要求切成(H//window_size,W//window_size)的形式,然后在细分。
|