一、数据集
import tensorflow as tf
tf.__version__
'2.6.0'
mnist = tf.keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
train_images.shape,test_images.shape,train_labels.shape
((60000, 28, 28), (10000, 28, 28), (60000,))
import matplotlib.pyplot as plt
def plot_image(image):
plt.imshow(image,cmap='binary')
plt.show()
plot_image(train_images[0])
total_num = len(train_images)
split_valid = 0.2
train_num = int((1 - split_valid) * total_num)
train_x = train_images[:train_num]
train_y = train_labels[:train_num]
valid_x = train_images[train_num:]
valid_y = train_labels[train_num:]
test_x = test_images
test_y = test_labels
train_x = tf.cast(train_x.reshape(-1,784)/255.0,dtype=tf.float32)
valid_x = tf.cast(valid_x.reshape(-1,784)/255.0,dtype=tf.float32)
test_x = tf.cast(test_x.reshape(-1,784)/255.0,dtype=tf.float32)
train_y = tf.one_hot(train_y,10)
valid_y = tf.one_hot(valid_y,10)
test_y = tf.one_hot(test_y,10)
二、模型
Input_Dim = 784
H1_NN = 64
W1 = tf.Variable(tf.random.normal(shape=(Input_Dim,H1_NN)),dtype=tf.float32)
B1 = tf.Variable(tf.zeros(H1_NN),dtype=tf.float32)
Output_Dim = 10
W2 = tf.Variable(tf.random.normal(shape=(H1_NN,Output_Dim)),dtype=tf.float32)
B2 = tf.Variable(tf.zeros(Output_Dim),dtype=tf.float32)
W = [W1,W2]
B = [B1,B2]
def model(w,x,b):
x = tf.matmul(x,w[0]) + b[0]
x = tf.nn.relu(x)
x = tf.matmul(x,w[1]) + b[1]
return tf.nn.softmax(x)
def loss(w,x,y,b):
pred = model(w,x,b)
loss_ = tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)
return tf.reduce_mean(loss_)
def accuracy(w,x,y,b):
pred = model(w,x,b)
acc = tf.equal(tf.argmax(pred,axis=1),tf.argmax(y,axis=1))
return tf.reduce_mean(tf.cast(acc,dtype=tf.float32))
def grad(w,x,y,b):
with tf.GradientTape() as tape:
loss_ = loss(w,x,y,b)
return tape.gradient(loss_,[w[0],b[0],w[1],b[1]])
三、训练
train_epochs = 20
learning_rate = 0.01
batch_size = 50
total_steps = train_num // batch_size
train_loss_list = []
valid_loss_list = []
trian_acc_list = []
valide_acc_list = []
optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
for epoch in range(train_epochs):
for step in range(total_steps):
xs = train_x[step*batch_size:(step+1)*batch_size]
ys = train_y[step*batch_size:(step+1)*batch_size]
grads = grad(W,xs,ys,B)
optimizer.apply_gradients(zip(grads,[W[0],B[0],W[1],B[1]]))
trian_loss = loss(W,train_x,train_y,B).numpy()
valid_loss = loss(W,valid_x,valid_y,B).numpy()
train_accuracy = accuracy(W,train_x,train_y,B).numpy()
valid_accuracy = accuracy(W,valid_x,valid_y,B).numpy()
trian_acc_list.append(train_accuracy)
valide_acc_list.append(valid_accuracy)
train_loss_list.append(trian_loss)
valid_loss_list.append(valid_loss)
print(f'{epoch+1}:trian_loss:{trian_loss}valid_loss:{valid_loss}train_accuracy:{train_accuracy}valid_accuracy:{valid_accuracy}')
1:trian_loss:4.090484142303467valid_loss:4.0961079597473145train_accuracy:0.7324583530426025valid_accuracy:0.731083333492279
2:trian_loss:3.873914957046509valid_loss:3.8966963291168213train_accuracy:0.7461875081062317valid_accuracy:0.7425000071525574
3:trian_loss:3.698087215423584valid_loss:3.7547082901000977train_accuracy:0.7597083449363708valid_accuracy:0.7545833587646484
4:trian_loss:2.0992202758789062valid_loss:2.1797149181365967train_accuracy:0.8577708601951599valid_accuracy:0.8530833125114441
5:trian_loss:2.0091030597686768valid_loss:2.1187283992767334train_accuracy:0.8645208477973938valid_accuracy:0.8534166812896729
6:trian_loss:2.05008864402771valid_loss:2.162834405899048train_accuracy:0.8585000038146973valid_accuracy:0.8494166731834412
7:trian_loss:1.9510189294815063valid_loss:2.0553224086761475train_accuracy:0.8664166927337646valid_accuracy:0.8576666712760925
8:trian_loss:1.9326006174087524valid_loss:2.050128221511841train_accuracy:0.8680833578109741valid_accuracy:0.8569999933242798
9:trian_loss:1.9068089723587036valid_loss:2.024397850036621train_accuracy:0.8706041574478149valid_accuracy:0.8599166870117188
10:trian_loss:0.4595804512500763valid_loss:0.5628429651260376train_accuracy:0.9586874842643738valid_accuracy:0.949999988079071
11:trian_loss:0.3590681552886963valid_loss:0.5005843043327332train_accuracy:0.9663333296775818valid_accuracy:0.9556666612625122
12:trian_loss:0.29265761375427246valid_loss:0.46133357286453247train_accuracy:0.9728958606719971valid_accuracy:0.9575833082199097
13:trian_loss:0.3250505030155182valid_loss:0.49780264496803284train_accuracy:0.9699791669845581valid_accuracy:0.9567499756813049
14:trian_loss:0.329074889421463valid_loss:0.4836892783641815train_accuracy:0.9683958292007446valid_accuracy:0.9536666870117188
15:trian_loss:0.2734844386577606valid_loss:0.46817922592163086train_accuracy:0.9743750095367432valid_accuracy:0.9578333497047424
16:trian_loss:0.3187606930732727valid_loss:0.5206401944160461train_accuracy:0.9695624709129333valid_accuracy:0.952750027179718
17:trian_loss:0.23391176760196686valid_loss:0.46213391423225403train_accuracy:0.9774166941642761valid_accuracy:0.9605000019073486
18:trian_loss:0.2218097299337387valid_loss:0.41849949955940247train_accuracy:0.9789999723434448valid_accuracy:0.9635000228881836
19:trian_loss:0.2505856156349182valid_loss:0.45410531759262085train_accuracy:0.9771875143051147valid_accuracy:0.9606666564941406
20:trian_loss:0.2279120683670044valid_loss:0.45335933566093445train_accuracy:0.9788125157356262valid_accuracy:0.9618333578109741
accuracy(W,test_x,test_y,B).numpy()
0.959
plt.plot(train_loss_list,'r')
plt.plot(valid_loss_list,'b')
[<matplotlib.lines.Line2D at 0x7f78f32a60d0>]
plt.plot(trian_acc_list,'r')
plt.plot(valide_acc_list,'b')
[<matplotlib.lines.Line2D at 0x7f78f328cfd0>]
四、预测
def predict(x,w,b):
pred = model(w,x,b)
pred_ = tf.argmax(pred,axis=1)
return pred_
import numpy as np
id = np.random.randint(0,len(test_x))
pred = predict(test_x,W,B)[id]
true = test_labels[id]
print(true,pred.numpy())
1 1
import sklearn.metrics as sm
print(f'r2:{sm.r2_score(test_y,model(W,test_x,B))}')
r2:0.9126431934513113
|