import os
import numpy as np
import codecs
import SimpleITK as sitk
import pandas as pd
import torch
'''
dice.txt context
---len(num_images)
---path of ground truth of image_001
---path of seg_mask of image_001
---path of ground truth of image_002
---path of seg_mask of image_002
...
...
...
'''
def readlines(file):
"""
read lines by removing '\n' in the end of line
:param file: a text file
:return: a list of line strings
"""
fp = codecs.open(file, 'r', encoding='utf-8')
linelist = fp.readlines()
fp.close()
for i in range(len(linelist)):
linelist[i] = linelist[i].rstrip('\n')
return linelist
def read_test_txt(imlist_file):
'''
:param imlist_file: image list file path
:return: image list path divided into two list
'''
lines = readlines(imlist_file)
num_cases = int(lines[0])
if (len(lines) - 1) < (num_cases * 2):
raise ValueError('too few lines in imlist file')
im_list, seg_list = [], []
for i in range(num_cases):
im_path, seg_path = lines[1 + i * 2].strip(), lines[2 + i * 2].strip()
assert os.path.exists(im_path), 'image not exist: {}'.format(im_path)
assert os.path.exists(seg_path), 'mask not exist: {}'.format(seg_path)
im_list.append(im_path)
seg_list.append(seg_path)
return im_list, seg_list
def cal_dice(input_tensor, target, num_class, epsilon=1e-6):
'''
:params input_tensor: the result of segmentation
:params target: ground true mask
:params num_class: label number
:params epsilon avoid dividezero arguments
:return: each class dice score
'''
dice_score = []
for i in range(1, num_class):
input_i = (input_tensor == i) * 1
target_i = (target == i) * 1
input_i = input_i.view(-1)
target_i = target_i.view(-1)
intersect = torch.sum(input_i * target_i, 0)
input_area = torch.sum(input_i, 0)
target_area = torch.sum(target_i, 0)
sum_area = input_area + target_area + 2 * epsilon
dice_score_i = 2 * intersect.float() / sum_area.float()
dice_score.append(dice_score_i)
print('class = {}, dice = {}'.format(i, dice_score_i))
return dice_score
def val(input_path, results_csv):
if input_path.endswith('txt'):
gt_list, pre_list = read_test_txt(input_path)
else:
raise ValueError('image test_list must either be a txt file or a csv file')
dice_score_record = pd.DataFrame(columns = ['case_name', 'tumor'])
for gt_path, pre_path in zip(gt_list, pre_list):
print('{}: {}'.format(gt_path, pre_path))
gt_mask = sitk.ReadImage(gt_path)
pre_mask = sitk.ReadImage(pre_path)
case_name = pre_path.split('/')[5]
print(case_name)
gt_mask_np = sitk.GetArrayFromImage(gt_mask).astype(float)
pre_mask_np = sitk.GetArrayFromImage(pre_mask).astype(float)
num_label = np.unique(gt_mask_np)
num_class = len(num_label)
gt_mask = torch.from_numpy(gt_mask_np)
gt_mask = torch.unsqueeze(gt_mask, 0)
gt_mask = gt_mask.float()
pre_mask = torch.from_numpy(pre_mask_np)
pre_mask = torch.unsqueeze(pre_mask, 0)
pre_mask = pre_mask.float()
dice_score = cal_dice(pre_mask, gt_mask, num_class)
if num_class == 3:
df = pd.DataFrame({
'case_name':case_name,
'left_testis': dice_score[0].item(),
'right_testis': dice_score[1].item()
},index=[0])
if num_class == 2:
df = pd.DataFrame({
'case_name': case_name,
'tumor':dice_score[0].item()
},index=[0])
dice_score_record = dice_score_record.append(df)
dice_score_record.to_csv(results_csv, index=None)
input_path = '/home/xxx/06_datalist/dice.txt'
results_csv = '/home/xxx/06_datalist/dice.csv'
val(input_path, results_csv)
|