IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 模型剪枝-ICCV2017-ThiNet: A Filter Level Pruning Method -> 正文阅读

[人工智能]模型剪枝-ICCV2017-ThiNet: A Filter Level Pruning Method

卷积核层面的模型剪枝

本质方法。就是计算每一个filter上权重绝对值之和,去掉m个权重最小的filters,并同时去掉与其相关的特征图及下一层的所有相关的输入filters;

在这里插入图片描述
筛选需要裁减的卷积核步骤为:
1.对每个Filter,使用L1-Norm来计算每一个filter上权重绝对值之和;
2.对所得权重之和进行排序,和大小反映了相关filter的重要性;
3.选择前k个较大的权重之和保留,建立一个mask,大保留的部分为1,小于阈值的部分为0。

cfg_mask = []
layer_id = 0 # 统计层数
for m in model.modules(): # 遍历vgg的每个module
    if isinstance(m, nn.Conv2d): # 如果发现卷积层
        out_channels = m.weight.data.shape[0]

        # cfg[layer_id]: 每一层要保留的通道数量
        if out_channels == cfg[layer_id]:
            # 如果这一层的通道数已经满足,直接进入下一层循环
            cfg_mask.append(torch.ones(out_channels))
            layer_id += 1
            continue

        # 克隆所有卷积层的权重
        weight_copy = m.weight.data.abs().clone()
        weight_copy = weight_copy.cpu().numpy()

        # weight_copy: [c_out, c_in, kernal, kernal]
        # L1_norm : [c_out]
        L1_norm = np.sum(weight_copy, axis=(1, 2, 3))

        # arg_max为从大到小排序后的下标
        arg_max = np.argsort(L1_norm)[::-1]

        # 取前cfg[layer_id]个较大值
        arg_max_rev = arg_max[:cfg[layer_id]]
        assert arg_max_rev.size == cfg[layer_id], "size of arg_max_rev not correct"
        
        # 删除的通道mask=0,保留的通道mask=1
        mask = torch.zeros(out_channels)
        mask[arg_max_rev.tolist()] = 1

        # 记录每个卷积层保留的权重
        cfg_mask.append(mask)
        layer_id += 1

    elif isinstance(m, nn.MaxPool2d):
        layer_id += 1

之后需要进行BN2D层的剪枝,即需要丢弃刚才被抛弃的卷积核。

start_mask = torch.ones(3)
layer_id_in_cfg = 0
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):

    # 对BN2d层进行剪枝
    if isinstance(m0, nn.BatchNorm2d):

        # 获取大于0的所有数据的索引,使用squeeze变成向量
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx1.size == 1:
            idx1 = np.resize(idx1,(1,))

        # 用经过剪枝后的层参数的替换原来的
        # [c]
        m1.weight.data = m0.weight.data[idx1.tolist()].clone()
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()

        # 下一层
        layer_id_in_cfg += 1

        # 当前在处理的层的mask
        start_mask = end_mask

        # 全连接层不做处理
        if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
            end_mask = cfg_mask[layer_id_in_cfg]

最后需要进行卷积层剪枝,根据前后BN层的保留层,可以计算得到卷积层保留的卷积核大小(上层BN层输出,下层BN层输入),保留前后BN的对应保留的元素,其余剪枝。

   # 对卷积层进行剪枝
    elif isinstance(m0, nn.Conv2d):
        # 卷积后面会接bn
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        # 剪枝
        # [c_out, c_in, kernal, kernal]
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
        w1 = w1[idx1.tolist(), :, :, :].clone()
        m1.weight.data = w1.clone()

最后对FC层进行剪枝,由于最后一层FC层的输出是固定的(分类类数),因此只对FC层的输入维度进行剪枝,也是根据上一层BN层的输出,对应保留的元素,其余剪枝。

    # 对全连接层进行剪枝
    elif isinstance(m0, nn.Linear):

        # 最后一层全连接层进行剪枝
        if layer_id_in_cfg == len(cfg_mask):
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask[-1].cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            # [c_out, c_in]
            m1.weight.data = m0.weight.data[:, idx0].clone()
            m1.bias.data = m0.bias.data.clone()
            layer_id_in_cfg += 1
            continue

        # 其余全连接层不剪枝
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()

对BN1d层不剪枝

    # 对BN1d层不进行剪枝,直接使用原始模型参数
    elif isinstance(m0, nn.BatchNorm1d):
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()
        m1.running_mean = m0.running_mean.clone()
        m1.running_var = m0.running_var.clone()

Thinet核心思路

Filters Pruning在卷积核的层面进行剪枝,上述思路filter的修剪取决于当前层,Thinet则选择使用下一层的输出进行卷积核的修剪依据。思想就是如果某层输入数据中的一部分就可以得到与全部输入非常近似的结果,那么就可以将输入数据中其他部分去掉,同时其对应的前面层的filter也就可以去掉。
在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-30 15:37:27  更:2021-11-30 15:38:27 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 2:36:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码