IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【TensorFlow2.x】Keras高层接口 -> 正文阅读

[人工智能]【TensorFlow2.x】Keras高层接口

TensorFlow2.x学习笔记—Keras高层接口

TensorFlow2.x版本中,Keras被正式确定为TensorFlow的高层API唯一接口,取代了TensorFlow1.x版本中自带的tf.layers等高层接口。也就是说,现在只能使用Keras的接口来完成TensorFlow层方式的模型搭建与训练。在TensorFlow中,Keras被实现在 tf.keras子模块中。对于使用TensorFlow的开发者来说,tf.keras可以理解为一个普通的子模块,与其他子模块,如tf.mathtf.data等并没有什么差别。


1. 常见功能模块

  • 常见数据集加载函数
  • 网络层类
  • 模型容器
  • 损失函数类
  • 优化器类
  • 经典模型类

1.1 常见数据集加载函数

在这里插入图片描述

该路径下面有一个mnist.py文件

from tensorflow.keras.datasets.mnist import load_data
data = load_data("mnist.npz")
x_train, y_train = data[0][0], data[0][1]
x_test, y_test = data[1][0], data[1][1]

在这里插入图片描述


1.2 网络层类

import tensorflow as tf
from tensorflow.keras import layers
x = tf.constant([2., 1.])
layer = layers.Softmax(axis = -1)
layer(x)

1.3 网络容器

Keras网络容器Sequential将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序运算。

from tensorflow.keras import layers, Sequential
network = Sequential([layers.Dense(3, activation = None),
                      layers.ReLU(),
                      layers.Dense(2, activation = None),
                      layers.ReLU()])
x = tf.random.normal([4, 3])
network(x)

追加网络层

layer_num = 2
network = Sequential([])
for _ in range(layer_num):
    network.add(layers.Dense(3))
    network.add(layers.ReLU())
network.build(input_shape = (None, 4))
network.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 3)                 15        		(4 * 3 + 3)
_________________________________________________________________
re_lu (ReLU)                 (None, 3)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 3)                 12        		(3 * 3 + 3)
_________________________________________________________________
re_lu_1 (ReLU)               (None, 3)                 0         
=================================================================
Total params: 27
Trainable params: 27
Non-trainable params: 0
_________________________________________________________________
for p in network.trainable_variables:
    print(p.name, p.shape)

dense_2/kernel:0 (4, 3)
dense_2/bias:0 (3,)
dense_3/kernel:0 (3, 3)
dense_3/bias:0 (3,)

2. 模型装配、训练与测试

2.1 模型装配

  • keras.Model

  • keras.layers.Layer

network = Sequential([layers.Dense(256, activation = "relu"),
                      layers.Dense(128, activation = "relu"),
                      layers.Dense(64, activation = "relu"),
                      layers.Dense(32, activation = "relu"),
                      layers.Dense(10)])
network.build(input_shape = (None, 28 * 28))
network.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_15 (Dense)             (None, 256)               200960    
_________________________________________________________________
dense_16 (Dense)             (None, 128)               32896     
_________________________________________________________________
dense_17 (Dense)             (None, 64)                8256      
_________________________________________________________________
dense_18 (Dense)             (None, 32)                2080      
_________________________________________________________________
dense_19 (Dense)             (None, 10)                330       
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
________________________________________________________________

782 × 256 + 256 = 200960 782\times 256 + 256 = 200960 782×256+256=200960

256 × 128 + 128 = 32896 256\times 128 + 128 = 32896 256×128+128=32896

128 × 64 + 64 = 8256 128\times 64 +64 = 8256 128×64+64=8256

64 × 32 + 32 = 2080 64\times 32+ 32 = 2080 64×32+32=2080

32 ? 10 + 10 = 330 32 * 10 +10 = 330 32?10+10=330


  • 通过compile()函数指定网络使用的优化器对象,损失函数,评价指标等
from tensorflow.keras import optimizers, losses
network.compile(optimizer = optimizers.Adam(lr = 0.01),
                loss = losses.CategoricalCrossentropy(from_logits = True),
                metrics = ["accuracy"])

2.2 模型训练

  • 通过fit()函数送入待训练的数据和验证用的数据集
history = network.fit(train, epochs = 5, validation_data = val, validation_freq = 2)
history.history  # 打印训练记录

2.3 模型测试

  • 通过Model.predict(x)方法完成模型的预测
x, y = next(iter(db_test))
print("predict x:", x.shape)
out = network.predict(x)
print(out)
# network.evaluate(db_test)

2.4 模型保存与加载

  • Tensor方式
# 保存模型参数到文件上
network.save_weights("weights.ckpt")
print("saved weights.")
del network

# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation = "relu"),
                      layers.Dense(128, activation = "relu"),
                      layers.Dense(64, activation = "relu"),
                      layers.Dense(32, activation = "relu"),
                      layers.Dense(10)])
network.compile(optimizer = optimizers.Adam(lr = 0.01),
                loss = tf.losses.CategoricalCrossentropy(from_logits = True),
                metrics = ['accuracy'])

# load
network.load_weights("weights.cpkt")
print("loaded weights!")
  • 网络方式
network.save("model.h5")
print("saved total model.")
del network

network = tf.keras.models.load_model("model.h5")
  • Save Model 方式
# 保存模型结构与模型参数到文件
tf.keras.experimental.export_saved_model(network, 'model-savedmodel')
print('export saved model.')
del network # 删除网络对象

# 从文件恢复网络结构与网络参数
network = tf.keras.experimental.load_from_saved_model('model-savedmodel')

2.5 自定义类

  • 创建自定义网络层类,需要继承自 layers.Layer 基类
  • 创建自定义的网络类,需要继承自keras.Model 基类

P180


2.6 模型乐园

P181


2.7 测量工具

  • 新建测量器
from tensorflow.keras import metrics
loss_meter = metrics.Mean()
  • 写入数据
loss_meter.update_state(float(loss))
  • 读取统计数据
print(step, "loss:", loss_meter.result())
  • 清零测量器
if step % 100 == 0:
    print(step, "loss:", loss_meter.result())
    loss_meter.reset_states()

实战

acc_meter = metrics.Accuracy()
out = network(x)
pred = tf.argmax(out, axis = 1)
pred = tf.cast(pred, dtype = tf.int32)
acc_meter.update_state(y, pred)

print(step, "Evaluate Acc:", acc_meter.result().numpy())
acc_meter.reset_states()

2.8 可视化

  • 模型端
# 创建监控类,监控数据将写入 log_dir 目录
summary_writer = tf.summary.create_file_writer(log_dir)

with summary_writer.as_default(): 
    # 当前时间戳 step 上的数据为 loss,写入到 ID 位 train-loss 对象中
    tf.summary.scalar('train-loss', float(loss), step=step)

with summary_writer.as_default():
    # 写入测试准确率
    tf.summary.scalar('test-acc', float(total_correct/total), step=step)
    # 可视化测试用的图片,设置最多可视化 9 张图片
    tf.summary.image("val-onebyone-images:", val_images, max_outputs=9, step=step)

P185

tensorboard --logdir path

with summary_writer.as_default(): 
    # 当前时间戳 step 上的数据为 loss,写入到 ID 为 train-loss 对象中
    tf.summary.scalar('train-loss', float(loss), step=step) 
    # 可视化真实标签的直方图分布
    tf.summary.histogram('y-hist', y, step=step)
    # 查看文本信息
    tf.summary.text('loss-text', str(float(loss)))

Facebook 的 Visdom

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-23 15:46:04  更:2021-12-23 15:48:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 21:05:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码