学习使用pytorch训练手写数字识别的模型和测试。
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
from net import CNN
from torch.nn import functional as F
EPOCH = 8
BATCH_SIZE = 50
learning_rate = 0.001
train_loader = Data.DataLoader(
torchvision.datasets.MNIST(
root='./data/',
train=True,
download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=BATCH_SIZE,
shuffle=True
)
test_loader = Data.DataLoader(
torchvision.datasets.MNIST(
root='./data/',
train=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,),(0.3081,))])
),
batch_size=BATCH_SIZE,
shuffle=True
)
w1,b1=torch.randn(200,784,requires_grad=True),\
torch.randn(200,requires_grad=True)
w2,b2=torch.randn(200,200,requires_grad=True),\
torch.randn(200,requires_grad=True)
w3,b3=torch.randn(10,200,requires_grad=True),\
torch.randn(10,requires_grad=True)
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)
def forward(x):
x=x@w1.t()+b1
x=F.relu(x)
x = x @ w2.t() + b2
x = F.relu(x)
x = x @ w3.t() + b3
x = F.relu(x)
return x
optimizer = torch.optim.SGD([w1,b1,w2,b2,w3,b3],lr=learning_rate)
criteon = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (data, target) in enumerate(train_loader):
data = data.view(-1,28*28)
output = forward(data)
loss = criteon(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print('Train Epoch:{} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
epoch, step*len(data) , len(train_loader.dataset),
100.*step/len(train_loader) , loss.item()))
test_loss = 0
correct = 0
for epoch in range(EPOCH):
for step, (data, target) in enumerate(train_loader):
data = data.view(-1, 28 * 28)
output = forward(data)
test_loss += criteon(output, target)
pred=output.data.max(1)[1]
correct = pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print('\n Test set :Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss,correct, len(test_loader.dataset),
100.*correct/ len(train_loader.dataset)
))
|