IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch手把手使用FasterRCNN进行目标检测 -> 正文阅读

[人工智能]pytorch手把手使用FasterRCNN进行目标检测

手把手使用FasterRCNN进行目标检测

代码Github地址

使用pytorch自带的预训练FasterRCNN目标检测模型进行目标检测,并把检测到的结果打印到原图片上。

由于pytorch自带的FasterRCNN是使用COCO数据集进行预训练得到的,所以我们需要使用COCO的标签来匹配模型输出的结果。

详细的注释都在代码里面了。

import torch
import torchvision
from torchvision.models.detection.roi_heads import fastrcnn_loss
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image, ImageDraw, ImageFont

# COCO数据集标签对照表
COCO_CLASSES ={1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane',
                   6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light',
                   11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench',
                   16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow',
                   22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack',
                   28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee',
                   35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat',
                   40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket',
                   44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',
                   51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',
                   56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',
                   61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table',
                   70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard',
                   77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink',
                   82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors',
                   88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}

COLORS = ['#e6194b', '#3cb44b', '#ffe119', '#0082c8', '#f58231', '#911eb4', '#46f0f0', '#f032e6',
            '#d2f53c', '#fabebe', '#008080', '#000080', '#aa6e28', '#fffac8', '#800000', '#aaffc3', '#808000',
            '#ffd8b1', '#e6beff', '#808080']

# 为每一个标签对应一种颜色,方便我们显示
LABEL_COLOR_MAP = {k: COLORS[i%len(COLORS)] for i, k in enumerate(COCO_CLASSES.keys())}

# 判断GPU设备是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def faster_rcnn_detection(img_path):
    # 加载pytorch自带的预训练Faster RCNN目标检测模型
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    model.to(device)
    model.eval()

    # 读取输入图像,并转化为tensor
    origin_img = Image.open(img_path, mode='r').convert('RGB')
    img = TF.to_tensor(origin_img)
    img = img.to(device)

    # 将图像输入神经网络模型中,得到输出
    output = model(img.unsqueeze(0))

    labels = output[0]['labels'].cpu().detach().numpy()     # 预测每一个obj的标签
    scores = output[0]['scores'].cpu().detach().numpy()     # 预测每一个obj的得分
    bboxes = output[0]['boxes'].cpu().detach().numpy()      # 预测每一个obj的边框
    # 这个我们只选取得分大于0.8的
    obj_index = np.argwhere(scores>0.8).squeeze(axis=1).tolist()

    # 使用ImageDraw将检测到的边框和类别打印在图片中,得到最终的输出
    draw = ImageDraw.Draw(origin_img)
    font = ImageFont.truetype('Arial.ttf', 15)

    for i in obj_index:
        box_location = bboxes[i].tolist()
        draw.rectangle(xy=box_location, outline=LABEL_COLOR_MAP[labels[i]])
        draw.rectangle(xy=[l + 1. for l in box_location], outline=LABEL_COLOR_MAP[labels[i]])

        text_size = font.getsize(COCO_CLASSES[labels[i]])
        text_location = [box_location[0] + 2., box_location[1] - text_size[1]]
        textbox_location = [box_location[0], box_location[1] - text_size[1], box_location[0] + text_size[0] + 4., box_location[1]]
        draw.rectangle(xy=textbox_location, fill=LABEL_COLOR_MAP[labels[i]])
        draw.text(xy=text_location, text=COCO_CLASSES[labels[i]], fill='white', font=font)

    del draw

    origin_img.save("result.png")


if __name__ == '__main__':
    faster_rcnn_detection("input.jpg")

输入图片:
请添加图片描述

输出图片:

请添加图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-07 13:51:15  更:2021-10-07 13:53:58 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 14:15:55-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码