下面借这篇blog记录一些阅读笔记,如果问题,恳请指出。
1. Introduction
paper:Focal Loss for Dense Object Detection
在RetinaNet出现之前,two-step检测网络(代表有Faster RCNN系列等等)的准确率一直要比one-step检测网络(代表有SSD系列、yolo系列)的准确率要高,但是在RetinaNet出现之后,one-step检测网络首次精度要比two-step检测网络要好。为此,two-step检测网络的速度与精度已均不及one-step检测网络。在当时,RetinaNet的精度要比所有的当前的two-step检测网络模型均要好。 
paper指出,导致one-step检测网络精度不高的主要原因是正负样本的不平衡。为此,paper对于分类损失提出了一种新的损失函数:Focal Loss。用以解决正负样本不均衡的问题,本质上是减小占绝大多数的简单负样本的权重,而增加权重在一些分类错误的样本上。
按原话来说,就是通过重塑标准交叉熵损失来解决这种类别不平衡,已此来降低了分类良好的例子的损失。Focal Loss在训练时集中与一些稀疏的硬例子(hard examples)上,并防止训练过程中这些难点被大量的简单负样本吞噬。也就是避免这些稀疏错误被忽视。
为了证实Focal Loss的有效性,在paper中只是简单的搭建了一个one-steps的目标检测网络,其借鉴与FPN与SSD的一些思想,同样是在多个预测特征层上生成一些列的anchors,然后进行筛选再通过分类器预测类别与边界框回归参数。为了进一步区分几个经典的模型,在本文中会区分一下SSD,Faster RCNN,RetinaNet之间的差异。
1.1 Class Imbalance
paper中谈讨了为什么当前的two-step检测网络要比one-step检测网络的精度要好。关键就在于two-step检测网络一定程度上解决了这个类别不均衡问题,而one-step检测网络依然存在。
Class Imbalance会导致两个问题: 1)导致训练效率低,因为很多生成的位置都是一些对网络学习没有作用的简单负样本 2)大量的简单负样本会影响训练,导致忽视一些难点,从而而导致模型的退化
对于two-step检测网络: 第一步,RPN网络会生成一些列的候选框,然后判断其为前景还是背景已经筛选出了绝大多数简单的负样本。它将几乎无限多的可能对象位置减少到一两千个。重要的是,所选的建议不是随机的,而是可能对应于真实的物体位置,这消除了绝大多数容易的负面影响。 第二步,将PRN筛选后的候选框,执行采样试探法,例如固定的前景-背景比(1:3),或在线硬示例挖掘,进一步通过置信度来筛选,这就解决了正负样本严重不平衡的问题。这个比率就像是通过采样实现的隐式α平衡因子。
而在one-step检测网络中,应用于对象位置、比例和纵横比的规则密集采样。 对,于在预测特征层中生成的一些anchor,其中觉得多数都是一些无关的负样本,如果让这些无关重要的负样本对损失产生较小的影响,就是Focal Loss所做的事情。
1.2 Robust Estimation
人们对设计稳健的损失函数(如Huber loss)很感兴趣,这种损失函数通过对误差较大的例子(hard examples)的损失进行加权来减少输出的贡献。
相比之下,focal loss不是解决异常值,而是通过降低简单的负样本的权重来解决类不平衡,这样即使它们的数量很大,它们对总损失的贡献也很小。换句话说,focal loss的作用与robust loss相反:将训练集中在一组稀疏的硬例子上。
也就是说,focal loss专注与难点与错误点。
2. Focal Loss
在paper中,提出了一个新的损失函数Focal Loss,作为一个更有效方法来替代以前的方法来处理正负样本不平衡问题。Focal Loss损失函数是一个动态缩放的交叉熵损失,当正确类别的置信度增加时,缩放因子衰减到零,见图1。  直观地说,这个比例因子可以在训练过程中自动降低简单示例的权重,并快速将模型集中在硬示例上。实验表明,focal loss能够训练一个高精度的one-step检测模型,该检测器明显优于使用采样试探法或硬示例挖掘的替代训练(其中硬示例挖掘是目前训练one-step检测器的最先进技术)
2.1 cross entropy loss
首先看看二值交叉熵损失:  交叉熵损失存在的一个问题是,即使是一些比较容易分类的例子,比如confidence < 0.5,当通过大量简单的例子进行总结时,这些小的损失值可以压倒稀有类别。
2.2 focal loss
以下内容截取至我之前的一篇blog:YOLOv4中的tricks概念总结——Bag of freebies
Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
主要过程就是将又来的二分类交叉熵损失函数:  替换为:  其中:gamma>0使得减少易分类样本的损失。使得更关注于困难的、错分的样本。例如gamma为2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的gamma次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。这样减少了简单样本的影响,大量预测概率很小的样本叠加起来后的效应才可能比较有效。  此外,加入平衡因子alpha,用来平衡正负样本本身的比例不均:文中alpha取0.25,即正样本要比负样本占比小,这是因为负例易分。
所以,最后的公式为:
F
L
(
p
t
)
=
?
α
t
(
1
?
p
t
)
γ
l
o
g
(
p
t
)
FL(p_{t}) = -α_{t}(1 - p_{t})^{γ}log(p_{t})
FL(pt?)=?αt?(1?pt?)γlog(pt?)
这种形式比非α平衡形式的精度略有提高。
2.3 focal loss*
focal loss的确切形式并不重要。我们现在展示了focal loss的另一个实例,它具有相似的性质并产生了相似的结果。  公式:
F
L
?
=
?
l
o
g
(
σ
(
γ
x
t
+
β
)
)
/
γ
FL* = -log(\sigma(γx_{t} + β))/γ
FL?=?log(σ(γxt?+β))/γ 其中:
x
t
=
y
x
x_{t} = yx
xt?=yx
FL有两个参数γ和β,它们控制损耗曲线的陡度和偏移。可以看出,像FL一样,具有所选参数的FL* 减少了分配给分类良好的示例的损失。我们使用与之前相同的设置训练了RetinaNet,但是我们用选定的参数将FL换成了FL *。这些模型实现了与用FL训练的模型几乎相同的AP。 
3. RetinaNet Detector
接下来介绍RetinaNet 的网络结构。
RetinaNet网络是一个单一的统一网络,由一个主干网络和两个特定任务的子网组成。主干负责计算整个输入图像上的卷积特征图,是一个独立的卷积网络。第一个子网对主干网的输出执行卷积对象分类;第二个子网执行卷积bounding box回归。 
3.1 Feature Pyramid Network Backbone
RetinaNet采用特征金字塔网络(FPN)作为支持网络的主干网络。简而言之,FPN通过自上而下的路径和横向连接扩展了标准卷积网络,因此该网络可以有效地从单个分辨率输入图像构建丰富的多尺度特征金字塔,参见图3(a)-(b)。金字塔的每一层都可以用来探测不同尺度的物体。
RetinaNet中使用的特征金字塔级别P3 to P7,其中P3 to P7是使用自上而下和横向连接从相应的ResNet剩余级(C3到C5)的输出计算的。其中P6是通过C5上的3×3步长-2 conv获得的,P7是通过应用ReLU然后是P6上的3×3步长-2 conv计算的。
但是在RetinaNet中修改了几点: 1)出于计算原因,RetinaNet不使用高分辨率金字塔级别的P2 2)P6,P7是通过交错卷积而不是下采样计算的 3)增加了P7以提高大对象检测。
通过观察一下图可以很容易了解:  (图片来源:https://www.bilibili.com/video/BV1Q54y1L7sM)
3.2 Anchors
在RetinaNet使用类似于中RPN变体的平移不变anchor。在每个金字塔级别(P3-P7),我们使用三种长宽比{1:2,1:1,2:1}的锚。使用了更密集的anchor规模覆盖,在每个级别,我们在原始的3纵横比锚点集合中添加大小为{
2
0
,
2
1
/
3
,
2
2
/
3
2^{0},2^{1/3},2^{2/3}
20,21/3,22/3}的锚点。也就是对于每一个预测特征点会有9个anchor进行预测。覆盖了相对于网络输入图像的32 - 813像素的比例范围。参数如图所示:  (图片来源:https://www.bilibili.com/video/BV1Q54y1L7sM)
每一个anchor都被分配了一个长度为K的分类目标的One-Hot编码,以及4个边界框回归参数。其中K是对象类的数量。
其中正负样本的设置如下: 1)IoU > 0.5:设置为正样本,分配一个ground truth对象 2)0 < IoU < 0.4:设置为负样本 3)0.4 < IoU < 0.5:丢弃
由于每个anchor最多被分配给一个对象框,因此我们将它的长度为K的标签向量中的相应类别设置为1,将所有其他类别设置为0。如果某个anchor未被分配,这可能会在[0.4,0.5]中发生重叠,则在训练过程中会被忽略。边界框回归目标计算为每个anchor与其分配的对象框之间的偏移,如果没有分配,则省略。
3.3 Classification Subnet
对于每层预测特征图上的每一个特征点上,所生成的9个anchor(不同比例不同大小),都会进行预测K个类别。
这个分类子网是一个附属于每个FPN级的小FCN;该子网的参数在所有金字塔级别共享。但是不与box回归子网共享参数
分类子网应用四个3×3 conv层,每个层带有C过滤器,每个层后面跟着ReLU激活,然后是带有KA过滤器的3×3 conv层。最后,附加sigmoid激活来输出每个空间位置的KA二进制预测,见图3 ?。在大多数实验中,我们使用C = 256和A = 9。
也就是对于每个anchor需要预测k个类别概率。
3.4 Box Regression Subnet
边界框回归子网与对象分类子网并行,RetinaNet将另一个小FCN附加到每个金字塔级别,目的是回归从每个anchor到ground true(如果存在)的偏移。边界框回归子网的设计与分类子网相同,只是它终止于每个空间位置的4A线性输出,见图3 (d)。
也就是对于每个anchor需要预测4个边界框回归参数。
对于每个空间位置的A个anchor中的每个anchors,这4个输出预测anchor和ground truth之间的相对偏移(我们使用来自RCNN 的标准框参数化)。
需要注意,这里使用了一个与类无关的边界框回归器,它使用了更少的参数,我们发现同样有效。对象分类子网和盒子回归子网虽然共享一个公共结构,但使用不同的参数。
4. Similarities and differences
回顾SSD与Faster RCNN与RetinaNet之间的差别,链接如下: 目标检测算法——SSD 目标检测算法——Faster R-CNN 目标检测算法——YOLOv1
4.1 Similarities
本质上,SSD,Faster R-CNN,RetinaNet这些目标检测算法本质上均是基于特征矩阵的的每一个特征点,然后根据不同的面积大小与尺度去产生几个anchor,如果使用了FPN结构,则对每一层的预测特征图上的特征点均预测k个anchor,然后再通过正负样本的筛选与非极大值抑制处理,挑选出最后的预测边界框。
4.2 differences
4.2.1 SSD & RetinaNet
而SSD与RetinaNet在结构上可以说是极其的相识了。同样是采用了FPN的思想,同样也是在多个预测特征层上进行anchor的生成。但是在RetinaNet中是分别区分了两个子网来进行分类预测与边界框回归参数预测;而在SSD中,是直接通过卷积操作预测分类与边界框回归参数。也就是说,SSD的两种预测参数是共享的,而RetinaNet是两个子网,参数不共享,只共享backbone的参数。
4.2.2 Faster R-CNN & RetinaNet
而对于Faster R-CNN进行预测时,则是将anchor筛选出来的候选框进行池化为统一大小,再展平通过两个head,一个预测类别,一个预测边界框回归参数,这两者又是分别预测的,与RetinaNet的操作相识。只不过在RetinaNet中进行的是多次的卷积操作进行预测,而在Faster R-CNN中则是通过展平通过全连接层来预测。
4.2.3 YoLo & RetinaNet
这些目标检测算法与yolo系列算法稍有不同,yolo系列算法是将图像均分为SxS份网格,然后直接使用网格去进行预测k个边界框(这里的k为2),然后再预测类别,这是在同一个网络中进行的,没有另外划分两个子网。
所以对于yolov1的缺点还是比较明显的,YOLO给边界框预测强加空间约束,因为每个网格单元只预测两个框和只能有一个类别。这个空间约束限制了我们的模型可以预测的邻近目标的数量。所以才有了v2,v3,v4的改进。
5. RetinaNet Structure code
完整参考代码:bubbliiiing博主代码写得非常清晰 https://github.com/bubbliiiing/retinanet-pytorch
以下截取RetinaNet的模型结构代码:
resnet.py
import math
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
model_urls = {
'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet18(pretrained=False, **kwargs):
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18'], model_dir='model_data'), strict=False)
return model
def resnet34(pretrained=False, **kwargs):
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], model_dir='model_data'), strict=False)
return model
def resnet50(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir='model_data'), strict=False)
return model
def resnet101(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir='model_data'), strict=False)
return model
def resnet152(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'], model_dir='model_data'), strict=False)
return model
retinanet.py
import math
import itertools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
class Anchors(nn.Module):
def __init__(self, anchor_scale=4., pyramid_levels=[3, 4, 5, 6, 7]):
super().__init__()
self.anchor_scale = anchor_scale
self.pyramid_levels = pyramid_levels
self.strides = [2 ** x for x in self.pyramid_levels]
self.scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
self.ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
def forward(self, features):
hs = []
ws = []
for feature in features:
_, _, h, w = feature.size()
hs.append(h)
ws.append(w)
boxes_all = []
for i, stride in enumerate(self.strides):
boxes_level = []
for scale, ratio in itertools.product(self.scales, self.ratios):
base_anchor_size = self.anchor_scale * stride * scale
anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0
anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0
x = np.arange(0, ws[i]) * stride + stride/2
y = np.arange(0, hs[i]) * stride + stride/2
xv, yv = np.meshgrid(x, y)
xv = xv.reshape(-1)
yv = yv.reshape(-1)
boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
yv + anchor_size_y_2, xv + anchor_size_x_2))
boxes = np.swapaxes(boxes, 0, 1)
boxes_level.append(np.expand_dims(boxes, axis=1))
boxes_level = np.concatenate(boxes_level, axis=1)
boxes_all.append(boxes_level.reshape([-1, 4]))
anchor_boxes = np.vstack(boxes_all)
anchor_boxes = torch.from_numpy(anchor_boxes).to(features[0].device)
anchor_boxes = anchor_boxes.unsqueeze(0)
return anchor_boxes
class PyramidFeatures(nn.Module):
def __init__(self, C3_size, C4_size, C5_size, feature_size=256):
super(PyramidFeatures, self).__init__()
self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
self.P6 = nn.Conv2d(C5_size, feature_size, kernel_size=3, stride=2, padding=1)
self.P7_1 = nn.ReLU()
self.P7_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)
def forward(self, inputs):
C3, C4, C5 = inputs
_, _, h4, w4 = C4.size()
_, _, h3, w3 = C3.size()
P3_x = self.P3_1(C3)
P4_x = self.P4_1(C4)
P5_x = self.P5_1(C5)
P5_upsampled_x = F.interpolate(P5_x, size=(h4, w4))
P4_x = P5_upsampled_x + P4_x
P4_upsampled_x = F.interpolate(P4_x, size=(h3, w3))
P3_x = P3_x + P4_upsampled_x
P3_x = self.P3_2(P3_x)
P4_x = self.P4_2(P4_x)
P5_x = self.P5_2(P5_x)
P6_x = self.P6(C5)
P7_x = self.P7_1(P6_x)
P7_x = self.P7_2(P7_x)
return [P3_x, P4_x, P5_x, P6_x, P7_x]
class RegressionModel(nn.Module):
def __init__(self, num_features_in, num_anchors=9, feature_size=256):
super(RegressionModel, self).__init__()
self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
self.act1 = nn.ReLU()
self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
self.act2 = nn.ReLU()
self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
self.act3 = nn.ReLU()
self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
self.act4 = nn.ReLU()
self.output = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, padding=1)
def forward(self, x):
out = self.conv1(x)
out = self.act1(out)
out = self.conv2(out)
out = self.act2(out)
out = self.conv3(out)
out = self.act3(out)
out = self.conv4(out)
out = self.act4(out)
out = self.output(out)
out = out.permute(0, 2, 3, 1)
return out.contiguous().view(out.shape[0], -1, 4)
class ClassificationModel(nn.Module):
def __init__(self, num_features_in, num_anchors=9, num_classes=80, prior=0.01, feature_size=256):
super(ClassificationModel, self).__init__()
self.num_classes = num_classes
self.num_anchors = num_anchors
self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
self.act1 = nn.ReLU()
self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
self.act2 = nn.ReLU()
self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
self.act3 = nn.ReLU()
self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
self.act4 = nn.ReLU()
self.output = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, padding=1)
self.output_act = nn.Sigmoid()
def forward(self, x):
out = self.conv1(x)
out = self.act1(out)
out = self.conv2(out)
out = self.act2(out)
out = self.conv3(out)
out = self.act3(out)
out = self.conv4(out)
out = self.act4(out)
out = self.output(out)
out = self.output_act(out)
out1 = out.permute(0, 2, 3, 1)
batch_size, width, height, channels = out1.shape
out2 = out1.view(batch_size, width, height, self.num_anchors, self.num_classes)
return out2.contiguous().view(x.shape[0], -1, self.num_classes)
class Resnet(nn.Module):
def __init__(self, phi, load_weights=False):
super(Resnet, self).__init__()
self.edition = [resnet18,resnet34,resnet50,resnet101,resnet152]
model = self.edition[phi](load_weights)
del model.avgpool
del model.fc
self.model = model
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
feat1 = self.model.layer2(x)
feat2 = self.model.layer3(feat1)
feat3 = self.model.layer4(feat2)
return [feat1,feat2,feat3]
class Retinanet(nn.Module):
def __init__(self, num_classes, phi, pretrain_weights=False):
super(Retinanet, self).__init__()
self.pretrain_weights = pretrain_weights
self.backbone_net = Resnet(phi,pretrain_weights)
fpn_sizes = {
0: [128, 256, 512],
1: [128, 256, 512],
2: [512, 1024, 2048],
3: [512, 1024, 2048],
4: [512, 1024, 2048],
}[phi]
self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])
self.regressionModel = RegressionModel(256)
self.classificationModel = ClassificationModel(256, num_classes=num_classes)
self.anchors = Anchors()
self._init_weights()
def _init_weights(self):
if not self.pretrain_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
prior = 0.01
self.classificationModel.output.weight.data.fill_(0)
self.classificationModel.output.bias.data.fill_(-math.log((1.0 - prior) / prior))
self.regressionModel.output.weight.data.fill_(0)
self.regressionModel.output.bias.data.fill_(0)
def forward(self, inputs):
p3, p4, p5 = self.backbone_net(inputs)
features = self.fpn([p3, p4, p5])
regression = torch.cat([self.regressionModel(feature) for feature in features], dim=1)
classification = torch.cat([self.classificationModel(feature) for feature in features], dim=1)
anchors = self.anchors(features)
return features, regression, classification, anchors
|