IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> TensorFlow 从入门到精通(5)—— 多层神经网络与应用 -> 正文阅读

[人工智能]TensorFlow 从入门到精通(5)—— 多层神经网络与应用

一、数据集

import tensorflow as tf
tf.__version__
'2.6.0'
# 导入数据集
mnist = tf.keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
train_images.shape,test_images.shape,train_labels.shape
((60000, 28, 28), (10000, 28, 28), (60000,))
# 展示图片
import matplotlib.pyplot as plt

def plot_image(image):
  plt.imshow(image,cmap='binary')
  plt.show()

plot_image(train_images[0])

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gaUufC01-1632840460449)(output_4_0.png)]

# 划分数据集
total_num = len(train_images)
split_valid = 0.2
train_num = int((1 - split_valid) * total_num)

# 训练集
train_x = train_images[:train_num]
train_y = train_labels[:train_num]
# 验证集
valid_x = train_images[train_num:]
valid_y = train_labels[train_num:]
# 测试集
test_x = test_images
test_y = test_labels
# 数据塑形+归一化
train_x = tf.cast(train_x.reshape(-1,784)/255.0,dtype=tf.float32)
valid_x = tf.cast(valid_x.reshape(-1,784)/255.0,dtype=tf.float32)
test_x = tf.cast(test_x.reshape(-1,784)/255.0,dtype=tf.float32)

# 标签进行独热编码
train_y = tf.one_hot(train_y,10)
valid_y = tf.one_hot(valid_y,10)
test_y = tf.one_hot(test_y,10)

二、模型

在这里插入图片描述

# 构建模型d
# 定义第一层隐藏层权重和偏执项变量
Input_Dim = 784
H1_NN = 64
W1 = tf.Variable(tf.random.normal(shape=(Input_Dim,H1_NN)),dtype=tf.float32)
B1 = tf.Variable(tf.zeros(H1_NN),dtype=tf.float32)
# 定义输出层权重和偏执项变量
Output_Dim = 10
W2 = tf.Variable(tf.random.normal(shape=(H1_NN,Output_Dim)),dtype=tf.float32)
B2 = tf.Variable(tf.zeros(Output_Dim),dtype=tf.float32)
# 待优化列表
W = [W1,W2]
B = [B1,B2]
# 定义模型的前向计算
def model(w,x,b):
  x = tf.matmul(x,w[0]) + b[0]
  x = tf.nn.relu(x)
  x = tf.matmul(x,w[1]) + b[1]
  return tf.nn.softmax(x)
# 损失函数
def loss(w,x,y,b):
  pred = model(w,x,b)
  loss_ = tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)
  return tf.reduce_mean(loss_)
# 准确率
def accuracy(w,x,y,b):
  pred = model(w,x,b)
  acc = tf.equal(tf.argmax(pred,axis=1),tf.argmax(y,axis=1))
  return tf.reduce_mean(tf.cast(acc,dtype=tf.float32))
# 计算梯度
def grad(w,x,y,b):
  with tf.GradientTape() as tape:
    loss_ = loss(w,x,y,b)
  return tape.gradient(loss_,[w[0],b[0],w[1],b[1]])

三、训练

# 定义超参数
train_epochs = 20
learning_rate = 0.01
batch_size = 50
total_steps = train_num // batch_size
train_loss_list = []
valid_loss_list = []
trian_acc_list = []
valide_acc_list = []
# 优化器
optimizer = tf.optimizers.Adam(learning_rate=learning_rate)
for epoch in range(train_epochs):
  for step in range(total_steps):
    xs = train_x[step*batch_size:(step+1)*batch_size]
    ys = train_y[step*batch_size:(step+1)*batch_size]
    grads = grad(W,xs,ys,B)
    optimizer.apply_gradients(zip(grads,[W[0],B[0],W[1],B[1]]))
  trian_loss = loss(W,train_x,train_y,B).numpy()
  valid_loss = loss(W,valid_x,valid_y,B).numpy()
  train_accuracy = accuracy(W,train_x,train_y,B).numpy()
  valid_accuracy = accuracy(W,valid_x,valid_y,B).numpy()
  trian_acc_list.append(train_accuracy)
  valide_acc_list.append(valid_accuracy)
  train_loss_list.append(trian_loss)
  valid_loss_list.append(valid_loss)
  print(f'{epoch+1}:trian_loss:{trian_loss}valid_loss:{valid_loss}train_accuracy:{train_accuracy}valid_accuracy:{valid_accuracy}')
