IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> YOLOv5的Tricks | 【Trick2】目标检测中进行多模型推理预测(Model Ensemble) -> 正文阅读

[人工智能]YOLOv5的Tricks | 【Trick2】目标检测中进行多模型推理预测(Model Ensemble)


如有错误,恳请指出。


在学习yolov5代码的时候,发现experimental.py文件中有一个很亮眼的模块:Ensemble。接触过机器学习的可能了解到,机器学习的代表性算法是随机森林这种,使用多个模型来并行推理,然后归纳他们的中值或者是平均值来最为整个模型的最后预测结构,没想到的是目标检测中也可以使用,叹为观止。下面就对其进行详细介绍:


1. Ensemble的概念

集成建模是通过使用许多不同的建模算法或使用不同的训练数据集创建多个不同模型来预测结果的过程。使用集成模型的动机是减少预测的泛化误差。只要基础模型是多样且独立的,使用集成方法时模型的预测误差就会减小。该方法在做出预测时寻求群体的智慧。即使集成模型在模型中具有多个基础模型(求多个模型的平均值或最大值),它仍作为单个模型运行和执行(最终还是以一个综合模型的取整进行预测)。

详细介绍见:https://www.sciencedirect.com/topics/computer-science/ensemble-modeling


2. Ensemble的实现

yolov5实现代码如下:

# 集成算法
class Ensemble(nn.ModuleList):
    # Ensemble of models
    def __init__(self):
        super().__init__()

    def forward(self, x, augment=False, profile=False, visualize=False):
        y = []
        # 集成模型为多个模型时, 在每一层forward运算时, 都要运行多个模型在该层的结果append进y中
        for module in self:
            y.append(module(x, augment, profile, visualize)[0])
        # y = torch.stack(y).max(0)[0]  # 求多个模型结果的最大值 max ensemble
        # y = torch.stack(y).mean(0)    # 求多个模型结果的均值   mean ensemble
        y = torch.cat(y, 1)             # 将多个模型结果concat在一起, 后面做做nms等于翻了一倍的pred nms ensemble
        return y, None  # inference, train output

在yolov5中使用attempt_load模块来实现多模型的调用,代码如下:

def attempt_load(weights, map_location=None, inplace=True, fuse=True):
    from models.yolo import Detect, Model

    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
    model = Ensemble()
    
    # weights如果是单路径, 则使用单个模型; 如果是list多路径, 则使用集成模型(多模型)
    for w in weights if isinstance(weights, list) else [weights]:
        # 这里map_location参数可以指定加载设备, 或者实现设备间的转化,eg:cuda1->cuda0 / cuda->cpu
        ckpt = torch.load(attempt_download(w), map_location=map_location)  # load
        if fuse:
            # 参数重结构化
            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
        else:
            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval())  # without layer fuse

    # Compatibility updates(关于版本兼容的设置)
    for m in model.modules():
        if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
            m.inplace = inplace  # pytorch 1.7.0 compatibility
            if type(m) is Detect:
                if not isinstance(m.anchor_grid, list):  # new Detect Layer compatibility
                    delattr(m, 'anchor_grid')       # delattr(x, 'y') is equivalent to `del x.y'
                    setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)  # setattr(x, 'y', v) is equivalent to `x.y = v'
        elif type(m) is Conv:
            m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility

    # 单模型设置
    if len(model) == 1:
        return model[-1]  # return model

    # 集成模型设置
    else:
        print(f'Ensemble created with {weights}\n')
        # 给每个模型一个name属性
        for k in ['names']:
            setattr(model, k, getattr(model[-1], k))    # getattr(x, 'y') is equivalent to x.y
        # 给每个模型分配stride属性
        model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride  # max stride
        return model  # return ensemble

3. Ensemble的使用

Ensemble使用方法,具体见yolov5的Tutorial:https://github.com/ultralytics/yolov5/issues/318

  • Ensemble Test
# python val.py --weights yolov5x.pt --data coco.yaml --img 640 --half    # use single 
python val.py --weights yolov5x.pt yolov5l6.pt --data coco.yaml --img 640 --half
  • Ensemble Inference
python detect.py --weights yolov5x.pt yolov5l6.pt --img 640 --source data/images

Output:

detect: weights=['yolov5x.pt', 'yolov5l6.pt'], source=data/images, imgsz=640, conf_thres=0.25, iou_thres=0.45, max_det=1000, device=, view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, update=False, project=runs/detect, name=exp, exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False
YOLOv5 🚀 v5.0-267-g6a3ee7c torch 1.9.0+cu102 CUDA:0 (Tesla P100-PCIE-16GB, 16280.875MB)

Fusing layers... 
Model Summary: 476 layers, 87730285 parameters, 0 gradients
Fusing layers... 
Model Summary: 501 layers, 77218620 parameters, 0 gradients
Ensemble created with ['yolov5x.pt', 'yolov5l6.pt']

image 1/2 /content/yolov5/data/images/bus.jpg: 640x512 4 persons, 1 bus, 1 tie, Done. (0.063s)
image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 3 persons, 2 ties, Done. (0.056s)
Results saved to runs/detect/exp2
Done. (0.223s)

可以看见输出的时候出打印使用了多少个模型,每个模型的层数,参数量

测试代码:

if __name__ == '__main__':

    x = torch.rand([8, 3, 640, 640])
    weights = ['../weights/yolov5s.pt', '../weights/yolov5m.pt', '../weights/yolov5l.pt']
    device = torch.device('cpu')

    # 集成模型测试
    model = attempt_load(weights, map_location=device)
    print("len(model(x)):", len(model(x)))
    print(model(x)[0].shape)

    # 单模型测试
    model = attempt_load(weights[0], map_location=device)
    print("len(model(x)):", len(model(x)))
    print(model(x)[0].shape)

输出:

Ensemble created with ['../weights/yolov5s.pt', '../weights/yolov5m.pt', '../weights/yolov5l.pt']

len(model(x)): 2
torch.Size([8, 75600, 85])
len(model(x)): 2
torch.Size([8, 25200, 85])

需要注意:

集成模块只能在推理的阶段使用(也就是测试或者验证阶段),因为这时候是调用多个已经训练好的模型权重来分别独立的对输入进行预测,然后每个训练好的模型所得到的结果取平均或者堆叠再做后续的后处理操作。

可以注意到,由于在上诉的Ensemble模块中,源码中选择了将结果拼接了在一起。所以可以看见,在对通一批图像做处理的时候,会得到多个模型预测的结果,预测框成倍的增加,使用多少个模型就会增加多少倍,后处理过程会变慢,但是精度会提高,其实也可以换成mean或者是max的方法。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-06 17:19:25  更:2022-06-06 17:20:26 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/30 2:30:09-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码