IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> CAM (Classification Attention Maps) Generation with Multi-workers in Keras -> 正文阅读

[Python知识库]CAM (Classification Attention Maps) Generation with Multi-workers in Keras

Keras使用多个workers加速生成CAM

简述

Keras官方文档中有如何生成CAM热图的参考代码,但是如果对大量测试图像生成CAM则需要使计算过程更加高效。本博文在Keras官方文档代码基础上进行修改实现使用多个workers加速生成CAM。

使用multi-workers加速的部分主要是读取原始图像和向原始图像上添加热图两部分。加速生成热图部分主要是调高batch size。

下面以一个适用于2分支分割任务的代码来举例,读者可根据自己的需求自行修改。

代码细节

1. 主程序模块

def grad_cam_multi_worker(model_path,ori_root_path,save_root_path,data_folder_name,save_folder_name,last_conv_layer_name,num_workers,batch_size,custom_objects=None):
    if custom_objects is None:
        model = tf.keras.models.load_model(model_path)
    else:
        model = tf.keras.models.load_model(model_path,custom_objects=custom_objects)
    for root, dirs, files in os.walk(ori_root_path):
        if not os.path.exists(root.replace(data_folder_name,save_folder_name)):
            os.makedirs(root.replace(data_folder_name,save_folder_name))
        if len(files)==0:
            continue
        if len(os.listdir(root.replace(data_folder_name,save_folder_name)))!=0:
            continue
        print(root.split('//')[-1])
        path_list = []
        for name in files:
            if 'image' in name:
                path_list.append(os.path.join(root,name))
        path_list = sorted(path_list)
        indexes_total = list(np.arange(len(path_list)))
        for i in range(int(np.ceil(len(indexes_total)/batch_size))): # batchSize level
            indexes = []
            cnt = 0
            while len(indexes_total)!=0:
                indexes.append(indexes_total.pop())
                cnt = cnt+1
                if cnt==batch_size:
                    break
            if len(indexes)==0:
                continue

            ## load original images with multiple workers
            dataset = Dataset(path_list=path_list)
            samples = []
            indexes_batch = []
            with ThreadPoolExecutor(max_workers=num_workers) as executor:
                for sample in executor.map(lambda i: dataset[i],indexes):
                    samples.append(sample[0])
                    indexes_batch.append(sample[1])
            ori_img = default_collate_fn(samples) # 16*128*128*7*1

            cam_1,cam_2 = make_gradcam_heatmap(img_array=ori_img, model=model, last_conv_layer_name=last_conv_layer_name)
            if len(cam_1.shape)==2: #  batchsize=1 or the last batch not full
                cam_1= np.expand_dims(cam_1,axis=0)
                cam_2 = np.expand_dims(cam_2,axis=0)
            ## Add heat maps to original image with multiple workers
            samples2 = []
            indexes_batch2 = []
            with ThreadPoolExecutor(max_workers=num_workers) as executor:
                datasetadd = DatasetAddHeatMap(cam=[cam_1,cam_2],ori_img=ori_img,path_list=path_list,indexes_batch=indexes_batch)
                for sample in executor.map(lambda i: datasetadd[i],indexes):
                    samples2.append(sample[0])
                    indexes_batch2.append(sample[1])
            gradcam1,gradcam2 = default_collate_fn_2(samples2) 
            cnt=0
            for grad1, grad2 in zip(gradcam1,gradcam2):
                figure = int(re.findall(r"\d+",path_list[indexes_batch2[cnt]].split('\\')[-1])[-1])
                grad1.save(os.path.join(save_root_path,root.split('\\')[-2],root.split('\\')[-1],'CAM_1_'+str(figure).zfill(3)+'.png'))
                grad2.save(os.path.join(save_root_path,root.split('\\')[-2],root.split('\\')[-1],'CAM_2_'+str(figure).zfill(3)+'.png'))
                cnt = cnt+1
                
