数据集下载
(Windows系统)使用keras导入cifar10数据会自动下载(https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz),但是速度很慢,可以是用迅雷自己下载,然后将下载完的文件改名为cifar-10-batches-py.tar.gz,然后复制到C:\Users\你得用户名\.keras\datasets文件夹下。
分类释义
0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'
数据集查看
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import logging
tf.get_logger().setLevel(logging.ERROR)
cifar_dataset = keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar_dataset.load_data()
# 打印训练集长度
print('train:',len(train_images))
# 打印测试集长度
print('test :',len(test_images))
# 打印训练集形状
print('train_image :',train_images.shape)
# 打印测试集形状
print('test_image :',test_images.shape)
# 打印下标100的训练标签的分类
print('Category: ', train_labels[100])
plt.figure(figsize=(1, 1))
plt.imshow(train_images[100])
plt.show()
?
?
|