IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Focal and Global Knowledge Distillation for Detectors--FGD论文解读 -> 正文阅读

[人工智能]Focal and Global Knowledge Distillation for Detectors--FGD论文解读

论文:Focal and Global Knowledge Distillation for Detectors

论文:https://arxiv.org/abs/2111.11837

代码:https://github.com/yzd-v/FGD

一,针对问题

1. 目标检测中前背景不平衡问题

????????知识蒸馏旨在使学生学习教师的知识,以获得相似的输出从而提升性能。为了探索学生与教师在特征层面的差异,作者首先对二者的特征图进行了可视化。可以看到在空间与通道注意力上,教师与学生均存在较大的差异。其中在空间注意力上,二者在前景中的差异较大,在背景中的差异较小,这会给蒸馏中的学生带来不同的学习难度。

?????????为了进一步探索前背景对于知识蒸馏的影响,作者分离出前背景进行了蒸馏实验,全图一起蒸馏会导致蒸馏性能的下降,将前景与背景分开学生能够获得更好的表现。

?????????针对学生与教师注意力的差异,前景与背景的差异,作者提出了重点蒸馏Focal Distillation:分离前背景,并利用教师的空间与通道注意力作为权重,指导学生进行知识蒸馏,计算重点蒸馏损失。

二,方法

整体蒸馏损失计算方式:

C,H,W:feature map的通道时和高宽。

?F^TF^{S}为教师和学生模型的输出。

?2.1 分离前背景

前、背景Mask

设置一个二值MASK:

r代表GT bbox,如果feature map的点落在bbox内则该点为1,否则为0.

?2.2 尺度

尺度Mask

?大小目标focal,前、背景

Hr和Wr为bbox的高和宽,如果一个同时属于多个目标(遮挡场景)选取bbox最小的目标去计算S

?2.2 空间与通道注意力

?空间与通道注意力

?C,H,W:feature map的通道时和高宽。

GG^S ,G^{C}?代表空间注意立和通道注意力机制,

Attention MASK:

T为蒸馏温度 ,论文设置为0.5

2.3 全局蒸馏

全局信息的丢失

????????Focal Distillation将前景背景分开进行蒸馏,割断了前背景的联系,缺乏了特征的全局信息的蒸馏。为此提出了全局蒸馏Global Distillation:利用GcBlock分别提取学生与教师的全局信息,并进行全局蒸馏损失的计算。

?使用GCBlock去获取全局信息,使得学生模型从教室模型中学习前背景的联系。

损失计算如下:

    
        self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
        self.channel_add_conv_s = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
        self.channel_add_conv_t = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))

    def spatial_pool(self, x, in_type):
        batch, channel, width, height = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        if in_type == 0:
            context_mask = self.conv_mask_s(x)
        else:
            context_mask = self.conv_mask_t(x)
        # [N, 1, H * W]
        context_mask = context_mask.view(batch, 1, height * width)
        # [N, 1, H * W]
        context_mask = F.softmax(context_mask, dim=2)
        # [N, 1, H * W, 1]
        context_mask = context_mask.unsqueeze(-1)
        # [N, 1, C, 1]
        context = torch.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = context.view(batch, channel, 1, 1)

        return context

   
    def get_rela_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='sum')

        context_s = self.spatial_pool(preds_S, 0)
        context_t = self.spatial_pool(preds_T, 1)

        out_s = preds_S
        out_t = preds_T

        channel_add_s = self.channel_add_conv_s(context_s)
        out_s = out_s + channel_add_s

        channel_add_t = self.channel_add_conv_t(context_t)
        out_t = out_t + channel_add_t

        rela_loss = loss_mse(out_s, out_t)/len(out_s)
        
        return rela_loss

?2.4?最终Loss


alpha=0.001,beta=0.0005

除此之外,利用L_{at}注意力损失来强迫学生模型去逼近教师模型的空间和通道注意力Mask
公式如下:


gamma=0.0005.

最终loss

lambda=0.000005

关于超参

最终效果:

?

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-17 12:33:42  更:2022-10-17 12:37:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 1:56:47-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计