手把手使用FasterRCNN进行目标检测
代码Github地址
使用pytorch自带的预训练FasterRCNN目标检测模型进行目标检测,并把检测到的结果打印到原图片上。
由于pytorch自带的FasterRCNN是使用COCO数据集进行预训练得到的,所以我们需要使用COCO的标签来匹配模型输出的结果。
详细的注释都在代码里面了。
import torch
import torchvision
from torchvision.models.detection.roi_heads import fastrcnn_loss
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image, ImageDraw, ImageFont
COCO_CLASSES ={1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane',
6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light',
11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench',
16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow',
22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack',
28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee',
35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat',
40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket',
44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',
51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',
61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table',
70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard',
77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink',
82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors',
88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}
COLORS = ['#e6194b', '#3cb44b', '#ffe119', '#0082c8', '#f58231', '#911eb4', '#46f0f0', '#f032e6',
'#d2f53c', '#fabebe', '#008080', '#000080', '#aa6e28', '#fffac8', '#800000', '#aaffc3', '#808000',
'#ffd8b1', '#e6beff', '#808080']
LABEL_COLOR_MAP = {k: COLORS[i%len(COLORS)] for i, k in enumerate(COCO_CLASSES.keys())}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def faster_rcnn_detection(img_path):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)
model.eval()
origin_img = Image.open(img_path, mode='r').convert('RGB')
img = TF.to_tensor(origin_img)
img = img.to(device)
output = model(img.unsqueeze(0))
labels = output[0]['labels'].cpu().detach().numpy()
scores = output[0]['scores'].cpu().detach().numpy()
bboxes = output[0]['boxes'].cpu().detach().numpy()
obj_index = np.argwhere(scores>0.8).squeeze(axis=1).tolist()
draw = ImageDraw.Draw(origin_img)
font = ImageFont.truetype('Arial.ttf', 15)
for i in obj_index:
box_location = bboxes[i].tolist()
draw.rectangle(xy=box_location, outline=LABEL_COLOR_MAP[labels[i]])
draw.rectangle(xy=[l + 1. for l in box_location], outline=LABEL_COLOR_MAP[labels[i]])
text_size = font.getsize(COCO_CLASSES[labels[i]])
text_location = [box_location[0] + 2., box_location[1] - text_size[1]]
textbox_location = [box_location[0], box_location[1] - text_size[1], box_location[0] + text_size[0] + 4., box_location[1]]
draw.rectangle(xy=textbox_location, fill=LABEL_COLOR_MAP[labels[i]])
draw.text(xy=text_location, text=COCO_CLASSES[labels[i]], fill='white', font=font)
del draw
origin_img.save("result.png")
if __name__ == '__main__':
faster_rcnn_detection("input.jpg")
输入图片:
输出图片:
|