IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PyTorch 模型剪枝实例教程三、多参数与全局剪枝 -> 正文阅读

[人工智能]PyTorch 模型剪枝实例教程三、多参数与全局剪枝

目前大部分最先进的(SOTA)深度学习技术虽然效果好,但由于其模型参数量和计算量过高,难以用于实际部署。而众所周知,生物神经网络使用高效的稀疏连接(生物大脑神经网络balabala啥的都是稀疏连接的),考虑到这一点,为了减少内存、容量和硬件消耗,同时又不牺牲模型预测的精度,在设备上部署轻量级模型,并通过私有的设备上计算以保证隐私,通过减少参数数量来压缩模型的最佳技术非常重要

稀疏神经网络在预测精度方面可以达到密集神经网络的水平,但由于模型参数量小,理论上来讲推理速度也会快很多。而模型剪枝是一种将密集神经网络训练成稀疏神经网络的方法。

本文将通过学习官方示例教程,介绍如何通过一个简单的实例教程来进行模型剪枝,实践深度学习模型压缩加速。

相关链接

深度学习模型压缩与加速技术(一):参数剪枝

PyTorch模型剪枝实例教程一、非结构化剪枝

PyTorch模型剪枝实例教程二、结构化剪枝

PyTorch模型剪枝实例教程三、多参数与全局剪枝

通过教程一和教程二,我们可以了解如何通过PyTorch进行非结构化和结构化的剪枝,一般而言,我们会考虑将较深的网络进行参数剪枝,此时,通过一个个检查模块诶个给它们剪枝就比较麻烦,我们可以利用多参数和全局剪枝的方法对同类型参数进行剪枝。

1.导包&定义一个简单的网络

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''搭建类LeNet网络'''
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 单通道图像输入,5×5核尺寸
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.conv2 = nn.Conv2d(3, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

2.多参数剪枝

new_model = LeNet()
for name, module in new_model.named_modules():
    # 对所有Conv2d的参数进行20%的L1非结构化剪枝
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # 对所有Linear的参数进行20%的L1非结构化剪枝
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())  # 验证一下下

输出:

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

3.全局剪枝

前面所有提到的方法,都是局部剪枝方法,我们还可以使用全局剪枝方法,通过删除整个模型最低的20%的连接,而非删除每个层中最低20%的连接,也就是说,可能会出现层与层之间删除的百分比不一样的情况。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

print(
    "稀疏性 in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "稀疏性 in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "稀疏性 in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "稀疏性 in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "稀疏性 in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "全局稀疏性: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

输出

稀疏性 in conv1.weight: 8.00%
稀疏性 in conv2.weight: 9.33%
稀疏性 in fc1.weight: 22.07%
稀疏性 in fc2.weight: 12.20%
稀疏性 in fc3.weight: 11.31%
全局稀疏性: 20.00%

4.总结

本示例首先搭建了一个类LeNet网络模型,为了进行多参数剪枝,我们使用.named_modules()遍历了所有层,并利用isinstance()方法判断是否为Conv2d或Linear结构,以此来对相同结构参数进行同等类型剪枝。为了进行全局剪枝,我们使用了 .global_unstructured参数进行剪枝,可以发现,全局剪枝与多参数剪枝不一样的地方在于,全局剪枝最终的稀疏性虽然和多参数剪枝稀疏性相同,但全局剪枝稀疏性并非对每层均等稀疏的。

本文用到的核心函数方法:

  • .named_modules(),获取模型的参数名和结构
  • isinstance(),判断类型是否一致
  • .global_unstructured,全局剪枝方法

参考:

Torch官方剪枝教程

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-23 10:50:39  更:2022-04-23 10:51:21 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 10:01:19-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码