CIFAR10数据集介绍
10类,每一类有6000张照片,50000张training,10000张test。
实例
从datasets包中加载数据集,使用transforms包进行变换,通过resize获取图片维度,再把图片转换成tensor,因为pytorch的数据类型都是tensor。 cifar_train一次加载一张,我们需要使用DataLoader加载一次一批。直接覆盖写即可。
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
通过iter方法得到DataLoader的迭代器,再使用迭代器的next方法得到一个batch。
接下来新建一个类,lenet5,卷积神经网络的最简单的一个版本。 第一层是卷积层,第一个卷积层输入维度是照片的维度,cifar是彩色照片。 第二层是subsampling,是一个pooling层。 又一个卷积层,pooling不改变channel,输入依然是6。 全连接层时,输入维度是4维的,我们需要打平,但是pytorch中没有自带的FLatten函数,但是Sequential中需要写既有的类,所以我们写两个unit。
self.conv_unit=nn.Sequential(
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
)
打平 打平后,看结构图是120层,全连接层是Linear。激活函数一般选择sigmod和relu,而sigmod会出现梯度离散现象,所有选择relu。 我们计算一些输入输出值,使用一个例子tmp,送入第一个unit运行。
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
tmp=torch.randn(2,3,32,32)
out=self.conv_unit(tmp)
print('conv_out:',out.shape)
每一个网络结构都需要一个forward前向计算,且不需要backward,自动会有。
def forward(self,x):
batchsz=x.size(0)
x=self.conv_unit(x)
x=x.view(batchsz,16*5*5)
logits=self.fc_unit(x)
return logits
我们取名字叫logits,一般在经过softmax之前的数叫logits。pred和logits的区别在于pred是logits经过softmax操作。
使用loss,我们这是个分类问题,通常使用cross entropy loss。
softmax和loss的操作叫做CELoss
nn上面的类是大写的,F上面的类是小写,两者的区别是nn上面的类先要初始化一下,再在forward里面调用,F里面的类是直接运行的函数,我们可以直接代入数值使用。
total_correct+=torch.eq(pred,label).float().sum().item()
eq函数进行对比,[2 1 1 2 1]的转置与][2 0 0 1 2]的转置进行对比,答案是[1 0 0 0 1]的转置,float后再相加可得2。
完整代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn,optim
from lenet5 import Lenet5
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x,label=iter(cifar_train).next()
print('x:',x.shape,'label:',label.shape)
device = torch.device('cuda')
model=Lenet5().to(device)
criteon=nn.CrossEntropyLoss().to(device)
optimizer=optim.Adam(model.parameters(),lr=1e-3)
print('model:',model)
for epoch in range(1000):
model.train()
for batchsz,(x,label) in enumerate(cifar_train):
x,label=x.to(device),label.to(device)
logits=model(x)
loss=criteon(logits,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch,loss.item())
model.eval()
with torch.no_grad():
total_correct=0
total_num=0
for x,label in cifar_test:
x, label = x.to(device), label.to(device)
logits=model(x)
pred=logits.argmax(dim=1)
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
acc = total_correct/total_num
print('epoch,acc:',epoch ,acc)
if __name__ == '__main__':
main()
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit=nn.Sequential(
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
)
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
tmp=torch.randn(2,3,32,32)
out=self.conv_unit(tmp)
print('conv_out:',out.shape)
self.criteon=nn.CrossEntropyLoss()
def forward(self,x):
batchsz=x.size(0)
x=self.conv_unit(x)
x=x.view(batchsz,16*5*5)
logits=self.fc_unit(x)
return logits
def main():
net=Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet_out:', out.shape)
if __name__ == '__main__':
main()
|