利用PyTorch训练一个简单的分类器,以手写数字的识别为例。
一、数据
直接用torchvision自带的方法下载构建即可。
def data():
dataset_train = torchvision.datasets.MNIST(root='./', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
dataset_test = torchvision.datasets.MNIST(root='./', train=False, transform=torchvision.transforms.ToTensor(),
download=True)
train_data_loader = torch.utils.data.DataLoader(dataset_train, batch_size=2, num_workers=4, drop_last=True)
test_data_loader = torch.utils.data.DataLoader(dataset_test, batch_size=1, num_workers=1)
return train_data_loader, test_data_loader
?二、构建一个简单的网络结构
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1)
self.bn = torch.nn.BatchNorm2d(1)
self.relu = torch.nn.ReLU()
self.classifer = torch.nn.Linear(in_features=196, out_features=10)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = x.view(x.shape[0], -1)
x = self.classifer(x)
return x
三、训练方法与验证方法
def train():
model = Net()
model.cuda()
model.train()
train_data, test_data = data()
loss_fn = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for i_epoch in range(100):
for i, data_iter in enumerate(train_data):
optimizer.zero_grad()
img, label = [x.cuda() for x in data_iter]
out = model(img)
loss = loss_fn(out, label)
loss.backward()
optimizer.step()
if i % 500 ==0:
print(loss.item(), 'Epoch:{}/{}'.format(i_epoch, i))
evalate(model, test_data)
model.train()
def evalate(model, test_data):
model.eval()
total, correct = 0, 0
with torch.no_grad():
for i, data_iter in enumerate(test_data):
img, label = [x.cuda() for x in data_iter]
out = model(img)
_, out_label = torch.max(out, 1)
total += label.size(0)
correct += (out_label == label).sum().item()
print("acc of total {} images:{}%".format(total, 100.0 * correct/total))
|