IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 手写数字识别pytorch 附loss图和结果验证图 -> 正文阅读

[人工智能]手写数字识别pytorch 附loss图和结果验证图

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
n_epochs = 3
random_seed = 1
BATCH_SIZE=512
torch.manual_seed(random_seed)
train_losses = []
train_counter = []
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=BATCH_SIZE, shuffle=True)

examples = enumerate(test_loader)
# print(examples)
batch_idx, (example_data, example_targets) = next(examples)

# for batch_idx, (data, target) in enumerate(train_loader):
#     # print(batch_idx)
#     # print('data:',data)
#     # print(data.shape)
#     # print(target)
#     print(len(train_loader.dataset))
#     print( len(data))
# print(example_targets,example_targets.shape)
# print(example_data,example_data.shape)
# print(batch_idx)
class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet,self).__init__()
        self.layer1 = nn.Linear(784,50)
        self.layer2 = nn.Linear(50,10)
    def forward(self,x):
        #输入层到隐藏层,使用tanh激活函数
        x = self.layer1(x.reshape(-1,784))
        x = torch.tanh(x)
        #隐藏层到输出层,使用relu激活函数
        x = self.layer2(x)
        x = F.relu(x)
        #log(softmax)操作,使用NLLLoss损失函数
        x = F.log_softmax(x,dim=1)
        return x
model=MNISTNet()
optimizer=optim.Adam(model.parameters(),)
def train(epoch):

    model.train()
    for batch_idx,(data,target) in enumerate(train_loader):
        #batch_idx为512张图片分多少批次输入,data.shape=[512,1,28,28],target.shape=[512]
        #target即为每张照片代表的数字
        optimizer.zero_grad()
        output=model(data)
        # pred=output.argmax(dim=1)#找到每行中数值最大的索引
        # #[0,0.1,0.2,0.3,0.4]返回为5表示数字5
        loss = F.nll_loss(output, target)
        # train_loss=train_loss.append(loss.item())#注意加item
        loss.backward()
        optimizer.step()
        if batch_idx%3000==0:
            print('Train Epoch: {} \tLoss: {:.6f}'.format(
                epoch,  loss.item()))
            print(len(train_loader.dataset))
        train_losses.append(loss.item())
        train_counter.append(
                (batch_idx * 512) + ((epoch - 1) * len(train_loader.dataset)))
        #先将图片按照117次batch_idx划分,之后一次输入512张,一共训练epoch次
            # print(pred==target)


# 测试集
def test():
    model.eval()
    correct=0
    test_loss=0
    with torch.no_grad():#括号
        for (data,target) in test_loader:
            output=model(data)
            test_loss=test_loss+F.nll_loss(output,target).item()#交叉熵损失
            pred = output.argmax(dim=1)
            # pred=output.argmax(dim=1)#找到每行中数值最大的索引
            # #[0,0.1,0.2,0.3,0.4]返回为5表示数字5

            correct += pred.eq(target.data.view_as(pred)).sum().item()#统计累积正确率
            # writer.add_scalar("损失/步骤", loss[epoch], epoch)
            #其中eq是判断是否相等,相等则计数,target.data.view_as(pred)为将pred和targe化为相同的矩阵
            # 因为keepdim维度和输入相同,所以并没有自动降维仍旧是2维,但是target是1维(一维数字嘛),所以需要target维度变成和pred一样的维度。view_as
            # 返回被视作与给定的tensor相同大小的原tensor。
            # 由此看完全代码完全等效于图片中的,亲测有效。max自动降维,就不需要target升维了。

        test_loss=test_loss/len(test_loader.dataset)#平均值loss
        print('Test--Average. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset)))
        return pred
for epoch in range(1, n_epochs + 1):
  train(epoch)
  a=test()
# fig = plt.figure()
# for i in range(6):
#   plt.subplot(2,3,i+1)#subplot的参数23i意思就是,把整个大的画板分成2*3,这个子图在第i个位置。
#   plt.tight_layout()#放大参数
#   plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
#   plt.title("Ground Truth: {}".format(example_targets[i]))
#
# plt.show()
def loss_plot():
    fig = plt.figure()
    plt.plot(train_counter, train_losses, color='blue')
    # plt.scatter(test_counter, test_losses, color='red')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')
    plt.show()
loss_plot()
def test_plot():
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    with torch.no_grad():
        output=model(example_data)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)#subplot的参数23i意思就是,把整个大的画板分成2*3,这个子图在第i个位置。
        plt.tight_layout()#放大参数
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        # plt.title("Ground Truth: {}".format(output.data.max(1, keepdim=True)[1][i].item()))
        plt.title("label=: {},predict=: {}".format(example_targets[i],output.argmax(dim=1)[i].item()))
    plt.show()
test_plot()

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-19 17:37:57  更:2021-11-19 17:39:19 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 6:28:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码