语义分割常用的指标有: PA: 像素准确率 CPA: 类别像素准确率 IoU:交并比 mIoU:平均交并比 其中mIoU是用得比较多一个评价标准
具体的介绍计算方法可以参考下面这篇博客,博主进行了很详细的介绍: 【语义分割】评价指标:PA、CPA、MPA、IoU、MIoU详细总结和代码实现(零基础从入门到精通系列!)
本文主要是想写一个用Pytorch计算的方法。当初想着直接拿这些评价指标的倒数作为loss来训练网络,所以才想着用Pytorch来计算这些评价指标。事实证明还是太年轻,哈哈。有这种想法的小伙伴赶紧停止你的幻想,刚入门的可以先拿交叉熵损失练练手。
话不多说,直接上代码:
"""
refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
"""
import torch
import cv2
import numpy as np
__all__ = ['SegmentationMetric']
"""
confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反
P\L P N
P TP FP
N FN TN
"""
class SegmentationMetric(object):
def __init__(self, numClass):
self.numClass = numClass
self.confusionMatrix = torch.zeros((self.numClass,) * 2)
def pixelAccuracy(self):
acc = torch.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
return acc
def classPixelAccuracy(self):
classAcc = torch.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
return classAcc
def meanPixelAccuracy(self):
"""
Mean Pixel Accuracy(MPA,均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。
:return:
"""
classAcc = self.classPixelAccuracy()
meanAcc = classAcc[classAcc < float('inf')].mean()
return meanAcc
def IntersectionOverUnion(self):
intersection = torch.diag(self.confusionMatrix)
union = torch.sum(self.confusionMatrix, axis=1) + torch.sum(self.confusionMatrix, axis=0) - torch.diag(
self.confusionMatrix)
IoU = intersection / union
return IoU
def meanIntersectionOverUnion(self):
IoU = self.IntersectionOverUnion()
mIoU = IoU[IoU<float('inf')].mean()
return mIoU
def genConfusionMatrix(self, imgPredict, imgLabel, ignore_labels):
"""
同FCN中score.py的fast_hist()函数,计算混淆矩阵
:param imgPredict:
:param imgLabel:
:return: 混淆矩阵
"""
mask = (imgLabel >= 0) & (imgLabel < self.numClass)
for IgLabel in ignore_labels:
mask &= (imgLabel != IgLabel)
label = self.numClass * imgLabel[mask] + imgPredict[mask]
count = torch.bincount(label, minlength=self.numClass ** 2)
confusionMatrix = count.view(self.numClass, self.numClass)
return confusionMatrix
def Frequency_Weighted_Intersection_over_Union(self):
"""
FWIoU,频权交并比:为MIoU的一种提升,这种方法根据每个类出现的频率为其设置权重。
FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
"""
freq = torch.sum(self.confusion_matrix, axis=1) / torch.sum(self.confusion_matrix)
iu = np.diag(self.confusion_matrix) / (
torch.sum(self.confusion_matrix, axis=1) + torch.sum(self.confusion_matrix, axis=0) -
torch.diag(self.confusion_matrix))
FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
return FWIoU
def addBatch(self, imgPredict, imgLabel, ignore_labels):
assert imgPredict.shape == imgLabel.shape
self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel, ignore_labels)
return self.confusionMatrix
def reset(self):
self.confusionMatrix = torch.zeros((self.numClass, self.numClass))
if __name__ == '__main__':
imgPredict = torch.tensor([[0,1,2],[2,1,1]]).long()
imgLabel = torch.tensor([[0,1,255],[1,1,2]]).long()
ignore_labels = [255]
metric = SegmentationMetric(3)
hist = metric.addBatch(imgPredict, imgLabel,ignore_labels)
pa = metric.pixelAccuracy()
cpa = metric.classPixelAccuracy()
mpa = metric.meanPixelAccuracy()
IoU = metric.IntersectionOverUnion()
mIoU = metric.meanIntersectionOverUnion()
print('hist is :\n', hist)
print('PA is : %f' % pa)
print('cPA is :', cpa)
print('mPA is : %f' % mpa)
print('IoU is : ', IoU)
print('mIoU is : ', mIoU)
其实就是把原程序的numpy替换成了torch,这两个框架很多命令都是一样的。唯一不同的是torch里面没有类似np.nanmean()这个函数。np.nanmean()可以计算除NaN之外的其他元素的平均值。torch如何实现这个功能呢?其实也很简单,一行代码就能搞定:
tensor[tensor < float('inf')].mean()
具体也可以参考我另外一篇博客: torch实现np.nanmean的功能
解决这个之后就能用pytorch来计算mIoU了。
下一个问题就是在计算精度时我要忽略一些标签该怎么办?比如我们训练的时候经常把交叉熵损失中的ignore_index设置为255来忽略掉255这个标签,也就是该标签值不参与损失计算。个人理解应该是该标签值对应的元素,不管你预测成什么样都不会影响计算的结果。那么测试的时候呢,不会有影响吗?这也是我这两天一直在想该怎么解决这个问题,答案就是在你测试计算精度的时候也忽略掉这个值,只计算你感兴趣的类的精度。我的方法是通过修改生成混淆矩阵来实现:
def genConfusionMatrix(self, imgPredict, imgLabel, ignore_labels):
"""
同FCN中score.py的fast_hist()函数,计算混淆矩阵
:param imgPredict:
:param imgLabel:
:return: 混淆矩阵
"""
mask = (imgLabel >= 0) & (imgLabel < self.numClass)
for IgLabel in ignore_labels:
mask &= (imgLabel != IgLabel)
label = self.numClass * imgLabel[mask] + imgPredict[mask]
count = torch.bincount(label, minlength=self.numClass ** 2)
confusionMatrix = count.view(self.numClass, self.numClass)
return confusionMatrix
以上就是我这两天摸索出来的方法,写出来分享给大家,有什么错误或者改进的地方也欢迎大家提出来
|