IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于Pytorch的Mnist手写数字识别 -> 正文阅读

[人工智能]基于Pytorch的Mnist手写数字识别

# !/usr/bin/env Python3
# -*- coding: utf-8 -*-
# @version: v1.0
# @Author   : Meng Li
# @contact: 925762221@qq.com
# @FILE     : torch_mnist.py
# @Time     : 2022/5/31 9:29
# @Software : PyCharm
# @site:
# @Description : 自己动手实现mnist数据集的10分类任务
# 同等条件下,batch_size 越小,模型越收敛。但是更容易震荡。learning_rate越小,模型收敛速度越慢

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchsummary
import torch.optim as optim
from torch.utils.data import Dataset
import matplotlib.pylab as plt


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 100)
        self.fc2 = nn.Linear(100, 10)
        self.crition = torch.nn.CrossEntropyLoss()
        pass

    def forward(self, x, y):
        batch_size, _, h, w = x.size()
        x = x.view(-1, h * w)
        output = F.relu(self.fc1(x))
        output = self.fc2(output)
        loss = self.crition(output, y)
        val, index = torch.max(output, 1)
        acc = torch.eq(index, y).float().cpu().sum()
        return loss, acc.float() / y.size(0), index


def train():
    net = Net()
    show_sum_flg = False
    if show_sum_flg:
        torchsummary.summary(net, (28, 28))
    train_data = torchvision.datasets.MNIST(root="./", train=True, transform=torchvision.transforms.ToTensor(),
                                            download=False)
    batch_size = 64
    learning_rate = 0.001
    optimizer = optim.SGD(net.parameters(), lr=learning_rate)

    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
    epoch = 30
    max_acc = 0
    acc = 0
    for i in range(epoch):
        for image, label in train_iter:
            loss, acc, _ = net(image, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print("epoch {0}  acc {1}".format(i, acc))
        if acc > max_acc:
            max_acc = acc
            torch.save(net, 'limeng.pth')


def test():
    net = torch.load('limeng.pth')
    net.eval()
    train_data = torchvision.datasets.MNIST(root="./", train=True, transform=torchvision.transforms.ToTensor(),
                                            download=False)
    batch_size = 10
    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)

    for image, label in train_iter:
        _, _, predict = net(image, label)
        for i in range(batch_size):
            imagei = image[i, 0, :, :]
            plt.subplot(2, 5, i+1)
            plt.imshow(imagei)
            plt.title("{0}".format(predict[i]))
        plt.show()
        break


if __name__ == '__main__':
    # train()
    test()

先上代码,工作期间接触了Tensorflow和Pytorch两种框架,但是总得来说,pytorch由于编码语法规范更接近于python原生语法,所以更容易上手。作为深度学习中的"hello world",还是有必要自己写一下整个数据输入到模型训练,模型保存再到模型测试的全流程。

测试模型时,运行效果图大概是这样的:

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-03 23:58:52  更:2022-06-04 00:00:38 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 2:28:04-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码