目标检测 YOLOv5 - 卷积层和BN层的融合
即Conv2d和 BatchNorm2d融合
flyfish
为了减少模型推理时间,YOLOv5源码中attempt_load已经包括两层的合并,主要用在推理和导出模型提供给其他平台推理使用时。
函数调用链条是
attempt_load -》fuse-》fuse_conv_and_bn
fuse_conv_and_bn函数就是Conv2d层和BatchNorm2d层的合并。
在模型训练完成后,YOLOv5在推理阶段和导出模型时,将卷积层和BN层进行融合。
融合过程
(1)卷积层公式简写如下:
x
i
=
w
?
x
i
?
1
+
b
x_{i} = w * x_{i-1} + b
xi?=w?xi?1?+b
(2)Batch Normalization层公式简写如下 这里简单复述稍微详细一点看介绍BatchNorm2d
μ
B
←
1
m
∑
i
=
1
m
x
i
\mu_{\mathcal{B}} \leftarrow \frac{1}{m} \sum_{i=1}^m x_i
μB?←m1?i=1∑m?xi?
σ
B
2
←
1
m
∑
i
=
1
m
(
x
i
?
μ
B
)
2
\sigma_{\mathcal{B}}^2 \leftarrow \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\mathcal{B}})^2
σB2?←m1?i=1∑m?(xi??μB?)2
x
^
i
←
x
i
?
μ
B
σ
B
2
+
?
\hat{x}_i \leftarrow \dfrac{x_i-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2+\epsilon}}
x^i?←σB2?+?
?xi??μB??
y
i
=
x
i
?
μ
σ
2
+
?
?
γ
+
β
y_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
yi?=σ2+?
?xi??μ??γ+β
4个参数
γ
\gamma
γ 和
β
\beta
β 在训练时,是可学习参数。 推理阶段均值和方差是整个训练集的均值和方差。假如硬件资源无限均值和方差训练时全部储存,最后计算推理时所用的均值和方差。硬件资源有限,训练时mini-batch的均值和方差不需要全部储存,太多了不好算,是通过滑动平均计算出整个训练集的均值和方差 。 训练完成后这四个值都是固定值即常量。 (3)将卷积层的式子带入到 BN 层的式子就是融合
y
i
=
x
i
?
μ
σ
2
+
?
?
γ
+
β
=
w
?
x
i
+
b
?
μ
σ
2
+
?
?
γ
+
β
=
w
?
γ
σ
2
+
?
?
x
+
(
b
?
μ
σ
2
+
?
?
γ
+
β
)
\begin{aligned} &y_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \\ &= \frac{w * x_i + b - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \\ &= \frac{w * \gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot x + (\frac{b-\mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta) \end{aligned}
?yi?=σ2+?
?xi??μ??γ+β=σ2+?
?w?xi?+b?μ??γ+β=σ2+?
?w?γ??x+(σ2+?
?b?μ??γ+β)?
融合后
上述式子可以简单看成直线方程
y
=
k
x
+
b
y=kx+b
y=kx+b为了防止跟上面式子字母重复换一个字母
y
=
k
1
x
+
k
2
y=k_1x+k_2
y=k1?x+k2?
k
1
=
w
?
γ
σ
2
+
?
k_1= \frac{w * \gamma}{\sqrt{\sigma^2 + \epsilon}}
k1?=σ2+?
?w?γ?
k
2
=
b
?
μ
σ
2
+
?
+
β
k_2= \frac{b-\mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
k2?=σ2+?
?b?μ?+β
w
n
e
w
=
w
?
γ
σ
2
+
?
b
n
e
w
=
(
b
?
μ
σ
2
+
?
?
γ
+
β
)
\begin{aligned} & w_{new} = \frac{w * \gamma}{\sqrt{\sigma^2 + \epsilon}} \\ & b_{new} = (\frac{b-\mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta) \end{aligned}
?wnew?=σ2+?
?w?γ?bnew?=(σ2+?
?b?μ??γ+β)?
YOLOv5版本的融合代码实现
import torch
import torchvision
def fuse_conv_and_bn(conv, bn):
fusedconv = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
print("conv.out_channels:",conv.out_channels)
w_conv = conv.weight.clone().view(conv.out_channels, -1)
print("w_conv:",w_conv.shape)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
print("w_bn:", w_bn.shape)
fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()) )
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros( conv.weight.size(0) )
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_( b_conv + b_bn )
return fusedconv
测试代码
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
rn18 = torchvision.models.resnet18(pretrained=True)
rn18.eval()
net = torch.nn.Sequential(
rn18.conv1,
rn18.bn1
)
y1 = net.forward(x)
fusedconv = fuse_conv_and_bn(net[0], net[1])
y2 = fusedconv.forward(x)
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)
测试结果
conv.out_channels: 64
w_conv: torch.Size([64, 147])
w_bn: torch.Size([64, 64])
error: 0.00000030
PyTorch版本的融合代码实现
import copy
import torch
def fuse_conv_bn_eval(conv, bn):
assert(not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_conv.weight, fused_conv.bias = \
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
return fused_conv
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
我们自己添加如下代码测试
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
rn18 = torchvision.models.resnet18(pretrained=True)
rn18.eval()
net = torch.nn.Sequential(
rn18.conv1,
rn18.bn1
)
y1 = net.forward(x)
fusedconv = fuse_conv_bn_eval(net[0], net[1])
y2 = fusedconv.forward(x)
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)
输出结果
error: 0.00000030
YOLOv5版本的两层融合与 PyTorch提供的两层融合结果是一模一样的。
|