IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 目标检测 YOLOv5 - 卷积层和BN层的融合 -> 正文阅读

[人工智能]目标检测 YOLOv5 - 卷积层和BN层的融合

目标检测 YOLOv5 - 卷积层和BN层的融合

即Conv2d和 BatchNorm2d融合

flyfish

为了减少模型推理时间,YOLOv5源码中attempt_load已经包括两层的合并,主要用在推理和导出模型提供给其他平台推理使用时。

函数调用链条是

attempt_load -》fuse-》fuse_conv_and_bn

fuse_conv_and_bn函数就是Conv2d层和BatchNorm2d层的合并。

在模型训练完成后,YOLOv5在推理阶段和导出模型时,将卷积层和BN层进行融合。

融合过程

(1)卷积层公式简写如下:
x i = w ? x i ? 1 + b x_{i} = w * x_{i-1} + b xi?=w?xi?1?+b

(2)Batch Normalization层公式简写如下
这里简单复述稍微详细一点看介绍BatchNorm2d
μ B ← 1 m ∑ i = 1 m x i \mu_{\mathcal{B}} \leftarrow \frac{1}{m} \sum_{i=1}^m x_i μB?m1?i=1m?xi?
σ B 2 ← 1 m ∑ i = 1 m ( x i ? μ B ) 2 \sigma_{\mathcal{B}}^2 \leftarrow \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\mathcal{B}})^2 σB2?m1?i=1m?(xi??μB?)2
x ^ i ← x i ? μ B σ B 2 + ? \hat{x}_i \leftarrow \dfrac{x_i-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2+\epsilon}} x^i?σB2?+? ?xi??μB??
在这里插入图片描述

y i = x i ? μ σ 2 + ? ? γ + β y_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta yi?=σ2+? ?xi??μ??γ+β

4个参数

γ \gamma γ β \beta β 在训练时,是可学习参数。
推理阶段均值和方差是整个训练集的均值和方差。假如硬件资源无限均值和方差训练时全部储存,最后计算推理时所用的均值和方差。硬件资源有限,训练时mini-batch的均值和方差不需要全部储存,太多了不好算,是通过滑动平均计算出整个训练集的均值和方差 。
训练完成后这四个值都是固定值即常量。
(3)将卷积层的式子带入到 BN 层的式子就是融合
y i = x i ? μ σ 2 + ? ? γ + β = w ? x i + b ? μ σ 2 + ? ? γ + β = w ? γ σ 2 + ? ? x + ( b ? μ σ 2 + ? ? γ + β ) \begin{aligned} &y_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \\ &= \frac{w * x_i + b - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \\ &= \frac{w * \gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot x + (\frac{b-\mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta) \end{aligned} ?yi?=σ2+? ?xi??μ??γ+β=σ2+? ?w?xi?+b?μ??γ+β=σ2+? ?w?γ??x+(σ2+? ?b?μ??γ+β)?

融合后

上述式子可以简单看成直线方程
y = k x + b y=kx+b y=kx+b为了防止跟上面式子字母重复换一个字母
y = k 1 x + k 2 y=k_1x+k_2 y=k1?x+k2?
k 1 = w ? γ σ 2 + ? k_1= \frac{w * \gamma}{\sqrt{\sigma^2 + \epsilon}} k1?=σ2+? ?w?γ?
k 2 = b ? μ σ 2 + ? + β k_2= \frac{b-\mu}{\sqrt{\sigma^2 + \epsilon}} + \beta k2?=σ2+? ?b?μ?+β
w n e w = w ? γ σ 2 + ? b n e w = ( b ? μ σ 2 + ? ? γ + β ) \begin{aligned} & w_{new} = \frac{w * \gamma}{\sqrt{\sigma^2 + \epsilon}} \\ & b_{new} = (\frac{b-\mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta) \end{aligned} ?wnew?=σ2+? ?w?γ?bnew?=(σ2+? ?b?μ??γ+β)?

YOLOv5版本的融合代码实现

import torch
import torchvision
def fuse_conv_and_bn(conv, bn):
	#
	# init
	fusedconv = torch.nn.Conv2d(
		conv.in_channels,
		conv.out_channels,
		kernel_size=conv.kernel_size,
		stride=conv.stride,
		padding=conv.padding,
		bias=True
	)
	#
	# prepare filters
	print("conv.out_channels:",conv.out_channels)
	w_conv = conv.weight.clone().view(conv.out_channels, -1)
	print("w_conv:",w_conv.shape)
	w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
	print("w_bn:", w_bn.shape)
	fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()) )
	#
	# prepare spatial bias
	if conv.bias is not None:
		b_conv = conv.bias
	else:
		b_conv = torch.zeros( conv.weight.size(0) )
	b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
	fusedconv.bias.copy_( b_conv + b_bn )
	#
	# we're done
	return fusedconv

#以下代码片段在ResNet18的前两层测试上述函数:

测试代码

torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
rn18 = torchvision.models.resnet18(pretrained=True)
rn18.eval()
net = torch.nn.Sequential(
	rn18.conv1,
	rn18.bn1
)
y1 = net.forward(x)
fusedconv = fuse_conv_and_bn(net[0], net[1])
y2 = fusedconv.forward(x)
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)

测试结果

conv.out_channels: 64
w_conv: torch.Size([64, 147])
w_bn: torch.Size([64, 64])
error: 0.00000030

PyTorch版本的融合代码实现

import copy
import torch

def fuse_conv_bn_eval(conv, bn):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)

    return fused_conv

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)

我们自己添加如下代码测试

torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
rn18 = torchvision.models.resnet18(pretrained=True)
rn18.eval()
net = torch.nn.Sequential(
	rn18.conv1,
	rn18.bn1
)
y1 = net.forward(x)
fusedconv = fuse_conv_bn_eval(net[0], net[1])
y2 = fusedconv.forward(x)
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)

输出结果

error: 0.00000030

YOLOv5版本的两层融合与 PyTorch提供的两层融合结果是一模一样的。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-05 10:51:48  更:2021-09-05 10:51:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 19:44:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码