一、MNIST数据集
1.MNIST数据集简介
MNIST数据集是一个公开的数据集,相当于深度学习的hello world,用来检验一个模型/库/框架是否有效的一个评价指标。
MNIST数据集是由0?9手写数字图片和数字标签所组成的,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。MNIST 数据集来自美国国家标准与技术研究所,整个训练集由250个不同人的手写数字组成,其中50%来自美国高中学生,50%来自人口普查的工作人员。
2.MNIST数据集包含四部分
MNIST数据集是集成的API无需手动下载,可以通过torch里面的API直接获取。
官方文档:https://pytorch.org/vision/stable/datasets.html
参数:
- root:指的是下载的目录
- ?train:如果设置成True的话表示取训练集,如果要取测试集就设置成False
- download:如果设置成True,会先判断是否下载过,如果未下载过,就会下载文件;如果已经下载过了,设置成True和False都一样,不会重新下载。
- transform:是对图片进行预处理的一些操作,可以将一个PIL 图片翻译成张量或其他内容? ? ? ? ? ?
from torchvision.datasets import MNIST#获取MNIST的数据集
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=my_transforms)
#root表示下载路径,训练模式取训练集,是否下载:是
注:当我们后期要训练自己的数据集的时候需要将MNIST 数据集换成我们自己的数据集
print(len(mnist_train))#len表示求它的长度,训练集有10000张
print(mnist_train[0])#getitem表示通过索引把图像取出来
运行结果:
60000
(<PIL.Image.Image image mode=L size=28x28 at 0x2175A412320>, 5)
解释说明:
运行结果中60000是训练集的长度,即训练集有60000张图片
第二行返回结果有两部分,一部分是PIL图像,另一部分是图片的标签
标签为5,就表示这个数字是5.
怎样能看一下这张图片是什么呢?
import matplotlib.pyplot as plt#安装matlab库并导入matplotlib.pyplot
from torchvision.datasets import MNIST#获取MNIST的数据集
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None)
#root表示下载路径,训练模式取训练集,是否下载:是
print(len(mnist_train))#len表示求它的长度,训练集有10000张
print(mnist_train[0][0])#getitem表示通过索引把图像取出来
image = mnist_train[0][0]#取出具体的一张图片
plt.imshow(image)
plt.show()#把图片展示出来
print(mnist_train[0][1])#把图片的标签打印出来
运行结果:
60000
<PIL.Image.Image image mode=L size=28x28 at 0x1DAF3D3BB00>
5
图片展示??
遇到OMP报错的话,在代码中添加下面两行代码即可
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#前两行代码解决一个OMP报错
?可以参考下面链接:
OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.OMP: Hint_fencecat的博客-CSDN博客OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.OMP: Hinthttps://blog.csdn.net/fencecat/article/details/122887204?spm=1001.2014.3001.5502
有时即使模型再好,识别率也达不到100%;因为有些数字写的实在太飘逸了,标签也是随心所欲😂
二、数据加载
MNIST数据集继承了torch.utils.data.Dataset
需要自己实现__len__和__getitem__两个方法:
- __len__实现获取数据集长度的操作
- __getitem__实现获取第几个对象的操作,通过索引的方式把图片取出来。
torch已封装好的加载器
前边已经得到MNIST数据集的实例化对象,接下来就可以进行数据的加载,加载器功能较多,如果自己实现的话会比较复杂,我们可以借助torch已经封装好的加载器来处理
官方文档https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None)
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
print(dataloader)#打印一下
运行结果:
<torch.utils.data.dataloader.DataLoader object at 0x00000297C76711D0>
迭代DataLoader类:
# 加载器
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None)
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
# print(dataloader)#打印一下
for i in dataloader:
print(i)
报错:
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
#batch必须包含张量,numpy数组,数字,字典或列表,不支持PIL图像
怎样才能迭代将PIL图像打印呢?需要引入图像处理
三、transforms图像处理
1.导入transforms方法,并将MNIST数据集的transfrom改为transforms.ToTensor()
#图片处理
#导入transforms方法,并将MNIST数据集中transform改为transforms.ToTensor()
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
# print(dataloader)#打印一下
for i in dataloader:
print(i)
运行结果:将PIL图像转换成了张量形式
[tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]), tensor([8, 0])]
2.集合transforms.Compose(transforms)可以将transforms组合起来使用
#图片处理
#导入transforms方法,并将MNIST数据集中transform改为transforms.ToTensor()
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST
my_transforms = transforms.Compose(
[transforms.PILToTensor()])
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
for i in dataloader:
print(i)
exit()#打印一次后退出
运行结果:
[tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.1608, 0.5961, 0.9137, 0.5961, 0.5961,
0.2000, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.7961, 0.9922, 0.9882, 0.9922, 0.9882,
0.9922, 0.6745, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.4000, 0.9961, 0.9922, 0.4000, 0.2392,
0.6392, 0.9529, 0.9176, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.4000, 0.9922, 0.9882, 0.0000, 0.0000,
0.0000, 0.3176, 0.9922, 0.9098, 0.1608, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.4000, 0.9961, 0.9922, 0.0000, 0.0000,
0.0000, 0.0000, 0.5176, 0.9922, 0.6392, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0784, 0.8353, 0.9882, 0.1608, 0.0000,
0.0000, 0.1608, 0.5176, 0.9882, 0.8745, 0.0784, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.3608, 0.9922, 0.8392, 0.2000,
0.4431, 0.9137, 0.7961, 0.7961, 0.3216, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.2000, 0.9882, 0.9922, 0.9882,
0.9922, 0.8314, 0.0784, 0.0784, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7961, 0.9961, 0.9922,
0.5569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6353, 0.9922, 0.9882,
0.0784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.5176, 0.9922, 0.9961, 0.9922,
0.2431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.1608, 0.9922, 0.9882, 0.9922, 0.9882,
0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.6392, 0.9961, 0.6745, 0.5961, 0.9922,
0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0824, 0.8745, 0.8353, 0.0392, 0.2784, 0.9882,
0.7176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.2039, 0.9922, 0.7961, 0.0000, 0.1608, 0.9529,
0.9961, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.5176, 0.9882, 0.7961, 0.0000, 0.0000, 0.7961,
0.9922, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.6000, 0.9922, 0.8000, 0.0000, 0.1216, 0.9137,
1.0000, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.5961, 0.9882, 0.7961, 0.0000, 0.6784, 0.9882,
0.6745, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.3608, 0.9922, 1.0000, 0.9922, 1.0000, 0.9922,
0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0392, 0.5137, 0.8353, 0.9882, 0.9137, 0.2745,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]]]), tensor([8])]
Process finished with exit code 0
# for i in dataloader:
# print(i)
# exit()#打印一次后退出
# for (images, labels) in dataloader:
# print(images, labels)
#for i in dataloader:
# print(i[0], i[1])
?3.transfroms方法
官方文档:torchvision.transforms — Torchvision 0.11.0 documentationhttps://pytorch.org/vision/stable/transforms.html
(1) transfroms简介
- transfroms是一种常用的图像转换方法,他们可以通过Compose方法组合到一起,这样可以实现许多个transfroms对图像进行处理。transfroms方法提供图像的精细化处理,例如在分割任务的情况下? ,你必须建立一个更复杂的转换管道,这时transfroms方法是很有用的。
- 很多转换器既接受PIL图像,也接受tensor图像。一张tensor图像是形状为(C,?H,?W)的张量,这里C表示通道数,H和W 是图像的高和宽。1batch 的tensor图像是一个形状为
(B,?C,?H,?W) ?的张量,这里B表示在batch上有多少张图片。 - transfroms方法处理过后,会把通道移到最前边。比如MNIST h*w*c为:28*28*1,tensor处理完,通道数会提前,并且做了轴交换,变为了c*h*w为:1*28*28,为什么要这样设计呢?据说是做矩阵加减乘除以及卷积等运算是需要用cuda和cudnn的函数的,而这些接口都设成chw格式了。
a. 轴交换
transfroms方法处理过后,如果我们需要把图片转回PIL,需要进行一次轴交换;因为无法处理一个28通道数的图片。
#轴交换之前打印一下图片的形状
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST
my_transforms = transforms.Compose(
[transforms.PILToTensor()]
)
from torch.utils.data import DataLoader
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform= transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)
for (images, labels) in dataloader:
print(images.shape)
exit()#打印一次后退出
#运行结果
torch.Size([1, 1, 28, 28])
#说明:
这里第一个“1”表1 batch_size,即一次加载一张图片
第二个“1”表示通道数,后边两个“28”分别表示图片的高和宽
#使用make_grid方法将两张图片融合
from torchvision.utils import make_grid##即使一张图片我们也要将它融合一下,使用make_grid方法
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST
my_transforms = transforms.Compose(
[transforms.PILToTensor()]
)#将多个transforms组合在一起,还可以加入标准化等图像处理
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
print(dataloader)#打印一下
for (images, labels) in dataloader:
print(make_grid(images).shape)
exit()#打印一次后退出
#运行结果:
<torch.utils.data.dataloader.DataLoader object at 0x00000289D3634438>
torch.Size([3, 28, 28])
Process finished with exit code 0
#结果图像的形状变成了3*28*28
#如果将上述代码中dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)换成
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#运行结果变为torch.Size([3, 32, 62]),相当于把两张图片融合了
b. 使用轴交换边回去
#轴交换
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#前两行代码解决一个OMP报错
from torchvision.utils import make_grid##即使一张图片我们也要将它融合一下,使用make_grid方法
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt#安装matlab库并导入matplotlib.pyplot
my_transforms = transforms.Compose(
[transforms.PILToTensor()]
)#将多个transforms组合在一起,还可以加入标准化等图像处理
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
print(dataloader)#打印一下
for (images, labels) in dataloader:
image = make_grid(images).permute(1, 2, 0).numpy()
#permute(1, 2, 0)实际上就是把通道数移到后边的过程,忘记的话回看第二节课的视频
#轴交换了之后转换成numpy的数组,之后就可以做加载了
plt.imshow(image)
plt.show()
exit()
#运行结果
<torch.utils.data.dataloader.DataLoader object at 0x0000027DCB7A4710>
Process finished with exit code 0
图片展示:
添加代码:print(labels)可以将标签打印出来
这个操作一般只有调试时才会用,正常运算不需要把tensor图像转换成PIL图像再看一下
(2)进阶了解transfroms方法
?参考文档:PyTorch 学习笔记:transforms的二十二个方法(transforms用法非常详细)_liangbaqiang的博客-CSDN博客_transforms.scale
四、模型和优化器
1.简介
模型四深度学习的关键内容,是深度学习的核心。
深度神经网络的种类主要有:
- 传统神经网络CNN
- 卷积网络CNN
- 循环神经网络(递归神经网络)RNN
目前比较流行的深度神经模型几乎都是卷积和循环两种模型的延伸。
2.全连接层:torch.nn.Linear
(1)简介
官方文档:https://pytorch.org/docs/stable/nn.html
对于MNIST数据集这种简单的,且样本数量足够多的项目,一个全连接层就能达到不错的效果。
后期会对这些模型的“层”进行组合实现。有卷积层、池化层、标准化层等等。
?全连接层指的是层中的每个节点都会连接它下一层的所有节点,它是模仿人脑神经结构来构建的。最左边是输入的是图像,实际上就是图像的像素点,全连接层每层之间都是线性关系。
比如:假设输入为、、……,那么与输入层直接相连的中间层就是这样计算来的
,同理可以计算出第二层的、……,同样中间各层之间都有一个权重,下一层的输出都是由上一层的每个输入乘以相应的权重累加得出的。最终得到的是两个输出结果,这是一个二分类的问题。输出几个值几分类问题。
(2)全连接层的实现
#全连接层
#首先我们要新建一个类,这个类要继承nn.Module
class MnistModel(nn.Module):
def __init__(self):#继承__init__方法
super(MnistModel, self).__init__()
self.fc2 = nn.Linear(1*28*28, 10)#最初传入的图片的像素点是1*28*28的,最后我们要收敛成10个结果
#如果先收敛成100个,然后在写一个全连接层
# self.fc2 = nn.Linear(1 * 28 * 28, 100)
# self.fc2 = nn.Linear(100, 10)
#激活函数,激励函数,通过数学手段将线性计算过程进行优化,使其加速。最常用的线性激活函数Relu
self.relu = nn.ReLU()
def forward(self, image):#继承前向传播的方法
image_viwed = image.view(-1, 1*28*28)#此处需要拍平
out = self.fc2(image_viwed)
fcl_out = self.relu(out)#激活函数对应一下
return out
3.优化器
#优化器官方文档:https://pytorch.org/docs/stable/optim.html
(1)简介
- 优化器的作用就是寻求模型最优解,优化器有梯度下降,动量优化,自适应优化等,梯度下降是最原始的,也是最基础的。
- 梯度下降算法,载入数据集,计算所有梯度,然后执行决策。依据是损失函数,通过损失进行每一步计算,梯度下降算法分为:标准梯度下降法、批量梯度下降法和随机梯度下降法。
(2)优化器实现?
from torch import optim#导入优化器
#需要把实例化的模型传进去
model = MnistModel()
optim.Adam(model.parameters(), lr=1e-4)#这是一种自适应的优化器,不需要调参
#lr表示学习率,1e-4表示10的4次方
#优化器官方文档:https://pytorch.org/docs/stable/optim.html
4.损失函数
(1)简介
- 损失函数,设计一个损失函数的计算方法,让他统一一个损失值,算出一个结论,进而判断下次模型要朝着那个方向去优化权重,最终损失函数的选择取决于最终的结果和标签之间的关系。
- 每一种损失函数都对应着一种数学模型计算,目的就是把模型训练结果与标签之间建立起关系,在梯度下降优化器中,让损失不断减小的方向就是训练的方向 #损失函数的实现
(2)损失函数实现
LOST = nn.CTCLoss()#调用nn的损失函数,实例化
LOST(MODEL_RESULT, LABELS)#把模型的结果和标签传进去,得到一个数字就是损失值,就是优化器朝哪个方向去做的一个依据
|