IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【pytorch实战学习】第六篇:CIFAR-10分类实现 -> 正文阅读

[人工智能]【pytorch实战学习】第六篇:CIFAR-10分类实现

往期相关文章列表:



本文是基于 pytorch官网教程,然后在此基础上,写了一些自己的理解和修改。

1. 数据集简介

CIFAR-10数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。需要说明的是,这10类都是各自独立的,不会出现重叠。

题外话,MNIST数据集是只有1个通道的灰度图,尺寸大小为28*28。

在这里插入图片描述
CIFAR-10的相关链接如下:

下载解压完成之后,格式如下:
在这里插入图片描述

2. 相关理论介绍

(1)卷积层

nn.Conv2d(3,16,3,padding=1)

卷积操作在上一篇文章中也有介绍,只不过这里省略了in_channels, out_channels, kernel_size,下面是它的参数:

  • 输入通道:3,也就是RGB图像
  • 输出通道:16,所以这里用到了16个不同的卷积核。
  • 卷积核:kernel_size为 3×3,正方形kernel可以只写其中一个。
  • padding = 1

(2)输出维度计算

  • 卷积后的维度计算
输出维度 = 输入维度 + 2*padding - kernel_size +1

nn.Conv2d(3,16,3,padding=1)举例(原始为3通道,RGB),如果输入图像是32*32,那么输出也是32*32。也就是说输入为:3*32*32,输出为:16*32*32。

  • pooling维度计算

nn.MaxPool2d(2, 2),表示该最大池化层在 2x2 空间里向下采样,步长stride=2。如果它的输入维度是:32*32,那么它的输出为:16*16。

  • BN层和激活层的维度保持不变。

2. 数据集下载与显示

  • 数据下载

torchvision格式的数据集的范围为[0,1],我们这里使用transforms方法将它归一化为[-1,1]的张量。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  • 数据显示

使用torchvision.utils.make_grid将多个图片拼接成一张图片显示。

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

#随机选取图片,因为trainloader是随机的
dataiter = iter(trainloader)
images, labels = dataiter.next()

#显示图片
imshow(torchvision.utils.make_grid(images)) #make_grid的作用是将若干幅图像拼成一幅图像。

#打印labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

输出:

dog   deer  car   truck

在这里插入图片描述

3. 自定义CNN模型

这里没有使用官方的CNN自定义的模型,而是使用自己定义的一个模型,代码如下:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) #3*32*32 ==》6*28*28(32 + 2P - kernel + 1)
        self.pool = nn.MaxPool2d(2, 2) #6*28*28 ==> 6*14*14
        self.conv2 = nn.Conv2d(6, 16, 5) # 6*14*14 ==> 16 * 10 * 10
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 还要一个pooling 所以输入是16 * 5 * 5 ==> 120
        self.fc2 = nn.Linear(120, 84) # 120 ==> 84
        self.fc3 = nn.Linear(84, 10) # 84 ==> 10

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
if torch.cuda.is_available():
    net = net.cuda()

使用print(net)打印网络结构如下:

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

4. 损失函数与优化器

由于是分类器,选用交叉熵损失,优化器选用SGD或者Adam等都可以。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

5. 训练网络

这部分和官方唯一的区别是,是否选用cuda进行加速,被注释掉的部分为每个epoch打印一次信息,这里使用和官方一致的信息输出模式。

print('---------- Train Start ----------')
epochs = 3
test_iter = 0
for epoch in range(epochs): 
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        if torch.cuda.is_available():
            inputs = inputs.cuda()
            labels = labels.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    #没2000个小批次打印一次
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
       # test_iter = i * labels.data[0].item()

    # 每个epoch计算一次损失
    # print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / test_iter:.3f}')
    # running_loss = 0.0
    # test_iter = 0

print('----------Finished Training----------')

输出如下:

---------- Train Start ----------
[1,  2000] loss: 2.162
[1,  4000] loss: 1.846
[1,  6000] loss: 1.673
[1,  8000] loss: 1.561
[1, 10000] loss: 1.497
[1, 12000] loss: 1.469
[2,  2000] loss: 1.428
[2,  4000] loss: 1.389
[2,  6000] loss: 1.361
[2,  8000] loss: 1.342
[2, 10000] loss: 1.308
[2, 12000] loss: 1.326
[3,  2000] loss: 1.255
[3,  4000] loss: 1.237
[3,  6000] loss: 1.244
[3,  8000] loss: 1.212
[3, 10000] loss: 1.208
[3, 12000] loss: 1.231
----------Finished Training----------

6. 保存模型

pytorch提供了便捷的保存模型方法,如下所示:

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

7. 加载、测试模型

其实在这里保存和加载模型是没有必要的,只是告诉大家怎么去保存和加载而已

#显示测试gt
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

#加载模型
net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))

#准确度分析
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

在这里插入图片描述

GroundTruth:  cat   ship  ship  plane
Predicted:  cat   ship  car   plane
Accuracy of the network on the 10000 test images: 56 %

看起来准确度还不够好,让我们把epochs加大到20,再看看输出结果。

GroundTruth:  cat   ship  ship  plane
Predicted:  dog   ship  truck plane
Accuracy of the network on the 10000 test images: 59 %

测试结果有所提升,但不是很明显,待修炼之后再来补功课,提升一下指标!


参考链接:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-24 00:32:32  更:2022-03-24 00:35:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 14:45:53-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码