Keras使用多个workers加速生成CAM
简述
Keras官方文档中有如何生成CAM热图的参考代码,但是如果对大量测试图像生成CAM则需要使计算过程更加高效。本博文在Keras官方文档代码基础上进行修改实现使用多个workers加速生成CAM。
使用multi-workers加速的部分主要是读取原始图像和向原始图像上添加热图两部分。加速生成热图部分主要是调高batch size。
下面以一个适用于2分支分割任务的代码来举例,读者可根据自己的需求自行修改。
代码细节
1. 主程序模块
def grad_cam_multi_worker(model_path,ori_root_path,save_root_path,data_folder_name,save_folder_name,last_conv_layer_name,num_workers,batch_size,custom_objects=None):
if custom_objects is None:
model = tf.keras.models.load_model(model_path)
else:
model = tf.keras.models.load_model(model_path,custom_objects=custom_objects)
for root, dirs, files in os.walk(ori_root_path):
if not os.path.exists(root.replace(data_folder_name,save_folder_name)):
os.makedirs(root.replace(data_folder_name,save_folder_name))
if len(files)==0:
continue
if len(os.listdir(root.replace(data_folder_name,save_folder_name)))!=0:
continue
print(root.split('//')[-1])
path_list = []
for name in files:
if 'image' in name:
path_list.append(os.path.join(root,name))
path_list = sorted(path_list)
indexes_total = list(np.arange(len(path_list)))
for i in range(int(np.ceil(len(indexes_total)/batch_size))):
indexes = []
cnt = 0
while len(indexes_total)!=0:
indexes.append(indexes_total.pop())
cnt = cnt+1
if cnt==batch_size:
break
if len(indexes)==0:
continue
dataset = Dataset(path_list=path_list)
samples = []
indexes_batch = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for sample in executor.map(lambda i: dataset[i],indexes):
samples.append(sample[0])
indexes_batch.append(sample[1])
ori_img = default_collate_fn(samples)
cam_1,cam_2 = make_gradcam_heatmap(img_array=ori_img, model=model, last_conv_layer_name=last_conv_layer_name)
if len(cam_1.shape)==2:
cam_1= np.expand_dims(cam_1,axis=0)
cam_2 = np.expand_dims(cam_2,axis=0)
samples2 = []
indexes_batch2 = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
datasetadd = DatasetAddHeatMap(cam=[cam_1,cam_2],ori_img=ori_img,path_list=path_list,indexes_batch=indexes_batch)
for sample in executor.map(lambda i: datasetadd[i],indexes):
samples2.append(sample[0])
indexes_batch2.append(sample[1])
gradcam1,gradcam2 = default_collate_fn_2(samples2)
cnt=0
for grad1, grad2 in zip(gradcam1,gradcam2):
figure = int(re.findall(r"\d+",path_list[indexes_batch2[cnt]].split('\\')[-1])[-1])
grad1.save(os.path.join(save_root_path,root.split('\\')[-2],root.split('\\')[-1],'CAM_1_'+str(figure).zfill(3)+'.png'))
grad2.save(os.path.join(save_root_path,root.split('\\')[-2],root.split('\\')[-1],'CAM_2_'+str(figure).zfill(3)+'.png'))
cnt = cnt+1
if __name__ == '__main__':
model_path = r'\model.hdf5'
ori_root_path = r'\test_data_root_path'
save_root_path = r'\CAM_save_root_path'
last_conv_layer_name = 'conv3d_9'
data_folder_name = 'test_data_root_folder_name'
save_folder_name = 'CAM_save_root_folder'
num_workers = 8
batch_size = 32
custom_objects={"DiceBCELoss":DiceBCELoss}
grad_cam_multi_worker(model_path=model_path,save_root_path=save_root_path,ori_root_path=ori_root_path,data_folder_name=data_folder_name,save_folder_name=save_folder_name,last_conv_layer_name=last_conv_layer_name,num_workers=num_workers,batch_size=batch_size,custom_objects=custom_objects)
2. 读取图像和叠加热图的子程序模块
class Dataset(object):
def __init__(self,path_list):
self.path_list = path_list
def __getitem__(self, index):
oriImg_path = self.path_list[index]
img_array = get_img_array(oriImg_path)+1024
img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min() + 1e-7)
return [img_array, index]
class DatasetAddHeatMap(object):
def __init__(self,cam,ori_img,path_list,indexes_batch):
self.path_list = path_list
self.cam = cam
self.ori_img = ori_img
self.indexes_batch = indexes_batch
self.gradcam = gradcam
def __getitem__(self, index):
idx = int(np.argwhere(self.indexes_batch==index))
superimposed_img_1 = self.gradcam(img_path=self.path_list[index], heatmap=self.cam[0][idx,:], ori_img=self.ori_img[idx,:], alpha=0.4)
superimposed_img_2 = self.gradcam(img_path=self.path_list[index], heatmap=self.cam[1][idx,:], ori_img=self.ori_img[idx,:], alpha=0.4)
return [[superimposed_img_1,superimposed_img_2], index]
3. 生成热图的子程序模块
def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
grad_model_1 = tf.keras.models.Model(
[model.inputs], [model.get_layer(last_conv_layer_name).output, model.output[0]]
)
grad_model_2 = tf.keras.models.Model(
[model.inputs], [model.get_layer(last_conv_layer_name).output, model.output[1]]
)
with tf.GradientTape() as tape:
last_conv_layer_output, preds_1 = grad_model_1({'data':img_array})
grads_1 = tape.gradient(preds_1, last_conv_layer_output)
with tf.GradientTape() as tape:
last_conv_layer_output, preds_2 = grad_model_2({'data':img_array})
grads_2 = tape.gradient(preds_2, last_conv_layer_output)
pooled_grads_1 = tf.reduce_mean(grads_1, axis=(0, 1, 2, 3))
pooled_grads_2 = tf.reduce_mean(grads_2, axis=(0, 1, 2, 3))
heatmap_1 = last_conv_layer_output @ pooled_grads_1[..., tf.newaxis]
heatmap_1 = tf.squeeze(heatmap_1)
heatmap_1 = tf.reduce_mean(heatmap_1,axis=-1)
heatmap_2 = last_conv_layer_output @ pooled_grads_2[..., tf.newaxis]
heatmap_2 = tf.squeeze(heatmap_2)
heatmap_2 = tf.reduce_mean(heatmap_2,axis=-1)
heatmap_1 = tf.maximum(heatmap_1, 0) / tf.math.reduce_max(heatmap_1)
heatmap_2 = tf.maximum(heatmap_2, 0) / tf.math.reduce_max(heatmap_2)
return heatmap_1.numpy(), heatmap_2.numpy()
References
https://keras.io/examples/vision/grad_cam/
|