IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch实现语义分割指标 -> 正文阅读

[人工智能]Pytorch实现语义分割指标

引言

本文主要参考了FCN论文[1]和一篇综述[2]中的指标描述,用Pytorch实现了pixel accuarcy,mean accuarcy,mean IU,以及frequency weighted IU等指标。

正文

由于FCN模型得到的数据形状为(N,NC,H,W),其中N为batch size个数,NC为类别总个数,包括背景类,高和宽为H,W。这个数据先经过softmax在dim=1处处理,然后再选择每个像素在不同类别下的最大值即得到形状为(N,H,W)的数据,在Pytorch中可由代码torch.argmax(input, dim=1)实现。
由于原本的labels为(N,C,H,W),其中C为RGB通道数,需要建立一个索引把labels转换成形状为(N,H,W)的数据,其中每个像素值范围为[0,NC),即用不同像素值来表示不同类别,本文并不涉及这个转换过程。前期准备工作完成,接下来就是用Pytorch代码进行实现,如下图所示为FCN论文当中对指标的定义:

metrics
需要说明的一点:我在计算pixel accuracy指标时并没有用到混淆矩阵,但在训练过程中,所有指标的计算都需要混淆矩阵,并且随着训练数据的增加需要不断地更新混淆矩阵中的值来计算指标。

1. pixel accuarcy

def confusion_matrix(input, target, num_classes):
    """
    input: torch.LongTensor:(N, H, W)
    target: torch.LongTensor:(N, H, W)
    num_classes: int
    results:Tensor
    """
    assert torch.max(input) < num_classes
    assert torch.max(target) < num_classes
    H, W = target.size()[-2:]
    results = torch.zeros((num_classes, num_classes), dtype=torch.long)
    for i, j in zip(target.flatten(), input.flatten()):
        results[i, j] += 1
    return results
    
def pixel_accuracy(input, target):
    """
    input: torch.FloatTensor:(N, C, H, W)
    target: torch.LongTensor:(N, H, W)
    return: Tensor
    """
    assert len(input.size()) == 4
    assert len(target.size()) == 3
    N, H, W = target.size()
    input = F.softmax(input, dim=1)
    arg_max = torch.argmax(input, dim=1)
    # (TP + TN) / (TP + TN + FP + FN)
    return torch.sum(arg_max == target) / (N * H * W)

2. mean accuarcy

def mean_pixel_accuarcy(input, target):
    """
    input: torch.FloatTensor:(N, C, H, W)
    target: torch.LongTensor:(N, H, W)
    return: Tensor
    """
    N, num_classes, H, W = input.size()
    input = F.softmax(input, dim=1)
    arg_max = torch.argmax(input, dim=1)
    confuse_matrix = confusion_matrix(arg_max, target, num_classes)
    result = 0
    for i in range(num_classes):
        result += (confuse_matrix[i, i] / torch.sum(confuse_matrix[i, :]))

    return result / num_classes

3. mean IU

def mean_iou(input, target):
    """
    input: torch.FloatTensor:(N, C, H, W)
    target: torch.LongTensor:(N, H, W)
    return: Tensor
    """
    assert len(input.size()) == 4
    assert len(target.size()) == 3
    N, num_classes, H, W = input.size()
    input = F.softmax(input, dim=1)
    arg_max = torch.argmax(input, dim=1)
    result = 0
    confuse_matrix = confusion_matrix(arg_max, target, num_classes)
    for i in range(num_classes):
        nii = confuse_matrix[i, i]
        # consider the case where the denominator is zero.
        if nii == 0:
            continue
        else:
            ti, tj = torch.sum(confuse_matrix[i, :]), torch.sum(confuse_matrix[:, i])
            result += (nii / (ti + tj - nii))

    return result / num_classes

4. frequecy weighted IU

def frequency_weighted_iou(input, target):
    """
    input: torch.FloatTensor:(N, C, H, W)
    target: torch.LongTensor:(N, H, W)
    return: Tensor
    """
    assert len(input.size()) == 4
    assert len(target.size()) == 3
    N, num_classes, H, W = input.size()
    input = F.softmax(input, dim=1)
    arg_max = torch.argmax(input, dim=1)
    # get confusion matrix
    result = 0
    confuse_matrix = confusion_matrix(arg_max, target, num_classes)
    for i in range(num_classes):
        nii = confuse_matrix[i, i]
        # consider the case where the denominator is zero.
        if nii == 0:
            continue
        else:
            ti, tj = torch.sum(confuse_matrix[i, :]), torch.sum(confuse_matrix[:, i])
            result += (ti * nii / (ti + tj - nii))

    return result / torch.sum(confuse_matrix)

结语

本文主要是用Pytorch实现了语义分割常用的一些指标,代码部分还有优化和改进的地方,欢迎大家进行交流并提出建议。

参考文献

[1] Shelhamer E, Long J, Darrell T. Fully Convolutional Networks for Semantic Segmentation. IEEE Trans Pattern Anal Mach Intell. 2017 Apr;39(4):640-651. doi: 10.1109/TPAMI.2016.2572683. Epub 2016 May 24. PMID: 27244717.
[2] Garcia-Garcia A , Orts-Escolano S , Oprea S , et al. A Review on Deep Learning Techniques Applied to Semantic Segmentation[J]. 2017.

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-04 11:12:44  更:2021-08-04 11:12:51 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 15:02:07-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码