IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 神经网络分类任务 -> 正文阅读

[人工智能]神经网络分类任务

神经网络分类任务

1.传入数据

数据大小x_train.shape=(x,y)

x=数据集中样本个数

y=每个样本的横*纵**像素

若分类结果为n分类,每个像素点分类为1*n的概率矩阵

将数据转换为tensor模式

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)#输入测试集以及验证集

在这里插入图片描述

由上,数据分为训练集输入x_train,训练集输出y_train,验证集输入x_valid,验证集输出y_valid。可以看见,训练集大小为50000784,而测试集大小为10000×784。每784个像素点为一个单独的数据,即说明当前训练集有50000个数据,而验证集有10000个数据。

设置batch大小为bs=64,取得单个batch大小为64*784。

例如:

在这里插入图片描述

2.引入库

functional库可以引入损失函数,激活函数等许多层和函数

import torch.nn.functional as F

例如调用损失函数:

loss_func = F.cross_entropy

3.创建model

采用nn库构造函数:

from torch import nn

class Mnist_NN(nn.Module):#继承nn.Module
    def __init__(self):#定义方法 
        super().__init__()#调用构造函数   必须
        
        self.hidden1 = nn.Linear(784, 128)#定义自身模块隐层1
        self.hidden2 = nn.Linear(128, 256)#定义自身隐层2
        self.out  = nn.Linear(256, 10)#定义输出

    def forward(self, x):#定义前向传播
        x = F.relu(self.hidden1(x))#x传入到第一层,利用激活函数
        x = F.relu(self.hidden2(x))#第一层传递至第二层,激活函数
        x = self.out(x)#输出
        return x
    
    net = Mnist_NN()#定义网络

nn模型只需要输入前向传播过程方式与激活函数,即会自动进行反向传播。

4.通过使用TensorDataset和DataLoader定义batch

为了进行batch操作,首先要将数据转为tensor格式,再对data进行分批读取。

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train)#tensor格式转换
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)#取一个batch的数据, shuffle=True表示洗牌

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)

定义得到数据,从测试总矩阵train_ds中取出bs个数据给train_dl,valid_ds中取出bs*2个数据给valid_dl。

def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),##?为什么*2
    )#定义一个取参项

若定义当前batch大小为64,则valid_dl中数据个数为64

执行for循环操作,单步执行64个数据中的16个进行误差分析

5.定义训练器

def loss_batch(model, loss_func, xb, yb, opt=None):#opt为可选项,若未传入opt优化器则不会进行反向传播以及参数更新
    loss = loss_func(model(xb), yb)#loss_func = F.cross_entropy,先前定义好的,此处传入预测值与真实值进行计算

    if opt is not None:
        loss.backward()#反向传播,计算梯度
        opt.step()#优化步骤,更新
        opt.zero_grad()#优化参数的清零

    return loss.item(), len(xb)#返回结果

在loss_batch中传入模型,损失函数,x的数据集,y的数据集,以及是否使用优化器。

在函数中进行损失率的计算,若使用优化器,则对参数进行优化

在这里插入图片描述

6.定义优化器

from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)#定义步长

采用先前定义好的Mnist_NN模型,进行SGD优化。

梯度下降法是神经网络中更新参数常用的方法,根据损失进行反向传播(大致为求偏导过程),进行权重及偏移量的更新。

引入batch概念,将数据分批进行梯度下降,提高了拟合的真实性。

SGD为随机梯度下降法,该算法旨在优化神经网络中更新权重。即在样本容量中随机取n个样本进行梯度下降,这样的好处在于避免了批梯度下降法中出现平滑点导致参数无法进行进一步优化的问题,如下图(转载)。

但这样的方法运算时间较高,在此之后提出了一种MBGD的小批量梯度下降方法。关于batch与梯度下降,可参考博客:

(56条消息) 神经网络训练中batch的作用(从更高角度理解)_做个好男人!的博客-CSDN博客_神经网络batch

7.验证及测试

import numpy as np

def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()#按照输入模型进行训练
        for xb, yb in train_dl:#对当前batch进行损失率及参数更新
            loss_batch(model, loss_func, xb, yb, opt)#计算损失率以及进行参数更新

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )#取
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)#计算缺失率
        print('当前step:'+str(step), '验证集损失:'+str(val_loss),'数据大小'+str(yb.shape))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ij4EvUAO-1656682382947)(C:\Users\user\AppData\Roaming\Typora\typora-user-images\image-20220701162043169.png)]

8.网络训练

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)#按batch取数据
model, opt = get_model()#得到模型
fit(25, model, loss_func, opt, train_dl, valid_dl)#进行迭代

获取数据,获取模型及优化算法。

使用定义好的fit对网络进行训练以及测试。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-07-04 22:54:12  更:2022-07-04 22:56:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 8:23:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计