IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习||计算机视觉||如何在医学影像3DCNN中插入网上已有的注意力机制 -> 正文阅读

[人工智能]深度学习||计算机视觉||如何在医学影像3DCNN中插入网上已有的注意力机制

前言

在做深度学习在医学影像上应用的时候,数据集往往是3D的,而网上很多公开的trick或者注意力机制都是2D实现的,因此带来了一些困难。

做法

在github(https://github.com/xmu-xiaoma666/External-Attention-pytorch)上可以看到很多现有即插即用的注意力模块。
git clone到本地后,关键就是怎么在已有的代码框架上插入这些注意力模块。
**step0:**实例化一个模型,利用torch.randn函数模拟输入。我这里输入的是一个batchsize为128,channel为1,x为24,y为24,z为20的图像。

if __name__ == '__main__':
    input=torch.randn(128,1, 24, 24, 20)
    net=DiscriNet()
    out=net(input)
    print(out.shape)

**step1:**知道要插的位置,一般在backbone的后面。在def forward(self, x)方法中,利用print(x.size())知道相应的输入尺寸。

    def forward(self, x):
        #print(x.size()) #torch.Size([128, 1, 20, 20, 16]) 128为batchsize 1为channel
        x = self.net(x)
        # print(x.size())
        # print(x.size()) #torch.Size([128, 128, 2, 2, 2])
        if self.is_fc:
            x = x.view(-1, 128 * 2 * 2 * 2)
            x = self.final(x)
        else:
            #print(self.final(x).size()) #torch.Size([128, 2, 1, 1, 1])
            #print(self.final(x).squeeze(4).size()) #torch.Size([128, 2, 1, 1])
            #print(self.final(x).squeeze(4).squeeze(3).squeeze(2).size())#torch.Size([128, 2])
            x = self.final(x).squeeze(4).squeeze(3).squeeze(2)
            #print(x.size()) #torch.Size([128, 2])
        return x

**step2:**根据已经封装好的模块,将它实例化后插入到forward方法中。

    def forward(self, x):
        from SelfAttention import ScaledDotProductAttention
        #print(x.size())
        x = self.net(x)
        #print(x.size()) #torch.Size([128, 128, 4, 4, 4]) 
        b = x.permute(0, 2, 3, 4, 1).reshape(128, -1, 128)
        sa = ScaledDotProductAttention(d_model=128, d_k=128, d_v=128, h=8)
        output = sa(b, b, b)
        #print(output.size())
        x = x.reshape(128, 4, 4, 4, 128).permute(0, 4, 1, 2, 3)
        # print(x.size())
        #print(x.equal(a))
        x = x.contiguous().view(-1, 128 * 4 * 4 * 4)
        #(x.size())
        x = self.final(x)
        return x

利用permute、reshape、view函数,就可以转换了。我这里用的是自注意力机制,因为其本身是从自然语言领域转换而来的,无论是2D还是3D它都会转换成1D的tensor。
BUG1 RuntimeError: Expected all tensors to be on the same device, but found at least two devices
一般是tensor一个在cpu,一个在gpu上报错。

device = torch.device('cuda:0')
emsa = EMSA(d_model=128, d_k=128, d_v=128, h=8, H=8, W=8, ratio=2, apply_transform=True).to(device)

让模型在gpu上运行。
BUG2代码可以运行,但运行到最后一个epoch报错
在这里插入图片描述
这是由于最后一个epoch的batchsize不到指定数量所导致的。我batchsize设置为128,最后一个epoch不到128,因此后面的转换就会报错。

for img_batch, label_batch in dataloader:
            if(len(img_batch)==self.batch_size):

加一个判断就好。
或者不要固定batchsize

    def forward(self, x):
        from SelfAttention import ScaledDotProductAttention
        #print(x.size())
        x = self.net(x)
        x_batchsize=x.size()[0]
        device = torch.device('cuda:0')
        b = x.permute(0, 2, 3, 4, 1).reshape(x_batchsize, -1, 128).to(device)
        sa = ScaledDotProductAttention(d_model=128, d_k=128, d_v=128, h=8).to(device)
        output = sa(b, b, b)
        #print(output.size())
        x = x.reshape(x_batchsize, 4, 4, 4, 128).permute(0, 4, 1, 2, 3)
        # print(x.size())
        #print(x.equal(a))
        x = x.contiguous().view(-1, 128 * 4 * 4 * 4)

        #(x.size())
        x = self.final(x)
        return x
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-09-04 01:12:01  更:2022-09-04 01:14:44 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 19:59:00-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计