当时看老师留的作业的时候发现要求读取历史数据,但网上有没找到,我自己找了手册啥的,查到了几个参数,就分享下,希望各位大佬不要看不上哈...
那就先把程序从零实现:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import keras
from tensorflow.keras import *
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], True)
file_data = np.load(r'E:/mnist.npz')
x_train= file_data['x_train']
y_train= file_data['y_train']
x_test= file_data['x_test']
y_test= file_data['y_test']
先读取一些包,然后设置自己的gpu,如果预先下载好数据集的话可以放在自己固定放程序的文件夹内,修改上面的绝对路径就好。
下面是参数的初始化,将输入变成0.0-1.0之间浮点数,便于运算和加速收敛(当然,计算性能现今提升了好多,这个小程序没啥影响。)当时我们同学有人整个程序运行完会报错,提示Dim不匹配,如果不匹配的话就把下面注释的取消掉,reshape一下就解决了。
X_train = tf.cast(x_train, dtype=tf.float32)/ 255.0
X_test = tf.cast(x_test, dtype=tf.float32)/ 255.0
# X_train = X_train.reshaper((60000, 28, 28, 1))
# X_test = X_train.reshaper((10000, 28, 28, 1))
print (X_train.shape)
下面就是固定的那些格式了,没什么东西,都API什么的,主要是Model.fit返回一个参数,用来可视化,他是保存历史数据的,ACC和Loss,程序里面的那些超参数都可以修改,当然,那些优化方法我用的是'adam',loss是‘sparse_categorical_crossentropy’,同理都可以修改的。
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, kernel_size= (3,3), padding = 'same',
activation= tf.nn.relu, input_shape= (28, 28, 1)),
tf.keras.layers.MaxPool2D(pool_size = (2,2)),
tf.keras.layers.Conv2D(16, kernel_size= (3,3), padding = 'same',activation= tf.nn.relu),
tf.keras.layers.MaxPool2D(pool_size= (2,2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation= 'relu'),
tf.keras.layers.Dense(10, activation= 'softmax')
])
print (model.summary())
model.compile (optimizer= 'adam',
loss='sparse_categorical_crossentropy',
metrics=["sparse_categorical_accuracy"])
history = model.fit(X_train, y_train, batch_size = 100, epochs =5,
validation_split=0.20, shuffle=True)
model.evaluate(X_test, y_test,verbose= 2)
History返回了4个参数,分别是['loss', 'sparse_categorical_accuracy', 'val_loss', 'val_sparse_categorical_accuracy'],就是训练集的损失和精确度,和测试集的损失和精度。
下面是用matplotlib可视化,这个就没什么东西了,显示一下训练的结果。
plt.figure()
plt.subplot(2,2,1)
plt.title('Loss')
plt.plot(history.history['loss'],label='Validation bit_error')
plt.subplot(2,2,2)
plt.title('Sparse Categorical Accuracy')
plt.plot(history.history['sparse_categorical_accuracy'],label='Validation bit_error')
plt.subplot(2,2,3)
plt.title('Val Loss')
plt.plot(history.history['val_loss'],label='Validation bit_error')
plt.subplot(2,2,4)
plt.title('Val Sparse Categorical Accuracy')
plt.plot(history.history['val_sparse_categorical_accuracy'],label='Validation bit_error')
plt.show()
下面是显示的结果和图形。
?
?好了,这里就结束了,谢谢观看,嘿嘿
|