# !/usr/bin/env Python3
# -*- coding: utf-8 -*-
# @version: v1.0
# @Author : Meng Li
# @contact: 925762221@qq.com
# @FILE : torch_mnist.py
# @Time : 2022/5/31 9:29
# @Software : PyCharm
# @site:
# @Description : 自己动手实现mnist数据集的10分类任务
# 同等条件下,batch_size 越小,模型越收敛。但是更容易震荡。learning_rate越小,模型收敛速度越慢
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchsummary
import torch.optim as optim
from torch.utils.data import Dataset
import matplotlib.pylab as plt
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 100)
self.fc2 = nn.Linear(100, 10)
self.crition = torch.nn.CrossEntropyLoss()
pass
def forward(self, x, y):
batch_size, _, h, w = x.size()
x = x.view(-1, h * w)
output = F.relu(self.fc1(x))
output = self.fc2(output)
loss = self.crition(output, y)
val, index = torch.max(output, 1)
acc = torch.eq(index, y).float().cpu().sum()
return loss, acc.float() / y.size(0), index
def train():
net = Net()
show_sum_flg = False
if show_sum_flg:
torchsummary.summary(net, (28, 28))
train_data = torchvision.datasets.MNIST(root="./", train=True, transform=torchvision.transforms.ToTensor(),
download=False)
batch_size = 64
learning_rate = 0.001
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
epoch = 30
max_acc = 0
acc = 0
for i in range(epoch):
for image, label in train_iter:
loss, acc, _ = net(image, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch {0} acc {1}".format(i, acc))
if acc > max_acc:
max_acc = acc
torch.save(net, 'limeng.pth')
def test():
net = torch.load('limeng.pth')
net.eval()
train_data = torchvision.datasets.MNIST(root="./", train=True, transform=torchvision.transforms.ToTensor(),
download=False)
batch_size = 10
train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
for image, label in train_iter:
_, _, predict = net(image, label)
for i in range(batch_size):
imagei = image[i, 0, :, :]
plt.subplot(2, 5, i+1)
plt.imshow(imagei)
plt.title("{0}".format(predict[i]))
plt.show()
break
if __name__ == '__main__':
# train()
test()
先上代码,工作期间接触了Tensorflow和Pytorch两种框架,但是总得来说,pytorch由于编码语法规范更接近于python原生语法,所以更容易上手。作为深度学习中的"hello world",还是有必要自己写一下整个数据输入到模型训练,模型保存再到模型测试的全流程。
测试模型时,运行效果图大概是这样的:
?
|