if __name__ == '__main__':
    model_path = r'\model.hdf5'
    ori_root_path = r'\test_data_root_path' 
    save_root_path = r'\CAM_save_root_path'
    last_conv_layer_name = 'conv3d_9' # 最后一个卷积层的名字
    # 如果希望CAM存储文件夹的结构与test data文件夹的结构相同,可以使用文件夹名字替换递归生成相应的文件夹
    data_folder_name = 'test_data_root_folder_name'
    save_folder_name = 'CAM_save_root_folder'
    num_workers = 8
    batch_size = 32
    # 如果训练时使用了自定义的模块需要声明custom_objects,否则存储的模型无法正确载入
    custom_objects={"DiceBCELoss":DiceBCELoss}
    grad_cam_multi_worker(model_path=model_path,save_root_path=save_root_path,ori_root_path=ori_root_path,data_folder_name=data_folder_name,save_folder_name=save_folder_name,last_conv_layer_name=last_conv_layer_name,num_workers=num_workers,batch_size=batch_size,custom_objects=custom_objects)

2. 读取图像和叠加热图的子程序模块

class Dataset(object):
    def __init__(self,path_list):
        self.path_list = path_list

    def __getitem__(self, index): # 使用 multi-workers,每次传进来的index是乱序的
        oriImg_path = self.path_list[index]
        img_array = get_img_array(oriImg_path)+1024 
        img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min() + 1e-7)
        return [img_array, index] # 将index同载入的图像一起传出,供下面DatasetAddHeatMap使用。目的是将对应的热图叠加到正确的原始图像上

class DatasetAddHeatMap(object):
    def __init__(self,cam,ori_img,path_list,indexes_batch):
        self.path_list = path_list
        self.cam = cam
        self.ori_img = ori_img
        self.indexes_batch = indexes_batch
        self.gradcam = gradcam

    def __getitem__(self, index):
        idx = int(np.argwhere(self.indexes_batch==index))
        superimposed_img_1 = self.gradcam(img_path=self.path_list[index], heatmap=self.cam[0][idx,:], ori_img=self.ori_img[idx,:], alpha=0.4)
        superimposed_img_2 = self.gradcam(img_path=self.path_list[index], heatmap=self.cam[1][idx,:], ori_img=self.ori_img[idx,:], alpha=0.4)
        return [[superimposed_img_1,superimposed_img_2], index]

3. 生成热图的子程序模块

def make_gradcam_heatmap(img_array, model, last_conv_layer_name):
    grad_model_1 = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output[0]]
    )
    grad_model_2 = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output[1]]
    )

    with tf.GradientTape() as tape:
        last_conv_layer_output, preds_1 = grad_model_1({'data':img_array})
    grads_1 = tape.gradient(preds_1, last_conv_layer_output) # 1*128*128*2*64

    with tf.GradientTape() as tape:
        last_conv_layer_output, preds_2 = grad_model_2({'data':img_array})
    grads_2 = tape.gradient(preds_2, last_conv_layer_output) # 1*128*128*2*64

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads_1 = tf.reduce_mean(grads_1, axis=(0, 1, 2, 3)) # (64,)
    pooled_grads_2 = tf.reduce_mean(grads_2, axis=(0, 1, 2, 3)) # (64,)

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    # last_conv_layer_output = last_conv_layer_output[:,:,:,0,:]
    heatmap_1 = last_conv_layer_output @ pooled_grads_1[..., tf.newaxis] # 1*128*128*2
    heatmap_1 = tf.squeeze(heatmap_1) # 128*128*2
    heatmap_1 = tf.reduce_mean(heatmap_1,axis=-1) # 128*128
    heatmap_2 = last_conv_layer_output @ pooled_grads_2[..., tf.newaxis] # 1*128*128*2
    heatmap_2 = tf.squeeze(heatmap_2) # 128*128*2
    heatmap_2 = tf.reduce_mean(heatmap_2,axis=-1) # 128*128

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap_1 = tf.maximum(heatmap_1, 0) / tf.math.reduce_max(heatmap_1)
    heatmap_2 = tf.maximum(heatmap_2, 0) / tf.math.reduce_max(heatmap_2)
    return heatmap_1.numpy(), heatmap_2.numpy()

References

https://keras.io/examples/vision/grad_cam/

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-05-06 11:01:43  更:2022-05-06 11:01:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 8:31:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计