pytorch 模型容器 - 模块化构建深度学习网络
pytorch 模型容器总结
pytorch 提供的模型容器包括: nn.Sequential:按顺序包含多个网络层 nn.ModuleList:类似列表(list)的形式包含多个网络层 nn.ModuleDict:类似字典(dict)的形式包含多个网络层
一、nn.Sequential
nn.Sequential中的模块按顺序排列,输入数据从第一个block输入,依次经过后续的block后得到输出结果 代码示例:
class Net(nn.Module):
def __init__(self):
super(CoarseNet_v1, self).__init__()
self.seq1=nn.Sequential(
nn.Conv2d(16,32,3,1),
Block1(32, 32),
nn.ReLU(inplace=True)
)
def forward(self,inputs)
output=self.seq1(inputs)
return output
这样在定义前向传播时要便捷很多,但是 nn.Sequential 内部模块执行的顺序是固定的,这是应用时存在的不便之处。
二、nn.ModuleList
nn.ModuleList中的模块以列表的形式进行存储,在使用时,可以采用列表索引的形式实现模块的引用 代码示例:
class NetList(nn.Module):
def __init__(self):
super(NetList, self).__init__()
self.list1=nn.ModuleList(
[
nn.Conv2d(16, 32,3,1),
Block1(32, 32),
nn.Conv2d(32, 64, 3, 1),
nn.Conv2d(96, 64, 3, 1),
Block2(64, 32),
]
)
def forward(self,inputs):
x= self.list1[0](inputs)
x1= self.list1[1](x)
x2= self.list1[2](x)
x3=torch.cat((x1,x2),1)
x4=self.list1[3](x3)
x5=self.list1[4](x4)
使用ModuleList对模块进行包装,可以实现List内模块的索引,灵活性要好一些。但是也存在一定的问题,但模块数增加的时候,没有对应的注释,程序将会变得不方便阅读,List结构也不利于后续的修改。
三、nn.ModuleDict
nn.ModuleDict中的模块以字典的形式进行存储,在使用时,可以采用字典键值索引的形式实现模块的引用
代码示例:
class NetDict(nn.Module):
def __init__(self):
super(NetDict, self).__init__()
self.net1=nn.ModuleDict({
'block11':Block1(1, 32),
'downsample11':ConvLRelu(32, 32, 2, 1),
'block12':Block2(32, 64),
'downsample12': ConvLRelu(64, 64, 2, 1),
"block13": ConvLRelu(64, 128),
"block21": Block2(1, 32),
"downsample21":ConvLRelu(32, 32, 2, 2),
"block22":Block2(32, 64),
"downsample22":ConvLRelu(64, 64, 2, 2),
'block23':Block2(64, 128),
})
self.net2 = nn.ModuleDict({
'block_r1' : Block1(128, 128),
'deconv1' : DeConvLRelu(128, 128),
'block_r2' : Block1(128, 64),
'deconv2' : DeConvLRelu(64, 32),
'clr' :ConvLRelu(64, 32, 1, 1)
})
def forward(self, inputs, testflag=0):
x11 = self.net1['block11'](inputs)
x21 = self.net1["block21"](inputs)
xd11 = self.net1['downsample11'](x11)
xd21 = self.net1['downsample21'](x21)
x12 = self.net1['block12'](xd11)
x22 = self.net1['block22'](xd21)
xd12 = self.net1['downsample12'](x12)
xd22 = self.net1['downsample22'](x22)
x_cat = torch.cat((xd12 , xd22 ), 1)
xr1 = self.net2['block_r1'](x_cat)
xdc1 = self.net2['deconv1'](x12)
xadd=torch.add(xr1,xdc1 )
xr2 = self.net1['block_r2'](xadd)
xdc2 = self.net1['deconv2'](xr2)
output= self.net1['clr'](xdc2 )
return output
使用ModuleDict对模块进行包装,可以实现Dict内模块的索引,在定义前向时,可以使用键值索引,灵活性好,也使得程序可读性较好。
后记
使用pytorch 模型容器,主要是考虑对整个网络进行其他操作,所以将网络模块包装起来,方便操作。最近运行一个模型较大,考虑采用模型并行,需要将模型拆开到两个GPU上运行,逐个对网络进行GPU位置的设置比较麻烦,使用容器的话便能很方便地实现网络模块GPU位置的设定。 使用多显卡加速网络训练(数据并行和模型并行)可参考:
B站链接
|