- 参考:《动手学深度学习》(Pytorch)版 3.5 节
- 注:本文是 jupyter notebook 文档转换而来,部分代码可能无法直接复制运行!
1. 获取数据集
-
通过 torchvision.datasets.FashionMNIST 方法获取数据集 mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())
参数说明
-
root 参数指定数据集保存路径 -
train 参数指定获取训练集还是测试集 -
download 参数若设置为 True ,则在发现 root 路径下没有数据集时自动从网上下载,若已有数据集则不动作 -
transform = transforms.ToTensor() 使所有数据转换为 Tensor ,如果不转换则返回的是 PIL 图片
transforms.ToTensor() 将 “尺寸为
H
×
W
×
C
H \times W \times C
H×W×C 且数据位于
[
0
,
255
]
[0, 255]
[0,255] 的PIL图片” 或者 “数据类型为 np.uint8 的NumPy数组” 转换为 “尺寸为
C
×
H
×
W
C \times H \times W
C×H×W 且数据类型为 torch.float32 且位于 [0.0, 1.0] 的Tensor”
注意 transforms.ToTensor() 在内的一些关于图片的函数默认输入为 uint8 类型,如果不是则可能得到不想要的结果,所以如果用
[
0
,
255
]
[0,255]
[0,255] 的像素值表示图片数据,则一律将其类型设置为 uint8 ,以免不必要的bug -
这里加载的 mnist_train 和 mnist_test 都是 torch.utils.data.Dataset 的子类,一些常用方法如下 print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
feature, label = mnist_train[0]
print(feature.shape, label)
'''
torchvision.datasets.mnist.FashionMNIST
60000 10000
torch.Size([1, 28, 28]) 9
'''
-
Fashion-MNIST中一共包括了10个类别,分别为
- t-shirt(T恤)
- trouser(裤子)
- pullover(套衫)
- dress(连衣裙)
- coat(外套)
- sandal(凉鞋)
- shirt(衬衫)
- sneaker(运动鞋)
- bag(包)
- ankle boot(短靴)
使用以下函数将数值标签列表转成相应的文本标签列表 def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
-
使用以下函数在一行里绘制多个图像和对应的标签 def show_fashion_mnist(images, labels):
display.set_matplotlib_formats('svg')
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
-
随机显示 10 个样本 X, y = [], []
for i in np.random.randint(0,60000,size = 10).tolist():
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
这里我遇到一个报错,请参考 ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’,我删除了虚拟环境中的 libiomp5md.dll 解决此问题
2. 读取小批量
-
在实践中,数据读取经常是训练的性能瓶颈,torch.utils 模块提供的 DataLoader 方法允许我们方便地使用多进程来加速数据读取 -
mnist_train 是 torch.utils.data.Dataset 的子类,所以我们可以将其传入 torch.utils.data.DataLoader 来创建一个读取小批量数据样本的DataLoader 实例,在创建时
- 通过参数
num_workers 来指定读取数据的进程数量 - 通过
shuffle 参数指定读取时是否打乱 batch_size = 256
if sys.platform.startswith('win'):
num_workers = 4
else:
num_workers = 0
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
-
查看读取一遍数据的耗时 start = time.time()
for X, y in train_iter:
continue
print('%.2f sec' % (time.time() - start))
经测试,我的笔记本电脑在不使用多进程加速时耗时 5.88s,使用后减少到 3.18s
|