IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch卷积神经网络Mnist手写数字识别-GPU训练 -> 正文阅读

[人工智能]Pytorch卷积神经网络Mnist手写数字识别-GPU训练

导入工具包?

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

?定义超参数

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
classes = 10  #标签的种类数
epochs = 10  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片
learning_rate=0.001

?通过torchvision的dataset导入Mnist数据集

# 训练集
train_dataset = datasets.MNIST(root='./data',  
                            train=True,   
                            transform=transforms.ToTensor(),  
                            download=True) 

# 测试集
test_dataset = datasets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

?通过DataLoader实现构建batch数据,进一步简化了代码。(这样就不用写关于设置batch的循环了)

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

?构建CNN网络,卷积层-池化层-卷积层-池化层-全连接层

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()       #输入大小为(1,28,28)
        self.conv1=nn.Sequential(      
           nn.Conv2d(
               in_channels=1,            #灰度图,通道只有一个特征图
               out_channels=16,          #输出16个特征图
               kernel_size=5,            #卷积核大小为5*5
               stride=1,                 #步长为1
               padding=2,                #填充2圈变为32*32
           ) ,                           #输出为16*28*28
            nn.ReLU(),                   #ReLU层
            nn.MaxPool2d(kernel_size=2), #进行池化操作,2*2
        )                                #输出为16*14*14
        self.conv2=nn.Sequential(
           nn.Conv2d(
               in_channels=16,           #输入16*14*14
               out_channels=32,          #输出32*14*14
               kernel_size=5,
               stride=1,
               padding=2,
           ) ,
            nn.ReLU(),                   #ReLU层
            nn.MaxPool2d(kernel_size=2), #输出(32,7,7)
        )
        self.out=nn.Linear(32*7*7,10)    #全连接层输出结果
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(x.size(0),-1)
        output=self.out(x)
        return output

?定义准确率计算函数

def accuracy(predictions,labels):
    pred=torch.max(predictions.data,1)[1]
    rights=pred.eq(labels.data.view_as(pred)).sum()
    return rights,len(labels)

?网络实例化,设置优化器,损失函数,设定gpu训练(将模型,数据导入gpu即可)

#实例化
net=CNN()
#损失函数
criterion=nn.CrossEntropyLoss()
#优化器
optimizer=optim.Adam(net.parameters(),lr=learning_rate)

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

?训练模型并且打印输出?

#开始训练循环
for epoch in range(epochs):
    #保存当前epoch的结果
    train_rights=[]
    
    for batch_idx,(data,target) in enumerate(train_loader):
        data=data.to(device)
        target=target.to(device)
        net.train()
        output=net(data)
        loss=criterion(output,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right=accuracy(output,target)
        train_rights.append(right)
        
        if batch_idx%100==0:
            net.eval()
            val_rights=[]
            
            for (data,target) in test_loader:
                data=data.to(device)
                target=target.to(device)
                output=net(data)
                right=accuracy(output,target)
                val_rights.append(right)
                
            #准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch+1, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.data, 
                100. * train_r[0].numpy() / train_r[1], 
                100. * val_r[0].numpy() / val_r[1]))

?

?打印最后得到的训练模型总的准确率

train_rights=[]
val_rights=[]
for (data,target) in train_loader:
    data=data.to(device)
    target=target.to(device)
    train_output=net(data)
    right=accuracy(train_output,target)
    train_rights.append(right)
for (data,target) in test_loader:
    data=data.to(device)
    target=target.to(device)
    test_output=net(data)
    right=accuracy(test_output,target)
    val_rights.append(right)
#总准确率计算
train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

print('训练集总准确率: {:.2f}%\t测试集总准确率: {:.2f}%'.format( 
    100. * train_r[0].numpy() / train_r[1], 
    100. * val_r[0].numpy() / val_r[1]))

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-01 16:51:20  更:2021-10-01 16:54:45 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 16:18:32-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码