目前大部分最先进的(SOTA)深度学习技术虽然效果好,但由于其模型参数量和计算量过高,难以用于实际部署。而众所周知,生物神经网络使用高效的稀疏连接(生物大脑神经网络balabala啥的都是稀疏连接的),考虑到这一点,为了减少内存、容量和硬件消耗,同时又不牺牲模型预测的精度,在设备上部署轻量级模型,并通过私有的设备上计算以保证隐私,通过减少参数数量来压缩模型的最佳技术非常重要。
稀疏神经网络在预测精度方面可以达到密集神经网络的水平,但由于模型参数量小,理论上来讲推理速度也会快很多。而模型剪枝是一种将密集神经网络训练成稀疏神经网络的方法。
本文将通过学习官方示例教程,介绍如何通过一个简单的实例教程来进行模型剪枝,实践深度学习模型压缩加速。
相关链接
深度学习模型压缩与加速技术(一):参数剪枝
PyTorch模型剪枝实例教程一、非结构化剪枝
PyTorch模型剪枝实例教程二、结构化剪枝
PyTorch模型剪枝实例教程三、多参数与全局剪枝
通过教程一和教程二,我们可以了解如何通过PyTorch进行非结构化和结构化的剪枝,一般而言,我们会考虑将较深的网络进行参数剪枝,此时,通过一个个检查模块诶个给它们剪枝就比较麻烦,我们可以利用多参数和全局剪枝的方法对同类型参数进行剪枝。
1.导包&定义一个简单的网络
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''搭建类LeNet网络'''
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, 5)
self.conv2 = nn.Conv2d(3, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
2.多参数剪枝
new_model = LeNet()
for name, module in new_model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())
输出:
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
3.全局剪枝
前面所有提到的方法,都是局部剪枝方法,我们还可以使用全局剪枝方法,通过删除整个模型最低的20%的连接,而非删除每个层中最低20%的连接,也就是说,可能会出现层与层之间删除的百分比不一样的情况。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
print(
"稀疏性 in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"稀疏性 in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"稀疏性 in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"稀疏性 in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"稀疏性 in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"全局稀疏性: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
输出
稀疏性 in conv1.weight: 8.00%
稀疏性 in conv2.weight: 9.33%
稀疏性 in fc1.weight: 22.07%
稀疏性 in fc2.weight: 12.20%
稀疏性 in fc3.weight: 11.31%
全局稀疏性: 20.00%
4.总结
本示例首先搭建了一个类LeNet网络模型,为了进行多参数剪枝,我们使用.named_modules()遍历了所有层,并利用isinstance()方法判断是否为Conv2d或Linear结构,以此来对相同结构参数进行同等类型剪枝。为了进行全局剪枝,我们使用了 .global_unstructured参数进行剪枝,可以发现,全局剪枝与多参数剪枝不一样的地方在于,全局剪枝最终的稀疏性虽然和多参数剪枝稀疏性相同,但全局剪枝稀疏性并非对每层均等稀疏的。
本文用到的核心函数方法:
- .named_modules(),获取模型的参数名和结构
- isinstance(),判断类型是否一致
- .global_unstructured,全局剪枝方法
参考:
Torch官方剪枝教程
|