IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PyTorch实现Mnist手写数字识别 -> 正文阅读

[人工智能]PyTorch实现Mnist手写数字识别

首先下载读取Mnist数据集

%matplotlib inline
from pathlib import Path
import requests
%matplotlib inline

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

随机查看数据

#看一个数据
from matplotlib import pyplot as plt
import numpy as np

plt.imshow(x_train[1125].reshape(28,28),cmap="gray")
print(x_train.shape)   

?

?将数据转换为tensor张量形式

import torch

#数据转换为tensor的格式
x_train,y_train,x_valid,y_valid=map(
    torch.tensor,(x_train,y_train,x_valid,y_valid)
)
n,c=x_train.shape

使用nn.moduel构建网络,torch.nn.functional中有很多功能,也会很常用。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

import torch.nn.functional as F

loss_func=F.cross_entropy

#构建网络
from torch import nn
import torch.nn.functional as F

class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1=nn.Linear(784,128)
        self.hidden2=nn.Linear(128,256)
        self.out=nn.Linear(256,10)
        
    def forward(self,x):
        x=F.relu(self.hidden1(x))
        x=F.relu(self.hidden2(x))
        x=self.out(x)
        return x

net=Mnist_NN()
print(net)

打印构建的网络net,可以看到有两个隐藏层,一个输出层。输出层的输出特征是10而不是1,因为这是一个十分类的网络,会对每一个类都输出一个概率,因此输出的是一个由10个概率组成的一维矩阵。例如[0,0.1,0.05,0,0,0,0,0,0.85,0],此输出就代表此输入图像为8的概率为0.85,为1的概率为0.1,为3的概率为0.05,其余数字的概率均为0。

?

?构建网络时,已经自动进行了权重以及偏置的初始化,可以用下面的代码进行打印。nn.moduel构建网络有如下特点。

  • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
  • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
  • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
for name, parameter in net.named_parameters():
    print(name, parameter, parameter.size())
    print('---------------------------------------------')

??打印定义好名字里的权重和偏置项

?使用tenordataset和dataloader来简化batch_size需要编写的代码,调用这两个工具包即可完成batch_size的数据拆分。具体代码如下:

#使用tenordataset和dataloader来简化
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

def get_data(x_train, y_train,x_valid,y_valid, bs):
    train_ds=TensorDataset(x_train,y_train)
    valid_ds=TensorDataset(x_valid,y_valid)
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )

接下来写训练函数fit,其中loss_batch用于每一个batch的损失值计算。除此之外,一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout;测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout。

import numpy as np

def loss_batch(model,loss_func,xb,yb,opt=None):
    loss=loss_func(model(xb),yb)
    
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
        
    return loss.item(),len(xb)

def fit(steps,model,loss_func,opt,train_dl,valid_dl):
    for step in range(steps):
        model.train()
        for xb,yb in train_dl:
            loss_batch(model,loss_func,xb,yb,opt)
            
        model.eval()
        with torch.no_grad():
            losses,nums=zip(
                *[loss_batch(model,loss_func,xb,yb) for xb,yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))

最后的准备工作是写get_model,导入网络模型。

from torch import optim
def get_model():
    model=Mnist_NN()
    return model,optim.SGD(model.parameters(),lr=0.001)

然后三行代码完成手写数字的识别

train_dl, valid_dl = get_data(x_train, y_train, x_valid, y_valid, bs)
model,opt=get_model()
fit(25,model,loss_func,opt,train_dl,valid_dl)

心得:通过这个简单的网络熟悉PyTorch的神经网络编写过程,这个代码其实更注重调用,并不是完全按照前向传播后向传播的顺序一步一步构建一个网络,而是写了很多函数,最后主要的三行代码就完成了网络的训练。?

?

?

?

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-30 11:56:30  更:2021-09-30 11:58:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 12:39:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码