官网地址:https://pytorch.org/vision/stable/index.html
Torchvision 是 PyTorch 的一个视觉处理工具包,独立于PyTorch,需要另外安装
它包括4个类,各类的主要功能如下:
- 1)datasets:提供常用的数据集加载,设计上都是继承自torch.utils.data.Dataset,主要
包括MMIST、CIFAR10/100、ImageNet和COCO等。 - 2)models:提供深度学习中各种经典的网络结构以及训练好的模型(参数选择
pretrained=True),包括AlexNet、VGG系列、ResNet系列、Inception系列等。 - 3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。
- 4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是 save_img,它能将 Tensor 保存成图片
1、torchvision.datasets
1.1 常用数据集加载 MNIST等
举例,通过torchvision下载 MNIST (mnist 全称:mixed national institute of standards and technology database)
train_dataset = torchvision.datasets.MNIST(root,
train=True,
transform=transform,
download=True)
root :需要下载至地址的根目录位置 train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt; 默认是True transform:一系列作用在PIL图片上的转换操作,返回一个转换后的版本 download:是否下载到 root指定的位置,如果指定的root位置已经存在该数据集,则不再下载
1.2 自定义数据集读取 ImageFolder
torchvision.datasets.ImageFolder(root, transform, target_transform, loader)
- root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是 “…/input/data/”
- transform:对图片进行预处理操作(函数),原始图片作为输入,返回一个转换后的图片。
- target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
- loader:表示数据集加载方式,通常默认加载方式即可
另外,该 API 有以下成员变量:
- self.classes:用一个 list 保存类别名称
- self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
- self.imgs:保存(img-path, class) tuple的 list,与我们自定义 Dataset类的 def getitem(self, index): 返回值类似。注意看下面实例中 dataset.imgs 的返回值
举例:数据存储结构如下
import torchvision
from torchvision import transforms, utils
trans = transforms.Compose([transforms.RandomCrop(400), transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder('/Users/manmi/Desktop/data/data', transform=trans)
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)
print(len(dataset))
print(dataset[0][0].size())
print(dataset[0][1])
2、torchvision.models
后补…
3、torchvision.transforms
3.1 对PIL Image的常见操作
1)转换为 tensor ToTensor()
ToTensor() 做了三件事:
- 把灰度范围从0-255 变换到 0-1之间,其将每一个像素值归一化到 [0,1],其归一化方法比较简单,直接除以255即可
- 将 nump.ndarray 或 PIL.Image 转为 tensor,数据类型为 torch.FloatTensor
- 将shape 由 (H,W, C) 转为shape为 (C, H, W)
2)中心裁剪 CenterCrop()
torchvision.transforms.CenterCrop(size) # 所需裁剪的图片尺寸
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
img_src = Image.open('./bird.jpg')
img_1 = transforms.CenterCrop(200)(img_src)
img_2 = transforms.CenterCrop((200, 200))(img_src)
img_3 = transforms.CenterCrop((300, 200))(img_src)
img_4 = transforms.CenterCrop((500, 500))(img_src)
plt.subplot(231)
plt.imshow(img_src)
plt.subplot(232)
plt.imshow(img_1)
plt.subplot(233)
plt.imshow(img_2)
plt.subplot(234)
plt.imshow(img_3)
plt.subplot(235)
plt.imshow(img_4)
plt.show()
以上例子我们可知: (1)如果切正方形,transforms.CenterCrop(100) 和 transforms.CenterCrop((100, 100)),两种写size的方法,效果一样 (2)如果设置的输出的图片尺寸大于原尺寸,会在边上补黑色
3)随机裁剪 RandomCrop()
# 依据给定的size随机裁剪
torchvision.transforms.RandomCrop(size,
padding = None,
pad_if_needed = False,
fill=0,
padding_mode ='constant')
功能: 从图片中随机裁剪出尺寸为 size 的图片,如果有 padding,那么先进行 padding,再随机裁剪 size 大小的图片。
参数:
size :所需裁剪的图片尺寸padding : 设置填充大小
- 当为 a 时,上下左右均填充 a 个像素
- 当为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素
- 当为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
pad_if_needed :当图片小于设置的 size,是否填padding_mode :
- constant: 像素值由 fill 设定 (默认)
- edge: 像素值由图像边缘像素设定
- reflect: 镜像填充,最后一个像素不镜像。([1,2,3,4] -> [3,2,1,2,3,4,3,2])
- symmetric: 镜像填充,最后一个像素也镜像。([1,2,3,4] -> [2,1,1,2,3,4,4,4,3])
fill :当 padding_mode 为 constant 时,设置填充的像素值 (默认为0)
4)其他更多图像变换操作
其他更多的图像变换操作,看这里吧
3.2 对 Tensor 的常见操作
1)归一化 Normalize()
作用: 用均值和标准差对张量图像进行归一化, 公式:
i
m
a
g
e
=
(
i
m
a
g
e
?
m
e
a
n
)
/
s
t
d
image = (image-mean) / std
image=(image?mean)/std
比如,原像素值的取值区间为 [0, 1],在使用 transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) 进行归一化后,原像素值被分布到了 [-1, 1] 区间:
- 原来的 0~1 最小值 0 则变成 (0 - 0.5) / 0.5 = -1
- 最大值1则变成 (1 - 0.5) / 0.5 = 1
其中 mean 和 std 的3个值分表表示图像的3个通道 如果是单通道的灰度图,可以写成 transforms.Normalize(mean=[0.5], std=[0.5])
我们可能会看到很多代码里面是这样的: torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 这一组值是怎么来的呢?答案就是通过数据集,提前抽样计算出来的
2)转换为图像 ToPILImage()
将 Tensor类型数据转换为图片数据 PILImage, torchvision.transforms.ToPILImage() 函数的作用是把Tensor数据变为完整原始的图片数据(保存后可以直接双击打开的那种) 其内部处理过程为:
- 将Tensor的每个元素乘以255
- 将数据由Tensor转化成Uint8
- 将Tensor转化成numpy的ndarray类型
- 对ndarray对象做permute (1, 2, 0)的转置,将shape 由 (C, H, W) 转为shape为(H,W, C)
- 将ndarray对象转化成PILImage数据格式
- 输出该PILImage数据(save后可以直接打开)
4、torchvision.utils
4.1 图像拼接 grid
一行最多展示8张图片
import torch
import torchvision
from torchvision import transforms, utils
from torch.utils import data
import matplotlib.pyplot as plt
trans = transforms.Compose([transforms.RandomCrop(400), transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder('./data', transform=trans)
train_loader = data.DataLoader(dataset, batch_size=2, shuffle=True)
for (img, label) in train_loader:
grid = utils.make_grid(img)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()
break
4.2 tensor存储为图片 save_img
torchvision.utils.save_img(img, path)
image 的数据类型是tensor
|