假设文件是这样的
images和labels里面保存的都是.npy数组
images里面的一个数据的shape=[128,128,16,1],labels里面的一个数据的shape=[128,128,16,2],因为是二分类语义分割
data_loader.py
from tensorflow.keras.utils import Sequence
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class seg3D_Sequence(Sequence):
def __init__(self,file_name_list, batch_size,image_path="data/images/",
label_path="data/labels/"):
self.file_name_list = file_name_list
self.batch_size = batch_size
self.image_path = image_path
self.label_path = label_path
def __len__(self):
return math.ceil(len(self.file_name_list) / self.batch_size)
def __getitem__(self, idx):
self.x, self.y = [self.image_path+file_name for file_name in self.file_name_list], \
[self.label_path+file_name for file_name in self.file_name_list]
# print(self.x)
# print(('-'*60).center(60))
# print(self.y)
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
x_re = [np.load(file_name) for file_name in batch_x]
y_re = [np.load(file_name) for file_name in batch_y]
return np.array(x_re),np.array(y_re)
def on_epoch_end(self):
np.random.shuffle(self.file_name_list)
?train.py
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from utils.data_loader import seg3D_Sequence
from tensorflow.keras import Model,Sequential
from tensorflow.keras.layers import Conv3D,Input
import numpy as np
if __name__ == "__main__":
print('-'*60)
num_classes = 2
batch_size = 7
train_val_split = 0.2
image_path = "data/images/"
file_name_list = os.listdir(image_path)
train_name_list = file_name_list[:int(len(file_name_list)*0.8)]
val_name_list = file_name_list[int(len(file_name_list)*0.8):]
# print(len(file_name_list)) # 360
train_data_loader = seg3D_Sequence(train_name_list,batch_size)
val_data_loader = seg3D_Sequence(val_name_list,batch_size)
# x,y = data_loader[90]
# print(x.shape,y.shape)
model = Sequential()
model.add(Conv3D(num_classes,1,activation='sigmoid'))
inputs = Input(shape=[128,128,16,1])
outputs = model(inputs)
print(outputs.shape)
model = Model(inputs,outputs,name='test3d')
model.summary()
model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']
)
model.fit(train_data_loader,
epochs=3,
validation_data=val_data_loader)
|