构建神经网络步骤总结: 1、加载数据 使用的头文件:from torchvision import datasets, transforms 加载数据分为两部分:
- 加载训练集
- 加载测试集
二者差别不大,唯一却别就是训练集要将训练参数为true,测试集为false;
batchsz=32
cifar_train=datasets.CIFAR10('cifar',True,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True)
cifar_train=DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x,label=iter(cifar_train).__next__()
```python
在这里插入代码片
```python
为了可以加载多张图片,使用dataloader()多线程来加载多张图片
cifar_train=DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
第二步:构建网络结构 暂时还不清楚怎么设计网络,目前能做的是看图说话: 接下来构建如下图网络:
上图构建过程分为两步:
self.conv_unit=nn.Sequential(
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(6,16,kernel_size=5,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
)
全连接层构建过程:
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
随后将网络串联起来,调用forward()函数完成此效果,神经网络中forward()函数是必备的,可以说是神经网络使用说明书,通过随forward()的运行将网络结构串起来,另外也有backward()函数,但是一般不会初始化,可以直接调用,因为forward()存在,在运行的过程中,系统自动保存回溯轨迹;
def forward(self,x):
'''
:param x: [b,3,32,32]
:return:
'''
batchsz=x.size(0)
x=self.conv_unit(x)
x=x.view(batchsz,16*5*5)
logits=self.fc_unit(x)
return logits
整个网络代码:
import torch
from torch import nn
from torch.nn import functional as F
class lenet5(nn.Module):
def __init__(self):
super(lenet5,self).__init__()
self.conv_unit=nn.Sequential(
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(6,16,kernel_size=5,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
)
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
def forward(self,x):
'''
:param x: [b,3,32,32]
:return:
'''
batchsz=x.size(0)
x=self.conv_unit(x)
x=x.view(batchsz,16*5*5)
logits=self.fc_unit(x)
return logits
3、神经网路三步:
a)、调用GPU模块: b)、调用模型 c)、调用误差、优化器
device=torch.device('cuda')
model=ResNet18().to(device)
criteon=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=1e-3)
4、训练以及测试网络
for epoch in range(1000):
model.train()
for batchidx,(x,label) in enumerate(cifar_train):
x,label=x.to(device),label.to(device)
logirs=model(x)
loss=criteon(logirs,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch ,loss.item())
model.eval()
with torch.no_grad():
total_correct=0
total_num=0
for x,label in cifar_test:
x,label=x.to(device),label.to(device)
logits=model(x)
pred=logits.argmax(dim=1)
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
acc=total_correct/total_num
print(epoch,acc)
在训练以及测试的过程中,要记得转化模型的训练和测试模式; 运行效果:
Files already downloaded and verified
Files already downloaded and verified
x torch.Size([32, 3, 32, 32]) label torch.Size([32])
lenet5(
(conv_unit): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): AvgPool2d(kernel_size=2, stride=2, padding=0)
(2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(3): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(fc_union): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=10, bias=True)
)
)
0 1.610519528388977
0 0.4228
1 1.4471698999404907
1 0.4843
2 1.3781787157058716
2 0.5167
3 1.4658881425857544
3 0.5177
4 0.8951991200447083
4 0.5259
5 1.1710333824157715
5 0.5387
6 1.562530279159546
6 0.5442
7 0.8673275113105774
7 0.5443
8 1.4436936378479004
8 0.5421
9 1.3614578247070312
9 0.5434
10 1.341037392616272
10 0.5531
11 1.10409414768219
11 0.5487
12 0.8848217725753784
12 0.5474
13 0.980401337146759
13 0.5486
14 1.0723121166229248
14 0.5463
15 0.8616059422492981
15 0.5471
16 1.2519404888153076
16 0.5482
17 1.1916581392288208
17 0.5415
18 0.7907223105430603
18 0.5441
19 0.43285447359085083
19 0.5436
20 0.7941941618919373
20 0.5457
21 0.6954908967018127
21 0.5415
22 1.2862902879714966
22 0.5441
23 0.46743243932724
23 0.5439
24 1.0503959655761719
24 0.5413
25 1.3914414644241333
25 0.5507
26 0.933451235294342
26 0.5387
27 0.7994612455368042
27 0.5381
28 0.8405050039291382
28 0.5352
29 0.9237154722213745
29 0.5325
30 1.0469434261322021
30 0.5349
31 1.2612384557724
31 0.5267
32 0.9177519679069519
32 0.5325
33 1.0856417417526245
33 0.5257
34 0.6346448063850403
34 0.5315
35 1.3680710792541504
35 0.5263
36 0.5967005491256714
36 0.5231
37 1.1949267387390137
37 0.5285
38 0.9418867826461792
38 0.527
|