IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch 模型容器 - 模块化构建深度学习网络 -> 正文阅读

[人工智能]pytorch 模型容器 - 模块化构建深度学习网络

pytorch 模型容器 - 模块化构建深度学习网络


pytorch 模型容器总结

pytorch 提供的模型容器包括:
nn.Sequential:按顺序包含多个网络层
nn.ModuleList:类似列表(list)的形式包含多个网络层
nn.ModuleDict:类似字典(dict)的形式包含多个网络层


一、nn.Sequential

nn.Sequential中的模块按顺序排列,输入数据从第一个block输入,依次经过后续的block后得到输出结果
代码示例:

class Net(nn.Module):
    def __init__(self):
        super(CoarseNet_v1, self).__init__()
        self.seq1=nn.Sequential(
            nn.Conv2d(16,32,3,1),
            Block1(32, 32),
            nn.ReLU(inplace=True)
        )
        # Block1为自定义模块
    def forward(self,inputs)
    	output=self.seq1(inputs)  # 这样用一行代码就可以实现网络的前向传输
    	return output

这样在定义前向传播时要便捷很多,但是 nn.Sequential 内部模块执行的顺序是固定的,这是应用时存在的不便之处。

二、nn.ModuleList

nn.ModuleList中的模块以列表的形式进行存储,在使用时,可以采用列表索引的形式实现模块的引用
代码示例:

class NetList(nn.Module):
    def __init__(self):
        super(NetList, self).__init__()
        self.list1=nn.ModuleList(
            [
            nn.Conv2d(16, 32,3,1),
            Block1(32, 32),
            nn.Conv2d(32, 64, 3, 1),
            nn.Conv2d(96, 64, 3, 1),
            Block2(64, 32),
            ]
        )
        # Block1 和 Block2 是自定义模块
   def forward(self,inputs):
   		x= self.list1[0](inputs)  # 这里引用了nn.Conv2d(16,32,3,1)
   		x1= self.list1[1](x)   # 这里引用了Block1(32, 64)
   		x2= self.list1[2](x)   # 这里引用了 nn.Conv2d(32, 64, 3, 1)
   		x3=torch.cat((x1,x2),1) 
			x4=self.list1[3](x3) # 这里引用了 nn.Conv2d(96, 64, 3, 1)
			x5=self.list1[4](x4) # 这里引用了 Block2(64, 32)	

使用ModuleList对模块进行包装,可以实现List内模块的索引,灵活性要好一些。但是也存在一定的问题,但模块数增加的时候,没有对应的注释,程序将会变得不方便阅读,List结构也不利于后续的修改。

三、nn.ModuleDict

nn.ModuleDict中的模块以字典的形式进行存储,在使用时,可以采用字典键值索引的形式实现模块的引用

代码示例:

class NetDict(nn.Module):
    def __init__(self):
        super(NetDict, self).__init__()
        self.net1=nn.ModuleDict({
            'block11':Block1(1, 32),
            'downsample11':ConvLRelu(32, 32, 2, 1),
            'block12':Block2(32, 64),
            'downsample12': ConvLRelu(64, 64, 2, 1),
            "block13": ConvLRelu(64, 128),
            "block21": Block2(1, 32),
            "downsample21":ConvLRelu(32, 32, 2, 2),
            "block22":Block2(32, 64),
            "downsample22":ConvLRelu(64, 64, 2, 2),
            'block23':Block2(64, 128),
        })
        self.net2 = nn.ModuleDict({
            'block_r1' : Block1(128, 128),
            'deconv1' : DeConvLRelu(128, 128),
            'block_r2' : Block1(128, 64),
            'deconv2' : DeConvLRelu(64, 32),
            'clr' :ConvLRelu(64, 32, 1, 1)
             })
        # Block1,Block2 为自定义模块
    def forward(self, inputs, testflag=0):
   		 # net1 部分
        x11 = self.net1['block11'](inputs)
        x21 = self.net1["block21"](inputs)
        xd11 = self.net1['downsample11'](x11)
        xd21 = self.net1['downsample21'](x21)
        x12 = self.net1['block12'](xd11)
        x22 = self.net1['block22'](xd21)
        xd12 = self.net1['downsample12'](x12)
        xd22 = self.net1['downsample22'](x22)
        
        x_cat = torch.cat((xd12 , xd22 ), 1) 
        # net2 部分
        xr1 = self.net2['block_r1'](x_cat) # 128 
        xdc1 = self.net2['deconv1'](x12) # 128
        xadd=torch.add(xr1,xdc1 )
        xr2 = self.net1['block_r2'](xadd)
        xdc2 = self.net1['deconv2'](xr2)
        output= self.net1['clr'](xdc2 )
        return output
        

使用ModuleDict对模块进行包装,可以实现Dict内模块的索引,在定义前向时,可以使用键值索引,灵活性好,也使得程序可读性较好。

后记

使用pytorch 模型容器,主要是考虑对整个网络进行其他操作,所以将网络模块包装起来,方便操作。最近运行一个模型较大,考虑采用模型并行,需要将模型拆开到两个GPU上运行,逐个对网络进行GPU位置的设置比较麻烦,使用容器的话便能很方便地实现网络模块GPU位置的设定。
使用多显卡加速网络训练(数据并行和模型并行)可参考:

B站链接

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-22 14:41:05  更:2021-09-22 14:43:23 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/22 8:57:36-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码