MindSpore中对于常见数据及已经有现成API来进行处理,常见数据集包括:CelebA、Cifar100、Cifar10、Coco、ImageNet、Minist、VOC
下面以Cifar10数据集作为例子展示一下接口调用及数据的图片的展示
以下为官网提供的API的接口调用及解释:
class?mindspore.dataset. Cifar10Dataset (dataset_dir,?usage=None,?num_samples=None,?num_parallel_workers=None,?shuffle=None,?sampler=None,?num_shards=None,?shard_id=None,?cache=None)
A source dataset for reading and parsing Cifar10 dataset. This api only supports parsing Cifar10 file in binary version now. #用于读取和解析 Cifar10 数据集的源数据集,这个api现在只支持解析二进制版本的Cifar10文件
The generated dataset has two columns?[image,?label] . The tensor of column?image ?is of the uint8 type. The tensor of column?label ?is a scalar of the uint32 type #?生成的数据集有两列[image,label]。列图像的张量为uint8类型。列标签的张量是uint32类型的标量
接口调用及图片展示:
import mindspore.dataset as ds
from PIL import Image
import matplotlib.pyplot as plt
sampler = ds.SequentialSampler(num_samples=6)
dataset = ds.Cifar10Dataset(data_dir, sampler=sampler)
# 在数据集上创建迭代器,检索到的数据将是字典数据类型
for i, data in enumerate(dataset.create_dict_iterator()):
print("Image shape: {}".format(data['image'].shape), ", Label {}".format(data['label']))
image = data['image']
image = image.asnumpy() # mindspore.Tensor to numpy
image = Image.fromarray(image)
# plt
plt.subplot(2, 3, i + 1)
plt.imshow(image)
plt.title(f"{i + 1}", fontsize=6)
plt.xticks([])
plt.yticks([])
plt.show()
结果展示:
?
|