引言
本文主要参考了FCN论文[1]和一篇综述[2]中的指标描述,用Pytorch实现了pixel accuarcy,mean accuarcy,mean IU,以及frequency weighted IU等指标。
正文
由于FCN模型得到的数据形状为(N,NC,H,W),其中N为batch size个数,NC为类别总个数,包括背景类,高和宽为H,W。这个数据先经过softmax在dim=1处处理,然后再选择每个像素在不同类别下的最大值即得到形状为(N,H,W)的数据,在Pytorch中可由代码torch.argmax(input, dim=1)实现。 由于原本的labels为(N,C,H,W),其中C为RGB通道数,需要建立一个索引把labels转换成形状为(N,H,W)的数据,其中每个像素值范围为[0,NC),即用不同像素值来表示不同类别,本文并不涉及这个转换过程。前期准备工作完成,接下来就是用Pytorch代码进行实现,如下图所示为FCN论文当中对指标的定义:
需要说明的一点:我在计算pixel accuracy指标时并没有用到混淆矩阵,但在训练过程中,所有指标的计算都需要混淆矩阵,并且随着训练数据的增加需要不断地更新混淆矩阵中的值来计算指标。
1. pixel accuarcy
def confusion_matrix(input, target, num_classes):
"""
input: torch.LongTensor:(N, H, W)
target: torch.LongTensor:(N, H, W)
num_classes: int
results:Tensor
"""
assert torch.max(input) < num_classes
assert torch.max(target) < num_classes
H, W = target.size()[-2:]
results = torch.zeros((num_classes, num_classes), dtype=torch.long)
for i, j in zip(target.flatten(), input.flatten()):
results[i, j] += 1
return results
def pixel_accuracy(input, target):
"""
input: torch.FloatTensor:(N, C, H, W)
target: torch.LongTensor:(N, H, W)
return: Tensor
"""
assert len(input.size()) == 4
assert len(target.size()) == 3
N, H, W = target.size()
input = F.softmax(input, dim=1)
arg_max = torch.argmax(input, dim=1)
return torch.sum(arg_max == target) / (N * H * W)
2. mean accuarcy
def mean_pixel_accuarcy(input, target):
"""
input: torch.FloatTensor:(N, C, H, W)
target: torch.LongTensor:(N, H, W)
return: Tensor
"""
N, num_classes, H, W = input.size()
input = F.softmax(input, dim=1)
arg_max = torch.argmax(input, dim=1)
confuse_matrix = confusion_matrix(arg_max, target, num_classes)
result = 0
for i in range(num_classes):
result += (confuse_matrix[i, i] / torch.sum(confuse_matrix[i, :]))
return result / num_classes
3. mean IU
def mean_iou(input, target):
"""
input: torch.FloatTensor:(N, C, H, W)
target: torch.LongTensor:(N, H, W)
return: Tensor
"""
assert len(input.size()) == 4
assert len(target.size()) == 3
N, num_classes, H, W = input.size()
input = F.softmax(input, dim=1)
arg_max = torch.argmax(input, dim=1)
result = 0
confuse_matrix = confusion_matrix(arg_max, target, num_classes)
for i in range(num_classes):
nii = confuse_matrix[i, i]
if nii == 0:
continue
else:
ti, tj = torch.sum(confuse_matrix[i, :]), torch.sum(confuse_matrix[:, i])
result += (nii / (ti + tj - nii))
return result / num_classes
4. frequecy weighted IU
def frequency_weighted_iou(input, target):
"""
input: torch.FloatTensor:(N, C, H, W)
target: torch.LongTensor:(N, H, W)
return: Tensor
"""
assert len(input.size()) == 4
assert len(target.size()) == 3
N, num_classes, H, W = input.size()
input = F.softmax(input, dim=1)
arg_max = torch.argmax(input, dim=1)
result = 0
confuse_matrix = confusion_matrix(arg_max, target, num_classes)
for i in range(num_classes):
nii = confuse_matrix[i, i]
if nii == 0:
continue
else:
ti, tj = torch.sum(confuse_matrix[i, :]), torch.sum(confuse_matrix[:, i])
result += (ti * nii / (ti + tj - nii))
return result / torch.sum(confuse_matrix)
结语
本文主要是用Pytorch实现了语义分割常用的一些指标,代码部分还有优化和改进的地方,欢迎大家进行交流并提出建议。
参考文献
[1] Shelhamer E, Long J, Darrell T. Fully Convolutional Networks for Semantic Segmentation. IEEE Trans Pattern Anal Mach Intell. 2017 Apr;39(4):640-651. doi: 10.1109/TPAMI.2016.2572683. Epub 2016 May 24. PMID: 27244717. [2] Garcia-Garcia A , Orts-Escolano S , Oprea S , et al. A Review on Deep Learning Techniques Applied to Semantic Segmentation[J]. 2017.
|