IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> ResNet详解与CIFAR10数据集实战 -> 正文阅读

[人工智能]ResNet详解与CIFAR10数据集实战

1、引言

由于神经网络深度增加而导致的参数量急剧增大使得我们对其训练越来越困难,并且随着深度增加会出现网络退化现象。这时我们就需要采取一种方法去有效地解决这个问题。在2015年,何凯明大神提出了residual nets(深度残差网络,后面简称为ResNet)能够有效地解决这个问题,同时还能有效解决梯度消失和梯度爆炸问题从而更进一步提高网络的学习效果。

2、ResNet原理

深度残差神经网络不需要去拟合底层的映射,而是去拟合相对于输入的残差。残差模块如下:
在这里插入图片描述
我们设 x x x为输入数据,设原来的底层映射为 H ( x ) H(x) H(x),设相对于 x x x的残差为 F ( x ) F(x) F(x),所以可以得到 F ( x ) = H ( x ) ? x F(x)=H(x)-x F(x)=H(x)?x。经过残差模块后底层映射就是 F ( x ) + x F(x)+x F(x)+x。从上面可以看出,在极端情况下,如果底层映射足够好,那么残差就是0,这时就不需要从残差中学习了,此时底层映射就是恒等映射。

3、ResNet解决网络退化的机理

(1)深层梯度回传顺畅

恒等映射这一条路的梯度为1,能够通过这条路很好地把深层梯度注入底层,防止梯度消失,没有像sigmoid这样中间商的层层剥夺。

(2)网络自身构建的优势

1)每次都是去拟合上一层的误差,这样会让误差更尽可能变小
2)残差类似于LSTM的遗忘门,如果某个信息有用的,则记住;如果没有用,则忘记,最后再让后面的过程去拟合。这样能高效率得到有效信息,加快模型收敛。
3)使用ReLu激活函数,不会像sigmoid这样因为多次迭代而使梯度近似0,以至于无法更加逼近更好的结果。

(3)传统的线性网络很难去拟合“恒等映射”,而ResNet可以

由于深度学习的需要,我们有时需要原封不动地保存之前的信息。ResNet的残差模块就能根据需要自动选择是否要更新那些信息。,这样就弥补了高度非线性造成的不可逆的信息损失。

4、CIFAR10数据集实战

这里我们选择构建ResNet18,具体模型结构如下:
在这里插入图片描述

(1)导入数据

使用dataloader导入数据,并转换成tensor

# 导入训练集数据
train_data = dataloader.DataLoader(
    datasets.CIFAR10(root='data/', train=True, transform=transforms.Compose([
        transforms.Resize(32, 32),      # 重新设置图片大小
        transforms.ToTensor(),      # 将图片转化为tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])         # 进行归一化
    ]), download=True), shuffle=True, batch_size=batch_sz
)

# 导入测试集数据
train_test = dataloader.DataLoader(
    datasets.CIFAR10(root='data/', train=False, transform=transforms.Compose([
        transforms.Resize(32, 32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True), shuffle=True, batch_size=batch_sz
)

(2)定义网络结构

首先我们需要先定义残差的block
根据上面的残差模块图片不难写出

# 定义残差块
class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(ResBlk, self).__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
            nn.BatchNorm2d(ch_out)
        )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.extra(x) + out
        out = F.relu(out)
        return out

然后再根据上图写ResNet18网络结构,注意维度变换

# 定义ResNet18网络结构
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(64)
        )
        self.blk1 = ResBlk(64, 64, stride=2)
        self.blk2 = ResBlk(64, 128, stride=2)
        self.blk3 = ResBlk(128, 256, stride=2)
        self.blk4 = ResBlk(256, 512, stride=2)

        self.outlayer = nn.Linear(512*1*1, 10)

    def forward(self, x):

        x = F.relu(self.conv1(x))

        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        x = F.adaptive_avg_pool2d(x, [1, 1])
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)

        return x

(3)定义损失函数和优化方式

损失函数使用CrossEntropyLoss(交叉熵),优化方式使用Adam。至于为什么使用交叉熵,它会使梯度变得更大,优化起来更快。如果使用sigmoid+MSE的话会比较容易出现sigmoid饱和的情况,这时会出现梯度消失情况。最后只能说根据实验交叉熵在分类问题上表现效果非常好。

# 定义损失函数和优化方式
criteon = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

(4)训练测试模型

# 训练模型
for epoch in range(1):
    model.train()
    for batch_idx, (x, label) in enumerate(train_data):
        x = x.to(device)
        label = label.to(device)

        logits = model(x)       # 经过模型得到的数据

        loss = criteon(logits, label)
        print('logits:', logits[0])
        print('label:', label[0])
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx == len(train_data) - 1:
            print(epoch, 'loss:', loss.item())

    # 进行测试
    model.eval()
    with torch.no_grad():
        total_correct = 0
        total_num = 0
        for x, label in train_test:
            x = x.to(device)
            label = label.to(device)

            logits = model(x)

            pred = logits.argmax(dim=1)

            correct = torch.eq(pred, label).float().sum().item()
            total_correct += correct
            total_num += x.size(0)

        acc = total_correct / total_num
        print(epoch, 'test acc:', acc)

最后的分类准确率大约能达到百分之90左右。
github完整代码:ResNet源代码


小伙伴喜欢文章的话记得 点赞加关注哦,后面会更新其他深度学习的文章。
如果有什么写得有问题的地方希望大家能值出,谢谢。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-10 22:31:02  更:2022-03-10 22:34:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 16:31:28-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码