from pathlib import Path
import requests
import pickle
import gzip
from matplotlib import pyplot
import numpy as np
import tensorflow as tf
import keras
from keras import layers
DATA_PATH = Path("data")
Path = DATA_PATH / "mnist"
Path.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/data/mnist"
FILENAME = "mnist.pkl.gz"
# if not (Path/FILENAME).exists():
# content = requests.get(URL+FILENAME).content
# (Path / FILENAME).open("wb").write(content)
with gzip.open((Path/FILENAME).as_posix(), "rb") as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
# print(x_train[0])
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
pyplot.show()
print(x_train[0].shape)
# 构建网络
model = keras.Sequential()
model.add(layers.Dense(32, activation="relu"))
model.add(layers.Dense(32, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_valid, y_valid))
# 重新训练
train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train = train.batch(32)
train = train.repeat()
# 输出一下看看
print(train)
valid = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
valid = valid.batch(32)
valid = valid.repeat()
# 保存权重参数与网络模型
model.fit(train, epochs=5, steps_per_epoch=100, validation_data=valid, validation_steps=100)
model.save("figure_model.h5")
# 读模型 model = kears.models.load_model('figure_model.h5')
config = model.to_json()
with open('config.json', 'w') as json:
json.write(config)
# 读模型 model = kears.models.model_from_json(config.json)
# 读权重参数 weights = model.get_weights()
# 保存权重参数 model.save_weights('weights.h5')
|