目录
1. 加载MNIST的train数据和test数据
2. 定义神经网络
3. 使用定义的网络进行训练
4. 使用测试集,计算预测精度
5. 辅助工具函数
1. 加载MNIST的train数据和test数据
import torch
import torchvision # 处理图像视频, 包含一些常用的数据集、模型、转换函数等等
from torch import nn, optim
from torch.nn import functional as F
from matplotlib import pyplot as plt
from utils import plot_curve, plot_image, one_hot
batch_size = 512
# step1. 加载数据集
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data_john', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,)
)
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data_john/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,)
)
])),
batch_size=batch_size, shuffle=False)
# # test: 显示下训练数据集的前6张图像及对应的标签
# x,y = next(iter(train_loader))
# print(x.shape, y.shape, x.min(), x.max())
# plot_image(x,y,"image_gt")
显示下训练数据集的前6张图像及对应的标签:
2. 定义神经网络
这里采用 3个线性层,做简单示范。
# step2. 定义神经网络
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
# 定义三个线性层 y = wx+b
# 输入X的size是:[batch_size, 28*28=784]
# y = w1 * x + b1, 例如:参数数量:w1.size是[256,784] (一张图像x的size是[784,1]), b.size是[256]
self.fc1 = nn.Linear(28 * 28, 256) # 28*28是输入图像的大小,256是自定义的中间层大小
# y = w2 * x + b2
self.fc2 = nn.Linear(256, 64) # 中间层数的结果,本层输入层数取决于上一层的输出层数,本层输出决定了下一层的输入层数。
# y = w3 * x + b3
self.fc3 = nn.Linear(64, 10) # 10是要求的输出分类层数
def forward(self, x):
x = F.relu(self.fc1(x)) # 使用激活函数,增加非线性
x = F.relu(self.fc2(x))
x = self.fc3(x) # 最后一层根据网络结果输出
return x
3. 使用定义的网络进行训练
# step3. 开始训练
net = MyNet()
# 定义梯度下降方式
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []
for epoch in range(3):
for batch_idx, (x, y) in enumerate(train_loader): # train_loader中有 n 个 (x,y), 每一个 x和 y包含 batch_size张图像,所以总的图像数量是= batch_idx * batch_size
x = x.view(x.size(0), 28 * 28) # view 等价于reshape, [512,1,28*28] => [512,28*28]
out = net(x) # [batch_size,10]
y_onehot = one_hot(y) # 如将 [512,] 变成 [512,10], 原来一维数组中对应m行的值n,对应新的二维数组m行的第n列设置为1
loss = F.mse_loss(out, y_onehot)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if batch_idx % 10 == 0:
print(epoch, batch_idx, loss.item())
# 打印训练过程中的loss结果
plot_curve(train_loss)
loss的下降结果:
4. 使用测试集,计算预测精度
# step4. 进行测试,计算预测精度
total_correct = 0
for x, y in test_loader:
x = x.view(x.size(0), 28 * 28)
out = net(x)
pred = out.argmax(dim=1)
correct = pred.eq(y).sum().float().item()
total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print("acc: ", acc)
# 可视化部分预测结果
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28)) # 二维 [batch_size, 10]
pred = out.argmax(dim=1) # 一维 [batch_size,]
plot_image(x, pred, 'image_predict')
预测精度:
可视化部分预测结果:
5. 辅助工具函数
定义到uitls.py文件中:
用于显示图像,打印一维数组,one hot操作
import torch
from matplotlib import pyplot as plt
# 绘制一维数据图
def plot_curve(data):
fig = plt.figure() # 定义一张图纸
plt.plot(range(len(data)), data, color="blue") # 绘制一维数组
plt.legend(["value"], loc="upper right") # 添加图例,即数据说明标签
plt.xlabel("step")
plt.ylabel("value")
plt.show()
def plot_image(img, label, name):
'''
:param img: 比如:torch.Size([batch_size=512, 1, 28, 28])
:param label: 比如:torch.Size([512])
:param name: string
'''
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1) # 2*3个小图像
plt.tight_layout()
plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap="gray", interpolation="none") # 图像进行正则化,然后显示出来
plt.title("{}: {}".format(name, label[i].item())) # 显示每一个plot的标题
plt.xticks([]) # 设置x轴的刻度标签为空,即不显示刻度
plt.yticks([])
plt.show()
def one_hot(label, depth=10):
out = torch.zeros(label.size(0), depth) # 定义一个 [batch_size, 10]大小的矩阵
idx = torch.LongTensor(label).view(-1, 1) # 把 label reshape成 [batch_size, 1]尺寸的2维tensor
out.scatter_(dim=1, index=idx, value=1) # 改变 out的第dim=1维度的数据,out中值被改变值的索引,是index中对应的值, 填充的值是 1,
# 如第2个样本是“6”,则第1行第5列(矩阵索引从0开始)填充为1.
return out
# if __name__ == '__main__':
# data = [1,2,3,4,5,4,3,2,5,6,8]
# plot_curve(data)
参考:
深度学习入门_哔哩哔哩_bilibili
|