标准版本:
import torch
from torch import Tensor
import torch.nn.functional as F
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)
if __name__=="__main__":
//masks_pred表示模型的预测结果,true_masks表示真实标签 此处在多分类情况下,如果标签维度仅为(b,h,w),则需要onehot编码,增加channel维度,保持标签与预测结果的size一致
dice_loss(F.softmax(masks_pred, dim=1).float(), F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), multiclass=True)
简单版本(需要保持预测结果logits与标签targets的size一致):
import torch.nn as nn
import torch.nn.functional as F
class SoftDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(SoftDiceLoss, self).__init__()
def forward(self, logits, targets):
num = targets.size(0)
smooth = 1
probs = F.sigmoid(logits)
m1 = probs.view(num, -1)
m2 = targets.view(num, -1)
intersection = (m1 * m2)
score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
score = 1 - score.sum() / num
return score
简化版本: 利用了pytorch的广播机制 pred与target通道需要保持一致(此处用在图像分割中,默认类别是互斥的,非多通道二分类)
def dice_loss(pred, target, smooth = 1.):
pred = pred.contiguous()
target = target.contiguous()
intersection = (pred * target).sum(dim=2).sum(dim=2)
loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
return loss.mean()
由于广播机制,此处1-torch.mean()与(1-torch).mean()结果是相同的
注意:上述dice loss方法,涉及到了pred与target保持size一致的方式。实际分割任务中target的size与所使用的loss函数有关: 多分类使用交叉熵–torch.nn.CrossEntropyLoss,此时target不需要channel,所以在后续加dice loss的时候,需要通过F.one_hot()函数处理target,保持size与pred一致 多标签二分类通常使用torch.nn.BCEWithLogitsLoss,该函数要求target与pred的size一致,一般我们在自定义的ataloder中就会对label进行维度处理,或者在loss计算前处理。后续加dice loss,不需要再onehot了。
|