在《深度学习入门:基于Python的理论与实现》章节的第三章就开始以MNIST数据集为基础编写代码。然而根据源码的操作,很有可能会出现mnist下载超时的情况。以下是解决方案:
1. 获取代码读取数据集的路径
以mnist_show.py 为例: mnist_show.py源码:
import sys, os
sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist
from PIL import Image
def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
img = x_train[0]
label = t_train[0]
print(label)
print(img.shape)
img = img.reshape(28, 28)
print(img.shape)
img_show(img)
然后就会执行load_mnist 函数。 dataset目录内mnist.py 文件的 load_mnist 函数代码的开头:
print("!!!!",save_file)
if not os.path.exists(save_file):
init_mnist()
然后打印save_file 变量,获得路径地址,一般都是默认的dataset目录内。
2. 手动下载MNIST数据集
MNIST数据集的官网,下载:
train-images-idx3-ubyte: training set images
train-labels-idx1-ubyte: training set labels
t10k-images-idx3-ubyte: test set images
t10k-labels-idx1-ubyte: test set labels
保存到第一步中获取的路径。
3. 执行代码
一般都可以顺利执行了:
PS.如果自己来复现项目,一定记得要复制dataset 目录到自己项目的目录下。
|