IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch-神经网络分类任务 -> 正文阅读

[人工智能]pytorch-神经网络分类任务

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
x_train.shape  #第一个表示样本个数,第二个表示特征28*28*1=784

(50000, 784)

784是mnist数据集每个样本的像素点个数

注意数据需转换成tensor才能参与后续建模训练

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

torch.nn.functional 很多层和函数在这里都会见到

torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

import torch.nn.functional as F

loss_func = F.cross_entropy

def model(xb):
    return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) #随机初始化w
bs = 64
bias = torch.zeros(10, requires_grad=True)  #偏置:常数或随机初始化,对结果影响较小

print(loss_func(model(xb), yb))

创建一个model来更简化代码

1.必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
2.无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
3.Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器

from torch import nn

class Mnist_NN(nn.Module):
    def __init__(self):    #构造函数
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)  #全连接层,输入784个像素点,输出128个特征
        self.hidden2 = nn.Linear(128, 256)   #第二个隐层,第一层的输出,是第二层的输入,
        #256个神经元
        self.out  = nn.Linear(256, 10) #10个分类
        self.dropout=nn.Dropout(0.5)  #按照百分比 自己定
        

    def forward(self, x):    #前向传播 自己定义 反向传播是自动的
        #输入x,是数据给的,batch 64*784
        x = F.relu(self.hidden1(x))  #得到64*128的结果
        x=self.dropout(x)            #每个全连接层都要加dropout,随机杀除神经元,防止过拟合
        x = F.relu(self.hidden2(x))  #得到65*256 的向量
        x=self.dropout(x)  
        x = self.out(x)  #输出层,256*10的矩阵
        return x
        
net = Mnist_NN()
print(net)

Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
)

可以打印我们定义好名字里的权重和偏置项

for name, parameter in net.named_parameters():
    print(name, parameter,parameter.size())

使用TensorDataset和DataLoader来简化?

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train) #做一个封装
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )

一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout

import numpy as np

def fit(steps, model, loss_func, opt, train_dl, valid_dl): #训练函数
        #迭代  net               优化器
    for step in range(steps):
        model.train()  #训练模式,更新每一层权置和参数 
        for xb, yb in train_dl: #打包好数据一个个去返
            loss_batch(model, loss_func, xb, yb, opt) # xb, yb,输入的数据和标签

        model.eval()  #验证,不更新
        with torch.no_grad():
            #返回当前损失和数量,一个loss对应一个num
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) #平均损失
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))
#zip用法
a=[1,2,3]
b=[4,5,6]
zipped=zip(a,b)
print(list(zipped))   #[(1, 4), (2, 5), (3, 6)]
a2,b2=zip(*zip(a,b))
print(a2)  #(1, 2, 3)
print(b2)  #(4, 5, 6)
from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.Adam(model.parameters(), lr=0.001)  #返回 优化器 lr学习率
def loss_batch(model, loss_func, xb, yb, opt=None): #1.计算loss 2.更新w,b 
    loss = loss_func(model(xb), yb)# model(xb)把模型放输入当中,得到预测值   yb真实值

    if opt is not None: #优化器
        loss.backward()  #反向传播,算出每一层的权置参数,梯度
        opt.step()   # 执行更新,沿着梯度方向更新
        opt.zero_grad() #torch每次迭代累加记录,把之前迭代清空

    return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(20, model, loss_func, opt, train_dl, valid_dl)

当前step:0 验证集损失:0.18639104866981507
当前step:1 验证集损失:0.1372520131058991
当前step:2 验证集损失:0.12028736076653004
当前step:3 验证集损失:0.10732126496359706
当前step:4 验证集损失:0.10093651054650545
当前step:5 验证集损失:0.09517242526896298
当前step:6 验证集损失:0.09194612504523247
当前step:7 验证集损失:0.08943103497959673
当前step:8 验证集损失:0.0877913200291805
当前step:9 验证集损失:0.08958465236043557
当前step:10 验证集损失:0.08797709809066728
当前step:11 验证集损失:0.08352214076635428
当前step:12 验证集损失:0.0866958644344937
当前step:13 验证集损失:0.08074293819144368
当前step:14 验证集损失:0.08045687620015815
当前step:15 验证集损失:0.08040115665267222
当前step:16 验证集损失:0.07971061569196172
当前step:17 验证集损失:0.08116058965921402
当前step:18 验证集损失:0.0811522187425755
当前step:19 验证集损失:0.0807436868159566

correct=0
total=0
for xb,yb in valid_dl:  #去验证集里取数据
    outputs=model(xb)   #128*10 每个样本属于各个类别的概率值
    _, predicted=torch.max(outputs.data,1) #返回最大的值和索引,算概率值哪个大,沿着1这个维度(行)
    total+=yb.size(0)
    correct+=(predicted==yb).sum().item() #预测值=真实值.对了多少.转换成数组格式(可绘制图)
    
print('Accuracy of the network on the 10000 test images:%d %%' %(
100 * correct / total))

Accuracy of the network on the 10000 test images:97 %

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-09-04 01:12:01  更:2022-09-04 01:13:46 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 19:12:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计