RegionProposalNetwork
在Faster RCNN中第一阶段是由RegionProposalNetwork生成anchors,并通过筛选得到proposal。代码中详细注释了每一部分的过程。
import torch
import torchvision
from torch import nn, Tensor
from torch.nn import functional as F
import math
from typing import Dict
def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
cond = torch.lt(n, beta)
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if size_average:
return loss.mean()
return loss.sum()
def nms(boxes, scores, iou_threshold):
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
NMS iteratively removes lower scoring boxes which have an
IoU greater than iou_threshold with another (higher scoring)
box.
Parameters
----------
boxes : Tensor[N, 4])
boxes to perform NMS on. They
are expected to be in (x1, y1, x2, y2) format
scores : Tensor[N]
scores for each one of the boxes
iou_threshold : float
discards all overlapping
boxes with IoU < iou_threshold
Returns
-------
keep : Tensor
int64 tensor with the indices
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
return torchvision.ops.nms(boxes, scores, iou_threshold)
class RPNHead(nn.Module):
def __init__(self, in_channels, num_anchors):
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
for layer in self.children():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
def forward(self, x):
logits = []
bbox_reg = []
for i, feature in enumerate(x):
output_33 = F.relu(self.conv(feature))
logits.append(self.cls_logits(output_33))
bbox_reg.append(self.bbox_pred(output_33))
return logits, bbox_reg
class AnchorsGenerator(nn.Module):
def __init__(self, sizes=(128, 256, 512), aspect_ratios=(0.5, 1.0, 2.0)):
super(AnchorsGenerator, self).__init__()
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}
def generate_anchors(self, scale, aspect_ratios, dtype=torch.float32, device='cpu'):
scale = torch.as_tensor(scale, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1.0 / h_ratios
ws = (w_ratios[:, None] * scale[None, :]).view(-1)
hs = (h_ratios[:, None] * scale[None, :]).view(-1)
base_anchor = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchor.round()
def set_cell_anchors(self, dtype, device):
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(sizes, aspect_ratios, dtype, device)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)]
self.cell_anchors = cell_anchors
def cached_grid_anchors(self, grid_sizes, strides):
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def grid_anchors(self, grid_sizes, strides):
anchors = []
cell_anchors = self.cell_anchors
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
shifts_x = torch.arange(0, grid_width, dtype=torch.float32, device=device) * stride_width
shifts_y = torch.arange(0, grid_height, dtype=torch.float32, device=device) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1)
shifts_anchor = shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)
anchors.append(shifts_anchor.reshape(-1, 4))
return anchors
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
def forward(self, image_list, feature_maps):
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = []
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
self._cache.clear()
return anchors
def box_area(boxes):
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.
Arguments:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format
Returns:
area (Tensor[N]): area for each box
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2])
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (right_bottom - left_top).clamp(min=0)
inter = wh[:, :, 0] * wh[:, :, 1]
iou = inter / (area1[:, None] + area2 - inter)
return iou
def permute_and_flatten(layer, N, A, C, H, W):
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer
def concat_box_pred_layers(box_cls, box_regression):
box_cls_flattened = []
box_regression_flattened = []
for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4
C = AxC // A
box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
box_regression_flattened.append(box_regression_per_level)
box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
return box_cls, box_regression
def clip_boxes_to_image(boxes, size):
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)
clipped_boxes = torch.cat((boxes_x, boxes_y), dim=1)
return clipped_boxes
def remove_small_boxes(boxes, min_size):
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
keep = keep.nonzero().squeeze(1)
return keep
def batched_nms(boxes, scores, level_idxs, iou_threshold):
if boxes.numel() == 0:
return torch.empty((0, ), dtype=torch.int64, device=boxes.device)
max_coordinate = boxes.max()
offset = level_idxs.to(boxes) * (max_coordinate + 1)
boxes_offset = boxes + offset[:, None]
keep = nms(boxes_offset, scores, iou_threshold)
return keep
def encode_boxes(reference_boxes, anchors, weights):
wx = weights[0]
wy = weights[1]
ww = weights[2]
wh = weights[3]
anchors_x1 = anchors[:, 0].unsqueeze(1)
anchors_y1 = anchors[:, 1].unsqueeze(1)
anchors_x2 = anchors[:, 2].unsqueeze(1)
anchors_y2 = anchors[:, 3].unsqueeze(1)
reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
ex_width = anchors_x2 - anchors_x1
ex_height = anchors_y2 - anchors_y1
ex_ctr_x = anchors_x1 + 0.5 * ex_width
ex_ctr_y = anchors_y1 + 0.5 * ex_height
gt_widths = reference_boxes_x2 - reference_boxes_x1
gt_heights = reference_boxes_y2 - reference_boxes_y1
gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
target_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_width
target_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_height
target_dw = ww * torch.log(gt_widths / ex_width)
target_dh = wh * torch.log(gt_heights / ex_height)
pred_boxes_xmin = target_dx - torch.tensor(0.5, dtype=target_dx.dtype, device=target_dx.device) * target_dw
pred_boxes_ymin = target_dy - torch.tensor(0.5, dtype=target_dy.dtype, device=target_dy.device) * target_dh
pred_boxes_xmax = target_dx + torch.tensor(0.5, dtype=target_dx.dtype, device=target_dx.device) * target_dw
pred_boxes_ymax = target_dy + torch.tensor(0.5, dtype=target_dy.dtype, device=target_dy.device) * target_dh
targets = torch.cat((pred_boxes_xmin, pred_boxes_ymin, pred_boxes_xmax, pred_boxes_ymax), dim=1)
return targets
class box_Coder(object):
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, anchores):
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
anchors = torch.cat(anchores, dim=0)
targets = self.encode_single(reference_boxes, anchors)
return targets.split(boxes_per_image, 0)
def encode_single(self, reference_boxes, anchors):
dtype = reference_boxes.dtype
device = reference_boxes.device
weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
targets = encode_boxes(reference_boxes, anchors, weights)
return targets
def decode_single(self, rel_codes, boxes):
boxes = boxes.to(rel_codes.dtype)
width = boxes[:, 2] - boxes[:, 0]
height = boxes[:, 3] - boxes[:, 1]
center_x = boxes[:, 0] + 0.5 * width
center_y = boxes[:, 1] + 0.5 * height
wx, wy, ww, wh = self.weights
dx = rel_codes[:, 0::4] / wx
dy = rel_codes[:, 1::4] / wy
dw = rel_codes[:, 2::4] / ww
dh = rel_codes[:, 3::4] / wh
dw = torch.clamp(dw, max=self.bbox_xform_clip)
dh = torch.clamp(dh, max=self.bbox_xform_clip)
pred_center_x = dx * width[:, None] + center_x[:, None]
pred_center_y = dy * height[:, None] + center_y[:, None]
pred_w = torch.exp(dw) * width[:, None]
pred_h = torch.exp(dh) * height[:, None]
pred_boxes_xmin = pred_center_x - torch.tensor(0.5, dtype=pred_center_x.dtype, device=pred_w.device) * pred_w
pred_boxes_ymin = pred_center_y - torch.tensor(0.5, dtype=pred_center_x.dtype, device=pred_w.device) * pred_h
pred_boxes_xmax = pred_center_x + torch.tensor(0.5, dtype=pred_center_x.dtype, device=pred_w.device) * pred_w
pred_boxes_ymax = pred_center_y + torch.tensor(0.5, dtype=pred_center_x.dtype, device=pred_w.device) * pred_h
pred_boxes = torch.stack((pred_boxes_xmin, pred_boxes_ymin, pred_boxes_xmax, pred_boxes_ymax), dim=2).flatten(1)
return pred_boxes
def decode(self, rel_codes, boxes):
concat_boxes = torch.cat(boxes, dim=0)
box_sum = concat_boxes.shape[0]
pred_boxes = self.decode_single(rel_codes, concat_boxes)
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes
class Matcher(object):
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
self.BELOW_LOW_THRESHOLD = -1
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
def __call__(self, match_quality_matrix):
matched_value, matches_idx = match_quality_matrix.max(dim=0)
if self.allow_low_quality_matches:
all_matches = matches_idx.clone()
else:
all_matches = None
below_low_threshold = matched_value < self.low_threshold
between_threshold = (matched_value >= self.low_threshold) & (matched_value < self.high_threshold)
matches_idx[below_low_threshold] = -1
matches_idx[between_threshold] = -2
if self.allow_low_quality_matches:
self.set_low_quality_matches_(matches_idx, all_matches, match_quality_matrix)
return matches_idx
def set_low_quality_matches_(self, matches_idx, all_matches, match_quality_matrix):
highest_quality_gt_value, _ = match_quality_matrix.max(dim=1)
gt_anchor_matches_highest_coordiate = torch.nonzero(match_quality_matrix == highest_quality_gt_value[:, None])
gt_anchor_matches_highest_coordiate_update = gt_anchor_matches_highest_coordiate[:, 1]
matches_idx[gt_anchor_matches_highest_coordiate_update] = all_matches[gt_anchor_matches_highest_coordiate_update]
class BalancedPositiveNegativeSampler(object):
def __init__(self, batch_size_per_image, positive_fraction):
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction
def __call__(self, matched_idxs):
pos_idx = []
neg_idx = []
for matched_idxs_per_image in matched_idxs:
positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
num_pos = int(self.batch_size_per_image * self.positive_fraction)
num_pos = min(positive.numel(), num_pos)
num_neg = self.batch_size_per_image - num_pos
num_neg = min(negative.numel(), num_neg)
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
pos_idx_per_image = positive[perm1]
neg_idx_per_image = negative[perm2]
pos_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1
pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
class RegionProposalNetwork(nn.Module):
def __init__(self, anchor_generate, rpn_head, fg_iou_thresh, bg_iou_thresh, batch_size_per_image,
positive_fraction, pre_nms_top_n, post_nms_top_n, nms_thresh):
super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generate
self.head = rpn_head
self.box_coder = box_Coder(weights=(1.0, 1.0, 1.0, 1.0))
self.box_similarity = box_iou
self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=True)
self.fg_bg_sampler = BalancedPositiveNegativeSampler(
batch_size_per_image, positive_fraction
)
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = 1e-3
def pre_nms_top_n(self):
if self.training:
return self._pre_nms_top_n['training']
return self._pre_nms_top_n['testing']
def post_nms_top_n(self):
if self.training:
return self._post_nms_top_n['training']
return self._post_nms_top_n['testing']
def _get_top_n_idx(self, objectness, num_anchors_per_level):
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
r.append(top_n_idx + offset)
offset += num_anchors
return torch.cat(r, dim=1)
def assign_targets_to_anchors(self, anchors, targets):
labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
gt_boxes = targets_per_image['boxes']
if gt_boxes.numel() == 0:
device = anchors_per_image.device
matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
else:
match_quality_matrix = box_iou(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
bg_indices = matched_idxs == -1
labels_per_image[bg_indices] = 0.0
between_indices = matched_idxs == -2
labels_per_image[between_indices] = -1.0
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
num_images = proposals.shape[0]
device = proposals.device
objectness = objectness.detach()
objectness = objectness.view(num_images, -1)
levels = [torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)]
levels = torch.cat(levels, dim=0)
levels = levels.reshape(1, -1).expand_as(objectness)
top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
image_range = torch.arange(num_images, device=device)
batch_idx = image_range[:, None]
objectness = objectness[batch_idx, top_n_idx]
levels = levels[batch_idx, top_n_idx]
proposals = proposals[batch_idx, top_n_idx]
final_boxes = []
final_scores = []
for boxes, scores, level, img_shape in zip(proposals, objectness, levels, image_shapes):
boxes = clip_boxes_to_image(boxes, img_shape)
keep = remove_small_boxes(boxes, self.min_size)
boxes, scores, level = boxes[keep], scores[keep], level[keep]
keep = batched_nms(boxes, scores, level, self.nms_thresh)
keep = keep[: self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]
final_boxes.append(boxes)
final_scores.append(scores)
return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness = objectness.flatten()
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = smooth_l1_loss(pred_bbox_deltas[sampled_pos_inds], regression_targets[sampled_pos_inds], beta=1 / 9, size_average=False) / (sampled_inds.numel())
objectness_loss = F.binary_cross_entropy_with_logits(
objectness[sampled_inds], labels[sampled_inds]
)
return objectness_loss, box_loss
def forward(self, image_list, features, targets=None):
features = list(features.values())
objectness, pred_bbox_deltas = self.head(features)
anchors = self.anchor_generator(image_list, features)
num_images = len(anchors)
num_anchors_per_level_shape_tensor = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensor]
objectness, pred_bbox_deltas = concat_box_pred_layers(objectness, pred_bbox_deltas)
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
proposals = proposals.view(num_images, -1, 4)
boxes, scores = self.filter_proposals(proposals, objectness, image_list.image_sizes, num_anchors_per_level)
losses = {}
if self.training:
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss(objectness, pred_bbox_deltas, labels, regression_targets)
losses = {
'loss_objectness': loss_objectness,
'loss_rpn_box_reg': loss_rpn_box_reg
}
return boxes, losses
|