| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> Python知识库 -> 四、肺癌检测-数据集准备 dsets.py文件 -> 正文阅读 |
|
[Python知识库]四、肺癌检测-数据集准备 dsets.py文件 |
一、目标数据集准备需要完成以下几个工作: 1. 读取annotations.csv内容; 2. 读取candidates.csv内容; 3. 构造Ct类,用于根据输入的series_uid,获取该uid的CT数据的信息。 4. 构造Dataset类,用于加载数据集。 二、要点说明1. SimpleITK库读取和解析CT结果的【mhd】文件需要使用SimpleITK库,可通过【conda install simpleitk】命令安装。 其中主要用到以下几个函数说明如下:
2. functools库代码中用到了functools库,用于将某些函数的结果缓存到内存中。 @functools.lru_cache(1):代表1次缓存。用于存放在需要缓存的函数定义的代码的开头。意义是:如果该函数之前已经输入过相同的参数,下一次再输入相同参数时,函数直接从缓存调用结果,而不会从新执行函数内部代码。 3. diskcache库代码中用到了diskcache库,用于将CT数据解析后缓存到磁盘中,使用缓存可以较大的提高训练时数据加载速度。库的使用可参考相关文章: 【编程】Python : diskcache 本地缓存持久化,一行代码_哔哩哔哩_bilibili Python 爬虫进阶篇——diskcache缓存_十先生(公众号:Python知识学堂)的博客-CSDN博客_diskcache python 4. CT文件信息4.1 csv文件annotations.csv: 记录实际结节。文件结构: uid, x, y, z, diameter candidates.csv: 记录候选结节。文件结构: uid, x, y, z, class 注意:两个文件中,相同的uid对应的xyz坐标可能有偏差,要将偏差大于半径的一半(即diameter/4)的数据的diameter强制为0,即认为这个结节异常,不处理。 5. XYZ、IRC坐标轴5.1 坐标轴方向CT数据中,有XYZ坐标轴,训练时需要转换为IRC坐标轴,两个坐标轴分别对应着: xyz:各坐标轴正的方向指向的人体的方向为为: irc:各坐标轴正的方向指向的人体的方向为为: 其中i-index,r-row,? c-column 简记为:xyz-左后上,irc-上后左
|
良性结节和恶性结节的特征区别 | ||
特征 | 良性 | 恶性 |
生长速度 | 迅速 | 缓慢 |
查体表现 | 软,活动度大 | 硬,活动度小 |
超声检查 | 边界清晰,与组织分解明显 | 边界不清晰,与组织分解不明显 |
形态 | 光滑,圆 | 不规则,纵横比>1,直立生长 |
下图第一行是对CT文件中,三维CT矩阵用不同维度索引下的结果;
下图第二行是对某个结节中,三维结节矩阵用不同维度索引下的结果。
更多可视化内容可参照原书代码的ipynb文件。
candidateInfo_list = getCandidateInfoList(requireOnDisk_bool=True)
返回candidates.csv文件对应的list,其中每个元素为名称为candidateInfoTuple的元组,元组有如下节点:
class, diameter, id, xyz
属性如下:
CT.hu_a:以HU为单位的三维array,存储的是CT的所有体素数据。
CT.origin_xyz:xyz坐标和irc坐标的原点偏移量
CT.vzSize_xyz:体素在xyz坐标轴的尺度大小
CT.direction_a:体素的空间矩阵
CT.getRawCandidate函数:
ct_chunk, center_irc = getRawCandidate(center_xyz, width_irc)
center_xyz:结节在xyz坐标系的坐标值。
width_irc:结节在irc坐标系的尺寸大小。也是数据集输入到模型的input_size
ct_chunk:结节在irc坐标轴的HU值的三维矩阵。
center_irc:结节中心在irc坐标系的坐标值。
ds = LunaDataset(val_stride=0, isValSet_bool=False, series_uid=None)
val_stride:作为验证集时,从数据集中抽取样本作为验证集样本的步长。即每隔val_stride抽取一个样本作为验证集样本。
isValSet_bool:是否作为验证集。
series_uid:获取某个uid对应的所有样本。
书中代码【dsets.py】如下:
import copy
import csv
import functools
import glob
import os
from collections import namedtuple
import SimpleITK as sitk
import numpy as np
import torch
import torch.cuda
from torch.utils.data import Dataset
from util.disk import getCache
from util.util import XyzTuple, xyz2irc
from util.logconf import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
raw_cache = getCache('part2ch10_raw')
CandidateInfoTuple = namedtuple(
'CandidateInfoTuple',
'isNodule_bool, diameter_mm, series_uid, center_xyz',
)
@functools.lru_cache(1)
def getCandidateInfoList(requireOnDisk_bool=True):
# We construct a set with all series_uids that are present on disk.
# This will let us use the data, even if we haven't downloaded all of
# the subsets yet.
mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
diameter_dict = {}
with open('data/part2/luna/annotations.csv', "r") as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
annotationDiameter_mm = float(row[4])
diameter_dict.setdefault(series_uid, []).append(
(annotationCenter_xyz, annotationDiameter_mm)
)
candidateInfo_list = []
with open('data/part2/luna/candidates.csv', "r") as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
if series_uid not in presentOnDisk_set and requireOnDisk_bool:
continue
isNodule_bool = bool(int(row[4]))
candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
candidateDiameter_mm = 0.0
for annotation_tup in diameter_dict.get(series_uid, []):
annotationCenter_xyz, annotationDiameter_mm = annotation_tup
for i in range(3):
delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
if delta_mm > annotationDiameter_mm / 4:
break
else:
candidateDiameter_mm = annotationDiameter_mm
break
candidateInfo_list.append(CandidateInfoTuple(
isNodule_bool,
candidateDiameter_mm,
series_uid,
candidateCenter_xyz,
))
candidateInfo_list.sort(reverse=True)
return candidateInfo_list
class Ct:
def __init__(self, series_uid):
mhd_path = glob.glob(
'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
)[0]
ct_mhd = sitk.ReadImage(mhd_path)
ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
# CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
# HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
# The lower bound gets rid of negative density stuff used to indicate out-of-FOV
# The upper bound nukes any weird hotspots and clamps bone down
ct_a.clip(-1000, 1000, ct_a)
self.series_uid = series_uid
self.hu_a = ct_a
self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
def getRawCandidate(self, center_xyz, width_irc):
center_irc = xyz2irc(
center_xyz,
self.origin_xyz,
self.vxSize_xyz,
self.direction_a,
)
slice_list = []
for axis, center_val in enumerate(center_irc):
start_ndx = int(round(center_val - width_irc[axis]/2))
end_ndx = int(start_ndx + width_irc[axis])
assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
if start_ndx < 0:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
start_ndx = 0
end_ndx = int(width_irc[axis])
if end_ndx > self.hu_a.shape[axis]:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
end_ndx = self.hu_a.shape[axis]
start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
slice_list.append(slice(start_ndx, end_ndx))
ct_chunk = self.hu_a[tuple(slice_list)]
return ct_chunk, center_irc
@functools.lru_cache(1, typed=True)
def getCt(series_uid):
return Ct(series_uid)
@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
ct = getCt(series_uid)
ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
return ct_chunk, center_irc
class LunaDataset(Dataset):
def __init__(self,
val_stride=0,
isValSet_bool=None,
series_uid=None,
):
self.candidateInfo_list = copy.copy(getCandidateInfoList())
if series_uid:
self.candidateInfo_list = [
x for x in self.candidateInfo_list if x.series_uid == series_uid
]
if isValSet_bool:
assert val_stride > 0, val_stride
self.candidateInfo_list = self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
elif val_stride > 0:
del self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
log.info("{!r}: {} {} samples".format(
self,
len(self.candidateInfo_list),
"validation" if isValSet_bool else "training",
))
def __len__(self):
return len(self.candidateInfo_list)
def __getitem__(self, ndx):
candidateInfo_tup = self.candidateInfo_list[ndx]
width_irc = (32, 48, 48)
candidate_a, center_irc = getCtRawCandidate(
candidateInfo_tup.series_uid,
candidateInfo_tup.center_xyz,
width_irc,
)
candidate_t = torch.from_numpy(candidate_a)
candidate_t = candidate_t.to(torch.float32)
candidate_t = candidate_t.unsqueeze(0)
pos_t = torch.tensor([
not candidateInfo_tup.isNodule_bool,
candidateInfo_tup.isNodule_bool
],
dtype=torch.long,
)
return (
candidate_t,
pos_t,
candidateInfo_tup.series_uid,
torch.tensor(center_irc),
)
import functools
import glob
import os.path
import csv
import SimpleITK as sitk
import numpy as np
import copy
import torch
import torch.cuda
from torch.utils.data import Dataset
from collections import namedtuple
from util.disk import getCache
from util.util import XyzTuple, xyz2irc
from util.logconf import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
# annotations.csv: 记录实际结节。文件结构: uid, x, y, z, diameter
# candidates.csv: 记录候选结节。文件结构: uid, x, y, z, class
raw_cache = getCache('part2ch10_raw')
# 构建用于存储候选结节的元组, 结构: class, diameter, id, xyz
candidateInfoTuple = namedtuple('candidateInfoTuple',
'isNodule_bool, diameter_mm, series_uid, center_xyz')
@functools.lru_cache(1) # 缓存一次调用结果
def getCandidateInfoList(requireOnDisk_bool=True):
"""
加载annotations.csv和candidates.csv,分别存到diameter_list和candidateInfo_list
:param requireOnDisk_bool. 如果文件不存在,是否跳过
:return candidateInfo_list. 由candidateInfoTuple构成的list
"""
mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} # 提取所有文件名,即uid
diameter_dict= {}
with open('data/part2/luna/annotations.csv', 'r') as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
annotationDiameter_mm = float(row[4])
diameter_dict.setdefault(series_uid, []).append(
(annotationCenter_xyz, annotationDiameter_mm)
)
candidateInfo_list = []
with open('data/part2/luna/candidates.csv', 'r') as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
# 如果annotations.csv中找不到这个id,则跳过
if series_uid not in presentOnDisk_set and requireOnDisk_bool:
continue
candidateDiameter_xyz = tuple([float(x) for x in row[1:4]])
isNodule_bool = bool(int(row[4]))
# 如果candidate中的xyz坐标和annotation中的xyz坐标偏差大于半径的一半,
# 则认为它们不是同一个节点,将直接用零代替,即认为这不是结节
candidateDiameter_mm = 0.0
for annotation_tup in diameter_dict.get(series_uid, []):
annotation_xyz, annotationDiameter_mm = annotation_tup
for i in range(3):
delta_mm = abs(candidateDiameter_xyz[i] - annotation_xyz[i])
if delta_mm > annotationDiameter_mm/4:
break
else:
candidateDiameter_mm = annotationDiameter_mm
break
candidateInfo_list.append(candidateInfoTuple(
isNodule_bool,
candidateDiameter_mm,
series_uid,
candidateDiameter_xyz,
))
candidateInfo_list.sort(reverse=True)
return candidateInfo_list
class Ct:
def __init__(self, series_uid):
mhd_path = glob.glob(r'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
# 用SampleSTK包可直接读取CT扫描数据
ct_mhd = sitk.ReadImage(mhd_path)
# HU: 亨氏单位,Hounsfield Unit.
# 空气为-1000 HU,约等于0 g/cm3. 水为0 HU,约等于1 g/cm3, 骨骼至少时1000HU,约等于2~3g/cm3
ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) # 读取到的数据单位为HU
# 将数据限定再-1000~1000 HU
ct_a.clip(-1000, 1000, ct_a)
self.series_uid = series_uid
self.hu_a = ct_a
self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) # xyz坐标和irc坐标的原点偏移量
self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) # 体素在xyz坐标轴的大小
self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3) # 体素方向矩阵,等于eye(3)
def getRawCandidate(self, center_xyz, width_irc):
"""
根据xyz坐标算出病人坐标irc。然后根据每个结节的irc和体素宽度,算出结节包含的体素块数据
:param center_xyz: 结节的xyz坐标
:param width_irc: 体素宽度,也是数据集输入到模型的输入尺寸
:return ct_chunk: 结节包含的体素块的HU值,array
:return center_irc: 结节的病人坐标信息
"""
center_irc = xyz2irc(
center_xyz,
self.origin_xyz,
self.vxSize_xyz,
self.direction_a
)
slice_list = []
for axis, center_val in enumerate(center_irc):
start_ndx = int(round(center_val - width_irc[axis]/2))
end_ndx = int(start_ndx + width_irc[axis])
assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
if start_ndx < 0:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
start_ndx = 0
end_ndx = int(width_irc[axis])
if end_ndx > self.hu_a.shape[axis]:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
end_ndx = self.hu_a.shape[axis]
start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
slice_list.append(slice(start_ndx, end_ndx))
ct_chunk = self.hu_a[tuple(slice_list)]
return ct_chunk, center_irc
@functools.lru_cache(1, typed=True) # 保留一次缓存结果
def getCt(series_uid):
return Ct(series_uid)
@raw_cache.memoize(typed=True) # 数据缓存到同路径的cache文件夹下
def getCtRawCandidate(series_uid, center_xyz, width_irc):
ct = getCt(series_uid)
ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
return ct_chunk, center_irc
class LunaDataset(Dataset):
def __init__(self, val_stride=0, isValSet_bool=False, series_uid=None):
"""
val_stride:作为验证集时,从数据集中抽取样本作为验证集样本的步长。即每隔val_stride抽取一个样本作为验证集样本。
isValSet_bool:是否作为验证集。
series_uid:获取某个uid对应的所有样本。
"""
self.candidateInfo_list = copy.copy(getCandidateInfoList())
if series_uid:
self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid==series_uid]
if isValSet_bool:
assert val_stride > 0, val_stride
self.candidateInfo_list = self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
elif val_stride > 0:
del self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
log.info("(!r): {} {} samples".format(
self,
len(self.candidateInfo_list),
"validation" if isValSet_bool else "training",
))
def __len__(self):
return len(self.candidateInfo_list)
def __getitem__(self, ndx):
"""
返回指定索引对应的结节信息
:param ndx: 某个ct数据中的第ndx个结节索引
:return: candidate_t. 结节所包含的所有体素的三位数组。t代表数组时个tensor
:return: post_t. 结节是否为肿瘤。0代表不是,1代表肿瘤。
:return: series_uid. ndx所对应的结节uid
:return: center_irc. 结节的重心坐标。类型为tensor
"""
candidateInfo_tup = self.candidateInfo_list[ndx]
width_irc = (32, 48, 48)
candidate_a, center_irc = getCtRawCandidate(
candidateInfo_tup.series_uid,
candidateInfo_tup.center_xyz,
width_irc,
)
candidate_t = torch.from_numpy(candidate_a)
candidate_t = candidate_t.to(torch.float32)
candidate_t = candidate_t.unsqueeze(0)
post_t = torch.tensor([
not candidateInfo_tup.isNodule_bool,
candidateInfo_tup.isNodule_bool
],
dtype=torch.long,
)
return (
candidate_t,
post_t,
candidateInfo_tup.series_uid,
torch.tensor(center_irc)
)
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 | -2024/12/26 3:22:04- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |
数据统计 |