IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Tensorflow keras中实现语义分割多分类指标:IOU、MIOU -> 正文阅读

[人工智能]Tensorflow keras中实现语义分割多分类指标:IOU、MIOU

在TF1.x版本中 miou指标可以使用tf.metrics.mean_iou 进行计算:

tf.metrics.mean_iou(labels, predictions, num_classes) 

但是该方法有如下几点限制:

1. 无法在动态图中使用,例如Tensorflow2.x版本中(注:TF2.x中api移动到了tf.compat.v1.metrics.mean_iou中),由于TF2.x默认是开启动态图,因此会报错(见mean_iou方法的源码)

  if context.executing_eagerly():
    raise RuntimeError('tf.metrics.mean_iou is not supported when '
                       'eager execution is enabled.')

2. 使用必须先?sess.run(tf.local_variables_initializer()) 然后 sess.run(update_op),最后sess.run(mean_iou_v),注意次序不能颠倒,不太方便和tf.keras相关训练代码结合使用

mean_iou_v, update_op = tf.metrics.mean_iou(y_true, y_pred, num_classes=4)
sess = tf.Session()
sess.run(tf.local_variables_initializer())
print(sess.run(update_op))
print(sess.run(mean_iou_v))

3. 只能直接输出所有类别的平均IOU即mean_iou, 而不能输出各个类别对应的 iou

针对上述三个问题,我发现有如下两种解决方案:

目录

方案1:自己实现相关计算代码

方案2:继承调用tf.keras.metrics.MeanIoU类


方案1:自己实现相关计算代码

def cal_mean_iou(num_classes=None, ignore_label=None):
    "ignore_label: int,注意这里ignore_label必须为整数,表示需要忽略的(不需要计算miou)类别"

    def MIOU(y_true, y_pred):
        """
        y_true: Tensor,真实标签(one-hot类型),
        y_pred: Tensor,模型输出结果(one-hot类型),二者shape都为[N,H,W,C]或[N,H*W,C],C为总类别数,
        """
        y_true = tf.reshape(tf.argmax(y_true, axis=-1), [-1])  # 求argmax后,展平为一维
        y_pred = tf.reshape(tf.argmax(y_pred, axis=-1), [-1])
        if ignore_label is not None:
            mask = tf.not_equal(y_true, ignore_label)  # 获取需要忽略的标签的位置
            y_true = tf.boolean_mask(y_true, mask)  # 剔除y_true中需要忽略的标签
            y_pred = tf.boolean_mask(y_pred, mask)  # 剔除y_pred中需要忽略的标签
        confusion_matrix = tf.confusion_matrix(y_true, y_pred, num_classes)  # 计算混淆矩阵
        intersect = tf.diag_part(confusion_matrix)  # 获取对角线上的矩阵,形成一维向量
        union = tf.reduce_sum(confusion_matrix, axis=0) + tf.reduce_sum(confusion_matrix, axis=1) - intersect
        iou = tf.div_no_nan(tf.cast(intersect, tf.float32), tf.cast(union, tf.float32)) #一维向量,每个类别的iou
        num_valid_entries = tf.reduce_sum(tf.cast(tf.not_equal(union, 0), dtype=tf.float32)) #统计union中不为0的总数
        mean_iou = tf.div_no_nan(tf.reduce_sum(iou), num_valid_entries)  # mean_iou只需要计算union中不为0的
        return mean_iou

    return MIOU

上述代码是自己实现的各类别IOU以及平均IOU的计算方法,

如果只是想直接显示平均IOU,那么直接这样使用即可:

model.compile(optimizer=optimizer, 
              loss=loss,
              metrics=[cal_mean_iou(num_classes, ignore_label)])

如果想要在tf.keras训练过程中显示各个类别的IOU,一般是继承tf.keras.callbacks.Callback类,然后重写相关的方法,方法可参考:相关参考博客3

方案2:继承调用tf.keras.metrics.MeanIoU类

方案1中的计算方式和tf.keras.metrics.MeanIoU(num_classes)计算方式类似,需要注意tf.keras.metrics.MeanIoU类中update_state(self, y_true, y_pred, sample_weight=None)方法接受的y_true和y_pred一般是非one-hot编码形式的,即如果网络的输入shape为[N,H,W,C]或[N,H*W,C]形式,需要将y_true和y_pred先求argmax,然后调用该方法

因此遇到上述情况,可以先继承tf.keras.metrics.MeanIoU类,然后重写update_state方法,示例代码如下:

class MeanIoU(tf.keras.metrics.MeanIoU):
    """
    y_true: Tensor,真实标签(one-hot类型),
    y_pred: Tensor,模型输出结果(one-hot类型),二者shape都为[N,H,W,C],C为总类别数,
    """
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        super().update_state(y_true, y_pred, sample_weight=sample_weight)
        return self.result()

model.compile(optimizer=optimizer, 
              loss=loss,
              metrics=[tf.keras.metrics.MeanIoU(num_classes)])

相关参考博客:

1.图像分割常用指标及MIoU计算 - 简书

2.MIoU 源码解析 - Wenshan's Blog

3.Keras上实现recall和precision,f1-score(多分类问题)_热爱学习的Valeria的博客-CSDN博客_keras precision

4.Tensorflow中tf.keras.metrics.MeanIoU在shape不一致错误_Bluish White的博客-CSDN博客
?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-01 15:44:16  更:2022-05-01 15:44:29 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 17:52:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码