import torch
from torch import nn
from torch import optim
from torch.nn.parameter import Parameter
import torchvision.transforms as transforms
import torchvision
torch.manual_seed(1)
train_batch_size = 64
test_batch_size = 64
Ⅰ. 数据读取器
def get_data():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size)
return trainloader, testloader
在get_data() 函数中,我们下载了CIFAR10 数据集并加载该数据集。
可以看一下CIFAR10 数据集中的图像是什么样子的
import matplotlib.pyplot as plt
import cv2
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
examples = enumerate(trainloader, 0)
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title('class: {}'.format(classes[example_targets[i]]))
plt.xticks([])
plt.yticks([])
plt.show()
print(example_data.shape)
这些图像虽然分辨率不高(32×32),但最基本还是可以认出图像中的主要物体是什么。
Ⅱ. 搭建网络
class CNNNet(nn.Module):
def __init__(self):
super(CNNNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.ReLU(),
nn.Conv2d(6, 16, kernel_size=5),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.ReLU(),
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
??由于CIFAR10 数据集中的图像都是RBG 图像,其具有三个通道,所以在网络中的第一个卷积层输入为3个特征。32×32图像经过第一个输入特征为3,输出特征为6,卷积核大小为5的卷积层后输出的特征图尺寸为28×28;经过第一个池化层后变成14×14;在经过一个卷积层与池化层后输出特征图尺寸为5×5;最后经过三层全连接层输出预测结果。
Ⅲ. 模型训练
net = CNNNet()
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
net.to(device)
trainloader, testloader = get_data()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)
定义损失函数为交叉熵损失,优化器为Adam,学习率为0.01。
开始训练
train_losses = []
train_counter = []
def train(model, optimizer, loss_fn, train_loader, epochs=10, device='cpu'):
for epoch in range(1, epochs+1):
model.train()
for train_idx, (inputs, labels) in enumerate(train_loader, 0):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
optimizer.zero_grad()
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if train_idx % 10 == 0:
train_losses.append(loss.item())
counter_idx = train_idx * len(inputs) + epoch * len(train_loader.dataset)
train_counter.append(counter_idx)
print('epoch: {}, [{}/{}({:.0f}%)], loss: {:.6f}'.format(\
epoch, train_idx*len(inputs), len(train_loader.dataset),\
100*(train_idx*len(inputs)+(epoch-1)*len(train_loader.dataset))/(len(train_loader.dataset)*(epochs)),\
loss.item()))
print('training ended')
train(net, optimizer, loss_fn, trainloader, epochs=20)
部分训练结果 画个图看看模型的训练结果
import matplotlib.pyplot as plt
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.legend(['Train Loss'], loc='upper right')
plt.xlabel('Training images number')
plt.ylabel('Loss')
plt.show()
可以看到模型并没有很好地收敛,将epochs 设为5 将epochs 设为10 将epochs 设为20 由于神经网络结构简单,即使经过大量训练之后依然不能达到很好的预测结果
Ⅳ. 模型测试
import numpy as np
test_avg_loss = 0
def test(model ,test_loader, loss_fn, device='cpu'):
correct = 0
total = 0
test_loss = []
with torch.no_grad():
for test_idx, (inputs, labels) in enumerate(test_loader, 0):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = loss_fn(outputs, labels)
test_loss.append(loss.item())
index, value = torch.max(outputs.data, 1)
total += labels.size(0)
correct += int((value==labels).sum())
test_avg_loss = np.average(test_loss)
print('Total: {}, Correct: {}, Accuracy: {:.2f}%, AverageLoss: {:.6f}'.format(\
total, correct, (correct/total*100), test_avg_loss))
test(net, testloader, loss_fn)
Total: 10000, Correct: 4941, Accuracy: 49.41%, AverageLoss: 1.535578 准确率能有差不多50%,可能是因为网络结构相对简单,所以即使经过多次训练之后正确率依然不算很高。
看一下部分预测结果
import matplotlib.pyplot as plt
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
examples = enumerate(testloader)
batch_idx, (inputs, labels) = next(examples)
with torch.no_grad():
outputs = net(inputs)
fig = plt.figure()
for i in range(0, 10):
plt.subplot(2, 5, i+1)
plt.imshow(inputs[i][0], cmap='gray', interpolation='none')
plt.title('GT: {}, Prediction: {}'.format(\
classes[(labels[i])], classes[outputs.data.max(1, keepdim=True)[1][i].item()]))
plt.xticks([])
plt.yticks([])
plt.show()
Ⅴ. 完整代码
PyTorch_CIFAR10.py
|