IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 使用逻辑回归训练手写数字识别模型 -> 正文阅读

[人工智能]使用逻辑回归训练手写数字识别模型

学习使用pytorch训练手写数字识别的模型和测试。

# 训练+测试
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
from net import CNN
from torch.nn import functional as F


# 超参数
EPOCH = 8  # 训练整批数据的次数
BATCH_SIZE = 50
learning_rate = 0.001  # 学习率

# Torch中的DataLoader是用来包装数据的工具,它能帮我们有效迭代数据,这样就可以进行批训练
# 用于下载minist数据集,并进行正则化、转化成tensor格式,设置训练、测试时的batch大小,打乱数据
train_loader = Data.DataLoader(
    torchvision.datasets.MNIST(
        root='./data/',  # 保存或提取的位置  会放在当前文件夹中
        train=True,  # true说明是用于训练的数据,false说明是用于测试的数据
        download=True,

        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,),(0.3081,))])),  # 转换PIL.Image or numpy.ndarray
    batch_size=BATCH_SIZE,
    shuffle=True  # 是否打乱数据,一般都打乱
)


test_loader = Data.DataLoader(
    torchvision.datasets.MNIST(
        root='./data/',  # 保存或提取的位置  会放在当前文件夹中
        train=False,  # true说明是用于训练的数据,false说明是用于测试的数据
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,),(0.3081,))])
    ),

    batch_size=BATCH_SIZE,
    shuffle=True  # 是否打乱数据,一般都打乱
)

w1,b1=torch.randn(200,784,requires_grad=True),\
        torch.randn(200,requires_grad=True)

w2,b2=torch.randn(200,200,requires_grad=True),\
        torch.randn(200,requires_grad=True)

w3,b3=torch.randn(10,200,requires_grad=True),\
        torch.randn(10,requires_grad=True)

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

#传递函数,也是在此构建的特征提取器
#设置了一个三层的网络,网络的函数是y=wx+b 进行特征提取,使用relu函数作为激活函数,
# 最终得到的是一个1*10 向量
def forward(x):
    x=x@w1.t()+b1
    x=F.relu(x)
    x = x @ w2.t() + b2
    x = F.relu(x)
    x = x @ w3.t() + b3
    x = F.relu(x)
    return  x

#使用SGD作为优化器
optimizer = torch.optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)
#使用交叉损失函数
criteon = nn.CrossEntropyLoss()

# 开始训练
for epoch in range(EPOCH):
    for step, (data, target) in enumerate(train_loader):  # 分配batch data
        data = data.view(-1,28*28)  #把图片设置成28*28的大小

        output = forward(data)  # 先将数据放到cnn中计算output
        loss = criteon(output, target)  # 输出和真实标签的loss,二者位置不可颠倒

        optimizer.zero_grad()  # 清除之前学到的梯度的参数
        loss.backward()  # 反向传播,计算梯度
        optimizer.step()  # 应用梯度 ,进行梯度下降更新网络参数

        if step % 100 == 0:
            print('Train Epoch:{} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, step*len(data) , len(train_loader.dataset),
                100.*step/len(train_loader) , loss.item()))     #使用format 格式化输出的时候,{} 内部不能有空格,不然会报格式化错误

test_loss = 0
correct = 0
for epoch in range(EPOCH):
    for step, (data, target) in enumerate(train_loader):  # 分配batch data
        data = data.view(-1, 28 * 28)  # 把图片设置成28*28的大小

        output = forward(data)  # 先将数据放到cnn中计算output
        test_loss += criteon(output, target)  # 输出和真实标签的loss,二者位置不可颠倒

        pred=output.data.max(1)[1]#得到最大的输出作为预测值
        correct = pred.eq(target.data).sum()#计算总的正确率,还需要处于测试数据长度才是平均的正确率

test_loss /= len(test_loader.dataset)
#使用format 进行格式化输出,len 函数可以计算列表的函数
print('\n Test set :Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss,correct, len(test_loader.dataset),
    100.*correct/ len(train_loader.dataset)
))

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-26 08:51:52  更:2021-11-26 08:52:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 4:10:18-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码