1:trian_loss:4.090484142303467valid_loss:4.0961079597473145train_accuracy:0.7324583530426025valid_accuracy:0.731083333492279
2:trian_loss:3.873914957046509valid_loss:3.8966963291168213train_accuracy:0.7461875081062317valid_accuracy:0.7425000071525574
3:trian_loss:3.698087215423584valid_loss:3.7547082901000977train_accuracy:0.7597083449363708valid_accuracy:0.7545833587646484
4:trian_loss:2.0992202758789062valid_loss:2.1797149181365967train_accuracy:0.8577708601951599valid_accuracy:0.8530833125114441
5:trian_loss:2.0091030597686768valid_loss:2.1187283992767334train_accuracy:0.8645208477973938valid_accuracy:0.8534166812896729
6:trian_loss:2.05008864402771valid_loss:2.162834405899048train_accuracy:0.8585000038146973valid_accuracy:0.8494166731834412
7:trian_loss:1.9510189294815063valid_loss:2.0553224086761475train_accuracy:0.8664166927337646valid_accuracy:0.8576666712760925
8:trian_loss:1.9326006174087524valid_loss:2.050128221511841train_accuracy:0.8680833578109741valid_accuracy:0.8569999933242798
9:trian_loss:1.9068089723587036valid_loss:2.024397850036621train_accuracy:0.8706041574478149valid_accuracy:0.8599166870117188
10:trian_loss:0.4595804512500763valid_loss:0.5628429651260376train_accuracy:0.9586874842643738valid_accuracy:0.949999988079071
11:trian_loss:0.3590681552886963valid_loss:0.5005843043327332train_accuracy:0.9663333296775818valid_accuracy:0.9556666612625122
12:trian_loss:0.29265761375427246valid_loss:0.46133357286453247train_accuracy:0.9728958606719971valid_accuracy:0.9575833082199097
13:trian_loss:0.3250505030155182valid_loss:0.49780264496803284train_accuracy:0.9699791669845581valid_accuracy:0.9567499756813049
14:trian_loss:0.329074889421463valid_loss:0.4836892783641815train_accuracy:0.9683958292007446valid_accuracy:0.9536666870117188
15:trian_loss:0.2734844386577606valid_loss:0.46817922592163086train_accuracy:0.9743750095367432valid_accuracy:0.9578333497047424
16:trian_loss:0.3187606930732727valid_loss:0.5206401944160461train_accuracy:0.9695624709129333valid_accuracy:0.952750027179718
17:trian_loss:0.23391176760196686valid_loss:0.46213391423225403train_accuracy:0.9774166941642761valid_accuracy:0.9605000019073486
18:trian_loss:0.2218097299337387valid_loss:0.41849949955940247train_accuracy:0.9789999723434448valid_accuracy:0.9635000228881836
19:trian_loss:0.2505856156349182valid_loss:0.45410531759262085train_accuracy:0.9771875143051147valid_accuracy:0.9606666564941406
20:trian_loss:0.2279120683670044valid_loss:0.45335933566093445train_accuracy:0.9788125157356262valid_accuracy:0.9618333578109741
accuracy(W,test_x,test_y,B).numpy()
0.959
# 损失图像
plt.plot(train_loss_list,'r')
plt.plot(valid_loss_list,'b')
[<matplotlib.lines.Line2D at 0x7f78f32a60d0>]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pUg4t41Y-1632840460452)(output_21_1.png)]

# 准确率图像
plt.plot(trian_acc_list,'r')
plt.plot(valide_acc_list,'b')
[<matplotlib.lines.Line2D at 0x7f78f328cfd0>]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9O4yxIl2-1632840460454)(output_22_1.png)]

四、预测

def predict(x,w,b):
  pred = model(w,x,b)
  pred_ = tf.argmax(pred,axis=1)
  return pred_
import numpy as np
id = np.random.randint(0,len(test_x)) # 随机生成一个验证id
# 预测值
pred = predict(test_x,W,B)[id]
# 真实值
true = test_labels[id]
print(true,pred.numpy())
1 1
import sklearn.metrics as sm
print(f'r2:{sm.r2_score(test_y,model(W,test_x,B))}')
r2:0.9126431934513113
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-29 10:15:46  更:2021-09-29 10:15:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/21 23:09:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码