图像分类数据集—FashionMNIST数据集
①简介:fashionmnist数据集中共有10种类别的服饰,分别为:
['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots']
部分服饰为:
②具体介绍:在该数据集中共有7万张图片,每张图片的形状为:[单通道,长28,宽28],并且每张图片对应一种服饰(一种标签)。其中训练集和测试集的图片是分开的,分别有6万张图片和1万张图片。
③探索FashionMNIST数据集
导入相应的库,并下载数据集
%matplotlib inline
import torch
from IPython import display
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
import matplotlib.pyplot as plt
import time
def use_svg_display():
display.set_matplotlib_formats('svg')
use_svg_display()
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)
数据集的探索
len(mnist_train), len(mnist_test)
mnist_train[0][0].shape
数据集的可视化,结果为简介中的图片
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签。"""
test_labels = ['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots']
return [test_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
images = show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
images
plt.savefig('部分服饰.png', facecolor='white', edgecolor='red')
④导入数据集
把数据集通过函数形式导入到内存中
def load_data_fashion_mnist(batch_size, resize=None):
"""加载Fashion-MNIST数据集到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans)
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans)
return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
解释两个参数的含义: batch_size:我们一次读取多少张图片 resize:是否要对图片进行等比例的放大或缩小。eg: resize=66,则图片的尺寸变为66 x 66
⑤加载数据集
train_iter, test_iter = load_data_fashion_mnist(8, 12)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
结果为:torch.Size([8, 1, 12, 12]) torch.float32 torch.Size([8]) torch.int64 说明:我们一次读取8张图片,每张图片为单通道,尺寸为12 x 12,并且每张图片都有对应的标签,一共8个标签。
⑥查看单张图片
for X, y in test_iter:
print(X[0].tolist(), y[0])
break
结果为: [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.019607843831181526, 0.1411764770746231, 0.019607843831181526, 0.003921568859368563, 0.10196078568696976, 0.062745101749897], [0.0, 0.0, 0.0, 0.0, 0.0, 0.007843137718737125, 0.18431372940540314, 0.4745098054409027, 0.4745098054409027, 0.43921568989753723, 0.47058823704719543, 0.11372549086809158], [0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563, 0.125490203499794, 0.38823530077934265, 0.5333333611488342, 0.6039215922355652, 0.6352941393852234, 0.5803921818733215, 0.1921568661928177], [0.0, 0.003921568859368563, 0.003921568859368563, 0.03529411926865578, 0.14901961386203766, 0.3803921639919281, 0.4588235318660736, 0.5607843399047852, 0.5921568870544434, 0.6117647290229797, 0.5921568870544434, 0.3843137323856354], [0.08235294371843338, 0.1921568661928177, 0.26274511218070984, 0.3607843220233917, 0.4431372582912445, 0.4745098054409027, 0.5254902243614197, 0.5764706134796143, 0.6078431606292725, 0.6078431606292725, 0.6196078658103943, 0.5176470875740051], [0.33725491166114807, 0.47058823704719543, 0.5058823823928833, 0.49803921580314636, 0.5137255191802979, 0.5647059082984924, 0.6078431606292725, 0.6392157077789307, 0.6941176652908325, 0.800000011920929, 0.7686274647712708, 0.5333333611488342], [0.0470588244497776, 0.12156862765550613, 0.24313725531101227, 0.30588236451148987, 0.32156863808631897, 0.3176470696926117, 0.2235294133424759, 0.11764705926179886, 0.20392157137393951, 0.35686275362968445, 0.3176470696926117, 0.20000000298023224], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]] 对应的标签为:tensor(9),说明为第9种类型的服饰
结束!!!??????
完整代码链接:FashionMNIST数据集
|