'''
Author: 365JHWZGo
Description: 12.cnn study
Date: 2021-10-28 19:21:45
FilePath: \pytorch\pytorch\cnn\day10-1.py
LastEditTime: 2021-10-28 23:26:28
LastEditors: 365JHWZGo
'''
import torch
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.utils.data as Data
torch.manual_seed(1)
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
EPOCH = 1
train_data = torchvision.datasets.MNIST(
root='./mnist',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST
)
train_loader = Data.DataLoader(
dataset= train_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
test_data = torchvision.datasets.MNIST(
root = './mnist',
train = False
)
test_loader = Data.DataLoader(
dataset= test_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
test_x = torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride = 1,
padding = 2
),
nn.ReLU(),
nn.MaxPool2d(
kernel_size=2
)
)
self.conv2 = nn.Sequential(
nn.Conv2d(
16,32,5,1,2
),
nn.ReLU(),
nn.MaxPool2d(
kernel_size=2
)
)
self.out = nn.Linear(32*7*7,10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0),-1)
output = self.out(x)
return output
cnn = CNN()
optimizer = torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()
if __name__ == '__main__':
for epoch in range(EPOCH):
for step,(batch_x,batch_y) in enumerate(train_loader):
batch_x = Variable(batch_x)
batch_y = Variable(batch_y)
output = cnn(batch_x)
loss = loss_func(output,batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%50==0:
out_y = cnn(test_x)
pred_y = torch.max(out_y,1)[1].data.numpy().squeeze()
accuracy = float((pred_y==test_y.data.numpy()).astype(int).sum())/float(test_y.size(0))
print(
'epoch:',epoch,
'loss:%.4f'%loss,
'accuracy:%.2f'%accuracy
)
|