以下代码是阅读了torchvision中的GoogLeNet的实现后,复现出来的。
一、Inception模块的写法
每一个inception的结构都一样,不同就是输入输出的通道数不一样。所以可以写成一个Module,这样就可以复用。
以inception(3a)为例:
?代码:
# @file name : test.py
# @brief : Inception模块的写法
# @author : liupc
# @date : 2021/8/10
import torch
import torch.nn as nn
def BasicConv2d(in_channels, out_channels, **kwargs):
model = nn.Sequential(
nn.Conv2d(in_channels, out_channels, **kwargs),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
return model
class Inception_block(nn.Module):
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
#ch1x1: 第一个分支的通道数
#ch3x3red: 第二个分支的1x1卷积的通道数
#ch3x3: 第二个分支的3x3卷积的通道数
#ch5x5red: 第三个分支的1x1卷积的通道数
#ch5x5: 第三个分支的5x5卷积的通道数
#pool_proj: 第四个分支的通道数
super(Inception_block, self).__init__()
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) #经过1x1卷积
self.branch2 = nn.Sequential(
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
BasicConv2d(ch5x5red, ch5x5, kernel_size=3, stride=1, padding=1),
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
BasicConv2d(in_channels, pool_proj, kernel_size=1)
)
def forward(self, x):
x1 = self.branch1(x) #B,C1,H,W
x2 = self.branch2(x) #B,C2,H,W
x3 = self.branch3(x) #B,C3,H,W
x4 = self.branch4(x) #B,C4,H,W
out = [x1, x2, x3, x4]
return torch.cat(out, 1)
if __name__ == "__main__":
inputs = torch.rand(2, 192, 28, 28)
inception = Inception_block(192, 64, 96, 128, 16, 32, 32)
outputs = inception(inputs)
print(outputs.shape)
运行结果:
在写一个模块的时候,其实就是写一个Module的类。这个类怎么写呢,通过这个例子可以看到有两种方式:第一种就是写一个class,继承nn.Module类,在里面写init函数和forward函数。另一种写法就是定义一个函数,里面通过nn.Sequential包装多个现成的Module,然后返回。
第一种写法适合那种比较复杂的情况,可以自由发挥。第二种情况适合于简单地将现成的模块进行组合的情况。
这两种写法都行吧,怎么熟悉怎么来,怎么方便怎么来。比如,在GoogLeNet中,他们 BasicConv2d就不是我这样定义的,而是:
#GoogLeNet中的BasicConv2d的写法
import torch.nn.functional as F
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
当然也可以这么写了:
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
这个写法还不如我的简便呢。。。随意写都行。。。不用太在意
二、辅助损失模块的写法
写了Inception之后,还可以把辅助损失的代码写成模块:
代码:
# @file name : auxloss.py
# @brief : 辅助损失模块的写法
# @author : liupc
# @date : 2021/8/10
import torch
import torch.nn as nn
from inception import BasicConv2d, Inception_block
class Auxloss_block(nn.Module):
def __init__(self, in_channels, num_classes=1000):
super(Auxloss_block, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) # H和W变成4*4.
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
self.fc1 = nn.Linear(2048, 1024)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7)
self.fc2 = nn.Linear(1024, num_classes)
def forward(self, x):
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = self.avgpool(x) # N*512*4*4
x = self.conv(x) # N*128*4*4
x = torch.flatten(x, 1) # N* (128*4*4) = N *2048
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
if __name__ == "__main__":
inputs = torch.rand(2, 512, 14, 14)
net = Auxloss_block(512)
outputs = net(inputs)
print(outputs.shape)
?运行结果:
三、GoogLeNet的写法
有了Inception和辅助损失的模块之后,就可以写GoogLeNet了。
# @file name : googlenet.py
# @brief :
# @author : liupc
# @date : 2021/8/10
import torch
import torch.nn as nn
from inception import BasicConv2d, Inception_block
from auxloss import Auxloss_block
class GoogLeNet(nn.Module):
def __init__(self, num_classes=1000, aux_logits=False):
super(GoogLeNet, self).__init__()
self.aux_logits = aux_logits
#第一部分
self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.conv2 = BasicConv2d(64, 64, kernel_size=1)
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
#第二部分
self.inception3a = Inception_block(192, 64, 96, 128, 16, 32, 32)
self.inception3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.inception4a = Inception_block(480, 192, 96, 208, 16, 48, 64)
self.inception4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
self.inception4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
self.inception4d = Inception_block(512, 112, 144, 288, 32, 64, 64)
self.inception4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.inception5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception_block(832, 384, 192, 384, 48, 128, 128)
#第三部分
self.avgpool = nn.AdaptiveAvgPool2d((1,1)) #H和W变成1*1.
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(1024, num_classes)
)
#第四部分
if aux_logits:
self.aux1 = Auxloss_block(512, num_classes)
self.aux2 = Auxloss_block(528, num_classes)
def forward(self, x):
#第一部分
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.maxpool2(x)
# 第二部分
x = self.inception3a(x)
x = self.inception3b(x)
x = self.maxpool3(x)
x = self.inception4a(x)
if self.training and self.aux_logits:
aux1 = self.aux1(x)
x = self.inception4b(x)
x = self.inception4c(x)
x = self.inception4d(x)
if self.training and self.aux_logits:
aux2 = self.aux2(x)
x = self.inception4e(x)
x = self.maxpool4(x)
x = self.inception5a(x)
x = self.inception5b(x)
# 第三部分
x = self.avgpool(x)
x = torch.flatten(x, 1) #由B*C*1*1 变为:B*C,才能经过FC层。
x = self.classifier(x)
if self.training and self.aux_logits:
return x, aux1, aux2
else:
return x
if __name__ == "__main__":
inputs = torch.rand(2, 3, 224, 224)
net = GoogLeNet(1000, aux_logits=True)
outputs, a1, a2 = net(inputs)
print(outputs.shape, a1.shape, a2.shape)
运行结果:
分别输出了最后的分类和两个辅助损失的分类,是对的。?
|