IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> Colab Tensorboard 批次(batch)级数据显示 -> 正文阅读

[Python知识库]Colab Tensorboard 批次(batch)级数据显示

前言

太坑了。。查了很多资料,终于解决了。

方法

首先,根据官方API的参数列表得知,在tf.keras.callbacks.TensorBoard中修改update_freq参数为batch或一个整数

update_freq='batch'

# update_freq=10
# 如果改为使用一个整数N的话,过N批次后更新一次数据,
# 这样可以避免由于更新过于频繁而降低网络训练速度

然后,根据这则帖子,由于TensorFlow 2.3做了一个优化,导致上面的方法在这里不管用。

解决方法是,除了TensorBoard之类的callback以外,再添加一个LambdaCallback,具体代码如下:

   def batchOutput(batch, logs):
       tf.summary.scalar('batch_loss', data=logs['loss'], step=batch)
       tf.summary.scalar('batch_accuracy', data=logs['accuracy'], step=batch)
       return batch
       
   batch_log_callback = callbacks.LambdaCallback(
       on_batch_end=batchOutput)

于是终于成功
在这里插入图片描述在这里插入图片描述

示例代码

改完后,我的训练部分的完整代码是这样的:

def train_model(save:bool=True):
    # load and compile model
    model = create_model()    
    model.compile(
        loss='mean_squared_error',
        optimizer='adam',
        metrics=['accuracy'])

    # prepare tensorflow dashboard
    logdir = os.path.join(
        'logs', 
        datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    tensorboard_callback = callbacks.TensorBoard(
        logdir, 
        histogram_freq=1, 
        write_images=True, 
        update_freq=10, # 查看批次级别的数据变化,需要结合LambdaCallback
        embeddings_freq=1,
        profile_batch=1)
        
	# 实现查看批次级别数据变化
    def batchOutput(batch, logs):
        tf.summary.scalar('batch_loss', data=logs['loss'], step=batch)
        tf.summary.scalar('batch_accuracy', data=logs['accuracy'], step=batch)
        return batch
    batch_log_callback = callbacks.LambdaCallback(
        on_batch_end=batchOutput)

    # prepare early stop
    early_stop_callback = callbacks.EarlyStopping(
        monitor='val_loss', 
        patience=0,
        restore_best_weights=True)

    # train model
    epochs_num = 4
    model.fit(x=X,
              y=X, 
              epochs=epochs_num, 
              batch_size=64, 
              validation_data=(X_eval, X_eval),
              verbose=1, # 0:silent, 1:progress bar, 2:one line per epoch
              callbacks=[tensorboard_callback, 
                         batch_log_callback,
                         early_stop_callback])

    # save model
    if save:
        MODEL_FOLDER = '/content/drive/MyDrive/A-Million-Headlines/pretrained'
        model_name = 'AutoEncoder-model-{}-epochs-{}.h5'.format(epochs_num, int(time.time()))
        joblib.dump(model, os.path.join(MODEL_FOLDER, model_name))
    
    return model

参考资料

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-05-11 16:25:25  更:2022-05-11 16:26:32 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/27 20:08:17-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计