IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> GoogLeNet的实现 -> 正文阅读

[人工智能]GoogLeNet的实现

以下代码是阅读了torchvision中的GoogLeNet的实现后,复现出来的。

一、Inception模块的写法

每一个inception的结构都一样,不同就是输入输出的通道数不一样。所以可以写成一个Module,这样就可以复用。

以inception(3a)为例:

?代码:

# @file name  : test.py
# @brief      : Inception模块的写法
# @author     : liupc
# @date       : 2021/8/10


import torch
import torch.nn as nn

def BasicConv2d(in_channels, out_channels, **kwargs):
    model = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, **kwargs),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )
    return model


class Inception_block(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        #ch1x1: 第一个分支的通道数
        #ch3x3red: 第二个分支的1x1卷积的通道数
        #ch3x3: 第二个分支的3x3卷积的通道数
        #ch5x5red: 第三个分支的1x1卷积的通道数
        #ch5x5: 第三个分支的5x5卷积的通道数
        #pool_proj: 第四个分支的通道数
        super(Inception_block, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)   #经过1x1卷积

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, stride=1, padding=1),
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, stride=1, padding=1),
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        x1 = self.branch1(x)     #B,C1,H,W
        x2 = self.branch2(x)     #B,C2,H,W
        x3 = self.branch3(x)     #B,C3,H,W
        x4 = self.branch4(x)     #B,C4,H,W

        out = [x1, x2, x3, x4]
        return torch.cat(out, 1)

if __name__ == "__main__":
    inputs = torch.rand(2, 192, 28, 28)
    inception = Inception_block(192, 64, 96, 128, 16, 32, 32)
    outputs = inception(inputs)
    print(outputs.shape)


运行结果:

在写一个模块的时候,其实就是写一个Module的类。这个类怎么写呢,通过这个例子可以看到有两种方式:第一种就是写一个class,继承nn.Module类,在里面写init函数和forward函数。另一种写法就是定义一个函数,里面通过nn.Sequential包装多个现成的Module,然后返回。

第一种写法适合那种比较复杂的情况,可以自由发挥。第二种情况适合于简单地将现成的模块进行组合的情况。

这两种写法都行吧,怎么熟悉怎么来,怎么方便怎么来。比如,在GoogLeNet中,他们 BasicConv2d就不是我这样定义的,而是:

#GoogLeNet中的BasicConv2d的写法

import torch.nn.functional as F

class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

当然也可以这么写了:

class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

这个写法还不如我的简便呢。。。随意写都行。。。不用太在意

二、辅助损失模块的写法

写了Inception之后,还可以把辅助损失的代码写成模块:

代码:

# @file name  : auxloss.py
# @brief      : 辅助损失模块的写法
# @author     : liupc
# @date       : 2021/8/10


import torch
import torch.nn as nn
from inception import BasicConv2d, Inception_block

class Auxloss_block(nn.Module):
    def __init__(self, in_channels, num_classes=1000):
        super(Auxloss_block, self).__init__()

        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))  # H和W变成4*4.

        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.7)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.avgpool(x)      # N*512*4*4
        x = self.conv(x)         # N*128*4*4
        x = torch.flatten(x, 1)  # N* (128*4*4) = N *2048
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x

if __name__ == "__main__":
    inputs = torch.rand(2, 512, 14, 14)
    net = Auxloss_block(512)
    outputs = net(inputs)
    print(outputs.shape)

?运行结果:

三、GoogLeNet的写法

有了Inception和辅助损失的模块之后,就可以写GoogLeNet了。

# @file name  : googlenet.py
# @brief      : 
# @author     : liupc
# @date       : 2021/8/10

import torch
import torch.nn as nn
from inception import BasicConv2d, Inception_block
from auxloss import Auxloss_block

class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        #第一部分
        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        #第二部分
        self.inception3a = Inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception_block(832, 384, 192, 384, 48, 128, 128)

        #第三部分
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))    #H和W变成1*1.
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(1024, num_classes)
        )

        #第四部分
        if aux_logits:
            self.aux1 = Auxloss_block(512, num_classes)
            self.aux2 = Auxloss_block(528, num_classes)

    def forward(self, x):
        #第一部分
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        # 第二部分
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        if self.training and self.aux_logits:
            aux1 = self.aux1(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        if self.training and self.aux_logits:
            aux2 = self.aux2(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        # 第三部分
        x = self.avgpool(x)
        x = torch.flatten(x, 1)   #由B*C*1*1 变为:B*C,才能经过FC层。
        x = self.classifier(x)

        if self.training and self.aux_logits:
            return x, aux1, aux2
        else:
            return x

if __name__ == "__main__":
    inputs = torch.rand(2, 3, 224, 224)
    net = GoogLeNet(1000, aux_logits=True)
    outputs, a1, a2 = net(inputs)
    print(outputs.shape, a1.shape, a2.shape)





运行结果:

分别输出了最后的分类和两个辅助损失的分类,是对的。?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-11 12:23:47  更:2021-08-11 12:26:46 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 1:46:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码