IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 2021-6至2021-11周报 -> 正文阅读

[人工智能]2021-6至2021-11周报

摘要

1 文献阅读
2 深度学习实践:实践手写数字识别的代码
3 任务项目:将虚拟机文件与虚拟机配置文件打包好,上传至云盘,只需最后下载。
4 带毕设:沟通进度情况

深度学习实践

手写数字识别代码:
一、搭建模型Model.py,构造一个CNN模型

import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter


class CnnModel(nn.Module):
    def __init__(self):
        super(CnnModel, self).__init__()
        self.x1 = torch.ones(1, 3, 5, 5)
        self.model1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, stride=1, padding=0),
        nn.MaxPool2d(2),
        nn.Conv2d(3, 3, 3, 1, 0),
        nn.MaxPool2d(2)
        )
        self.model2 = nn.Sequential(
        nn.Flatten(),
        nn.Linear(75, 10),
        nn.Softmax(1)
        )

    def forward(self, x):
        self.x1 = self.model1(x)
        x = self.model2(self.x1)
        return x

if __name__ == '__main__':      #测试model是否正确
    x = torch.ones(1, 1, 28, 28)
    model = CnnModel()
    output = model.forward(x)
    writer = SummaryWriter("./logs_model")
    writer.add_graph(model, x)
    writer.close()

二、导入训练集,导入模型进行训练的Minist_model.py

import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from Model import *

#读取图片并转换为tensor型数据
trans_tensor = torchvision.transforms.ToTensor() #实例化图片转tensor的类
train_data = torchvision.datasets.MNIST("./data_minst", train=True, transform=trans_tensor, download=True)
test_data = torchvision.datasets.MNIST("./data_minst", train=False, transform=trans_tensor, download=True)
print(train_data)
writer =SummaryWriter("./logs")

#将数据分批量打包
batch = 50
train_data_batch = DataLoader(train_data, batch_size=batch)
test_data_batch = DataLoader(test_data, batch_size=batch)

#模型实例化
MinModel = CnnModel()

#优化器实例化
learingrate = 1e-2
optimizer = torch.optim.SGD(MinModel.parameters(), lr=learingrate)

#损失函数
loss_f = nn.CrossEntropyLoss()

epoach = 12
train_size = 1


for i in range(epoach):
    total_loss = 0
    print("——————————第{}轮训练开始——————————————".format(i+1))
    for data in train_data_batch:
        img, label = data   #取出img和label
        y = MinModel(img)   #得到model输出
        loss = loss_f(y, label)    #计算损失值
        writer.add_scalar("Loss值随训练次数的变化", loss, train_size)
        total_loss += loss  #每一轮的累计损失值
        optimizer.zero_grad()   #每一次训练将梯度清0
        loss.backward()     #自动求导
        optimizer.step()    #用优化器自动更新参数

        acc_train = ((y.argmax(1)==label).sum())/batch
        writer.add_scalar("准确率随训练次数的变化", acc_train, train_size)

        if(train_size%100==0): #每训练100次输出此batch的损失值
            print("第{}次训练的loss值为{}".format(train_size, loss))
        train_size += 1

    print("********第{}轮训练全部的loss值是:{}***********".format(i+1, total_loss))

torch.save(MinModel, "./Model_save/model.pth") #保持model

writer.close()

Minist_model.py将CNN模型导入,以train=true从数据集中导入训练数据、以train=false从数据集中导入测试数据,读取训练图片并转换成tensor型数据,将训练结果输出到logs文件中,将训练好的模型保存到model.pth文件中。

三、Minist_test.py测试模型准确率

import torch
import torchvision.transforms
from Model import *

trans_tensor = torchvision.transforms.ToTensor()
test_dataset = torchvision.datasets.MNIST("./data_minst", train=False, transform=trans_tensor, download=True)

Model_test = torch.load("Model_save/model.pth")
writer = SummaryWriter("./test_logs", )
img, label = test_dataset[0]
img = torch.reshape(img, (1, 1, 28, 28))
test_output = Model_test(img)
writer.add_images("输出图片", Model_test.x1)
print(test_output)
print(label)

取测试数据,调用训练好的模型,输出test_logs文件。
训练过程如下图所示。
0

在终端输入命令

tensorboard --logdir="logs"

得到可视化结果,如下图所示。
1
2

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-13 12:48:58  更:2021-12-13 12:49:26 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 0:26:50-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码