前言
在做深度学习在医学影像上应用的时候,数据集往往是3D的,而网上很多公开的trick或者注意力机制都是2D实现的,因此带来了一些困难。
做法
在github(https://github.com/xmu-xiaoma666/External-Attention-pytorch)上可以看到很多现有即插即用的注意力模块。 git clone到本地后,关键就是怎么在已有的代码框架上插入这些注意力模块。 **step0:**实例化一个模型,利用torch.randn函数模拟输入。我这里输入的是一个batchsize为128,channel为1,x为24,y为24,z为20的图像。
if __name__ == '__main__':
input=torch.randn(128,1, 24, 24, 20)
net=DiscriNet()
out=net(input)
print(out.shape)
**step1:**知道要插的位置,一般在backbone的后面。在def forward(self, x)方法中,利用print(x.size())知道相应的输入尺寸。
def forward(self, x):
x = self.net(x)
if self.is_fc:
x = x.view(-1, 128 * 2 * 2 * 2)
x = self.final(x)
else:
x = self.final(x).squeeze(4).squeeze(3).squeeze(2)
return x
**step2:**根据已经封装好的模块,将它实例化后插入到forward方法中。
def forward(self, x):
from SelfAttention import ScaledDotProductAttention
x = self.net(x)
b = x.permute(0, 2, 3, 4, 1).reshape(128, -1, 128)
sa = ScaledDotProductAttention(d_model=128, d_k=128, d_v=128, h=8)
output = sa(b, b, b)
x = x.reshape(128, 4, 4, 4, 128).permute(0, 4, 1, 2, 3)
x = x.contiguous().view(-1, 128 * 4 * 4 * 4)
x = self.final(x)
return x
利用permute、reshape、view函数,就可以转换了。我这里用的是自注意力机制,因为其本身是从自然语言领域转换而来的,无论是2D还是3D它都会转换成1D的tensor。 BUG1 RuntimeError: Expected all tensors to be on the same device, but found at least two devices 一般是tensor一个在cpu,一个在gpu上报错。
device = torch.device('cuda:0')
emsa = EMSA(d_model=128, d_k=128, d_v=128, h=8, H=8, W=8, ratio=2, apply_transform=True).to(device)
让模型在gpu上运行。 BUG2代码可以运行,但运行到最后一个epoch报错 这是由于最后一个epoch的batchsize不到指定数量所导致的。我batchsize设置为128,最后一个epoch不到128,因此后面的转换就会报错。
for img_batch, label_batch in dataloader:
if(len(img_batch)==self.batch_size):
加一个判断就好。 或者不要固定batchsize
def forward(self, x):
from SelfAttention import ScaledDotProductAttention
x = self.net(x)
x_batchsize=x.size()[0]
device = torch.device('cuda:0')
b = x.permute(0, 2, 3, 4, 1).reshape(x_batchsize, -1, 128).to(device)
sa = ScaledDotProductAttention(d_model=128, d_k=128, d_v=128, h=8).to(device)
output = sa(b, b, b)
x = x.reshape(x_batchsize, 4, 4, 4, 128).permute(0, 4, 1, 2, 3)
x = x.contiguous().view(-1, 128 * 4 * 4 * 4)
x = self.final(x)
return x
|