IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Transformer中Relative Position Bias以及DropPath细节梳理 -> 正文阅读

[人工智能]Transformer中Relative Position Bias以及DropPath细节梳理

1、Relative Position Bias[相对位置编码]

在transformer系列模型结构中,有关位置编码出现了一些变体,transformer以及ViT中使用原生的sine-cosine周期绝对位置编码(periodic absolute position encoding);而在最近的transformer变体工作中,e.g. SwinTransformer,BEIT等模型都使用相对位置编码(relative position encoding)。

二者有什么异同点?首先从paper empirical experiments上总结得到:对于image classification任务而言,二者差异不大,但是对于更细粒度的任务,如, object detection,image segmentation,相对位置编码泛化能力更强。所以目前在cv上,基于transformer的模型优化采用相对位置编码。

简要总结一下这二种位置编码方式:

  • 绝对位置编码

绝对位置编码, 论文中一般称1D绝对位置编码 , 看其张量的shape,本质上是一个二维张量,x.shape = [L, D],L是序列长度,对于每个token都会分配一个D维的位置embedding。其作用于token embedding之上,序列的token embedding + 绝对位置编码embedding作为transformer的序列输入

  • 相对位置编码

相对位置编码,论文中一般称2D相对位置编码,看其张量的shape,本质上是一个三维张量,

x.shape = [N, L, L],N是heads数量,L是序列长度,其作用于attention score上。相对位置编码,把图像序列恢复成2维空间结构,然后分别计算x轴,y轴平面上每个坐标点相对于其他所有坐标点的位置差异。

假设序列长度L = 9, 即3x3[window=3]图像平面,接下来分别计算在x轴以及y轴的坐标位置差异

x, y = torch.meshgird(torch.arange(3), torch.arange(3))

X = [[0,0,0],[1,1,1],[2,2,2]] y = [[0,1,2],[0,1,2],[0,1,2]]; 计算x,y平面上每个点相对于其他点的位置差异,应该是一个9x9矩阵; 计算时把 x, y拉成一个向量, x= [0,0,0,1,1,1,2,2,2], y = [0,1,2,0,1,2,0,1,2]

pos_x = x[:, None] - x[None, :] = [[0,0,0,1,1,1,2,2,2]]-[[0],[0],[0],[1],[1],[1],[2],[2],[2]] = [[0,0,0,1,1,1,2,2,2],[0,0,0,1,1,1,2,2,2],[0,0,0,1,1,1,2,2,2],[1,1,1,0,0,0,-1,-1,-1],[1,1,1,0,0,0,-1,-1,-1],[1,1,1,0,0,0,-1,-1,-1],[2,2,2,1,1,1,0,0,0],[2,2,2,1,1,1,0,0,0],[2,2,2,1,1,1,0,0,0]]

pos_y = y[:,None] - y[None, :] = [[0,1,2,0,1,2,0,1,2]]-[[0],[1],[2],[0],[1],[2],[0],[1],[2]] =

[[0,1,2,0,1,2,0,1,2],[-1,0,1,-1,0,1,-1,0,1],[-2,-1,0,-2,-1,0,-2,-1,0],[0,1,2,0,1,2,0,1,2],[-1,0,1,-1,0,1,-1,0,1],[-2,-1,0,-2,-1,0,-2,-1,0],[0,1,2,0,1,2,0,1,2],[-1,0,1,-1,0,1,-1,0,1],[-2,-1,0,-2,-1,0,-2,-1,0]]

可以看到pos_x, pos_y取值范围是在[-window+1, window-1], pos_x, pos_y同时加上window-1, 使得x轴,y轴上相对位置偏置变成从0开始; 理论上pos_x + pos_y 就是二维的相对位置编码索引了,但是会发现索引位置x+y的值会出现不唯一的情况,因此在pos_x + pos_y 之前对pos_x的坐标乘以2*window-1;

pos_x *= 2*window-1, pos_index = pos_x + pos_y。此时索引的最大值为: (2*window - 1)(2*window-2)[x轴的最大值] + 2*window - 2[y轴最大值] = (2*window-1)**2 - 1

二维可学习的相对位置查找表大小为: rel_pos_table = torch.zeros((2*window-1, 2*window-1))[可以看到二维相对位置索引的最大值,刚好覆盖到二维相对位置编码查找表]

具体源码如下:

class RelativePositionBias(nn.Module):

    def __init__(self, window_size, num_heads):
        super().__init__()
        self.window_size = window_size
        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) 
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        

        self.register_buffer("relative_position_index", relative_position_index)

        # trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self):
        relative_position_bias = \
            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1],
                self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2、DropPath原理分析

droppath是用来实现随机深度的一种正则化手段,其作用在有多个路径分支的时候,随机丢弃某个路径的输出。数学上等价于,在样本维度,随机丢弃某些样本的输出{本质}。可以用dropout层来实现,dropout的维度在样本维度, tf.keras.layers.Dropout可以通过axis来控制在哪些维度进行dropout。

pytorch实现:

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    # 注意是在样本维度随机丢弃
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    # 除以keep_prob保证输出期望不变
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
        
       
# 具体调用
self.drop_path = DropPath(drop_prob) if drop_prob > 0. else nn.Identity()
# -------------note-------------- #
# droppath一般作用在具有多个分支路径的网络结构上
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-07-21 21:32:47  更:2022-07-21 21:32:51 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 1:21:17-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码