IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 目标检测 YOLOv5 - 模型压缩 -> 正文阅读

[人工智能]目标检测 YOLOv5 - 模型压缩

目标检测 YOLOv5 - 模型压缩

flyfish

1 什么是剪枝

YOLOv5自带的模型压缩是怎样的呢?就是剪枝。
在一棵树中,把不重要的枝条剪掉,就是剪枝

在这里插入图片描述
园丁的手艺是不同的,不同的园丁剪的效果不同。
在这里插入图片描述

做模型的剪枝与园丁干得工作是一模一样,先看一个回归实例
在这里插入图片描述

拟合数据的结果有正合适,欠拟合,过拟合
直线就是欠拟合,一个每个数据点都经过的曲线就是过拟合了
再看他们的数学表达式,多项式的最高次数是不同的,剪枝就像去掉上图中3次方项和4次方项
在这里插入图片描述

剪枝方法多,因为把不重要的东西去掉,在定义什么东西不重要,各家有各家的方法
如果把项剪的太多了,曲线就变成值线了,这不是我们想要的。
当过拟合的时候,剪太多就把树枝剪的没剩几根,了就变成欠拟合;如果已经拟合的很合适了,再剪也欠拟合了。
期望是精度和召回率都不降低,降低的只有计算量

剪枝的生物学启示

在这里插入图片描述

深度学习的剪枝被认为是人脑中突触剪枝的一个想法,在人脑中之间发生的突触消除。修剪突触从出生时开始,一直持续到20多岁左右。

看一个神经网络
在这里插入图片描述

圈与圈之间的连线就是权重,权重也就是一堆堆的数

在这里插入图片描述
剪节点的可以叫 pruning node 或者 pruning neurons 剪神经元
剪线的可以叫 pruning connections 或者叫 pruning synapses 剪突触,剪权重

在这里插入图片描述
在这里插入图片描述

剪的结果
左图是原始权重矩阵,右图是阈值为0.1的修剪后的矩阵。高亮显示的权重将被删除或者置零。
在这里插入图片描述

weight剪枝的实现

怎么实现呢
在这里插入图片描述

原始的矩阵 * weight_mask = 新的矩阵
这样知道了代码里的weight_mask 和 bias_mask是个什么意思

彩票假说:寻找稀疏的、可训练的神经网络
彩票假说简单说就是是主要是随机初始化的密集神经网络包含一个初始化的子网,通过随机初始化权重的子网络仍可以达到原始网络的精度.

剪枝有unstructured和structured,这两者有什么区别
在这里插入图片描述

非结构化(左)和结构化(右)剪枝的区别:结构化剪枝去除卷积滤波器和内核行,而不仅仅是剪枝连接。
结构化剪枝和非结构化剪枝的主要区别在于剪枝权重的粒度。
非结构化剪枝主要是对单个权重进行裁剪
结构化剪枝的粒度较大,主要是把整行整列的权重移除掉(即把一个神经元去掉),对channel和Filter维度进行裁剪。
可以反过来说因为对单个权重进行裁剪的结果是unstructured,看上去有点乱。对channel和Filter裁剪的结果是structured的,根据裁剪之后是否保持了structe起了一个名字

YOLOv5的剪枝是怎么操作的

YOLOv5提供的一段剪枝代码

def sparsity(model):
    # Return global model sparsity
    a, b = 0., 0.
    for p in model.parameters():
        a += p.numel()
        b += (p == 0).sum()
    return b / a
def prune(model, amount=0.3):
    # Prune model to requested global sparsity
    import torch.nn.utils.prune as prune
    print('Pruning model... ', end='')
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.l1_unstructured(m, name='weight', amount=amount)  # prune
            prune.remove(m, 'weight')  # make permanent
    print(' %.3g global sparsity' % sparsity(model))

写一段代码把YOLOv5的剪枝代码用上去,查看剪枝前和剪枝后的区别

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def sparsity(model):
    # Return global model sparsity
    a, b = 0., 0.
    for p in model.parameters():
        a += p.numel()
        b += (p == 0).sum()
    return b / a
def prune(model, amount=0.3):
    # Prune model to requested global sparsity
    import torch.nn.utils.prune as prune
    print('Pruning model... ', end='')
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.l1_unstructured(m, name='weight', amount=amount)  # prune
            prune.remove(m, 'weight')  # make permanent
    print(' %.3g global sparsity' % sparsity(model))
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)
module = model.conv1
print(list(module.named_parameters()))
prune(module, amount=0.3)
print(list(module.named_parameters()))

剪枝前

[('weight', Parameter containing:
tensor([[[[-0.2032, -0.0269, -0.0981],
          [-0.1920, -0.2737,  0.2451],
          [ 0.1116,  0.1331,  0.0147]]],
...
        [[[ 0.3109,  0.0082, -0.0080],
          [-0.3009, -0.0805, -0.0308],
          [-0.0347, -0.2851,  0.1614]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0874,  0.2916,  0.2522,  0.2425, -0.2085,  0.2855], device='cuda:0',
       requires_grad=True))]

剪枝后

Pruning model...  0.267 global sparsity

[('bias', Parameter containing:
tensor([-0.0874,  0.2916,  0.2522,  0.2425, -0.2085,  0.2855], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.2032, -0.0000, -0.0981],
          [-0.1920, -0.2737,  0.2451],
          [ 0.1116,  0.1331,  0.0000]]],
...
        [[[ 0.3109,  0.0000, -0.0000],
          [-0.3009, -0.0805, -0.0000],
          [-0.0000, -0.2851,  0.1614]]]], device='cuda:0', requires_grad=True))]

我们看到不重要的权重变成了0

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-22 10:56:24  更:2021-10-22 10:57:46 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 8:20:37-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码