IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 结合代码看Vision Transformer【ViT】 -> 正文阅读

[人工智能]结合代码看Vision Transformer【ViT】

参考仓库:

论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

有相关问题搜索知识星球号:1453755 【CV老司机】加入星球提问。扫码也可加入:

也可以搜索关注微信公众号: CV老司机

相关代码和详细资源或者相关问题,可联系牛先生小猪wx号: jishudashou

结构介绍:

ViT: Transformer + Head

Transformer: Embeddings [1x197x768] + Encoder

Encoder: N x { Block_Sequence + layerNorm [非全局均值方差,有的实现没做】}

Block: LayerNorm + MultiHeadAttension + LayerNorm + Mlp [中间有两次残差累加]

>>> 以输入224x224x3为例,embedding :196+1 个patch , 768 通道【embedding dimension】

Embedding说明

patch embedding : 将图像分为16x16的小块,然后把16x16x3的小块拉平并跟一个FC做特征映射【有的实现是直接使用了kernel为16x16,stride为16的卷积实现,当然这里和我们说的Vision Transformer就有一定出入】

clstokens : 如果上面的patch embedding 为 1x196x768 , 这里的class token就是 1x1x768, 然后cat在一块儿。作为分类特征得汇总。【作用可以视为在CNN里面最后一层的FC logits】【注:维度也可以是 1x196x1024,:超参调节】

position embedding : NLP的实现中,position embedding 用来标记这个patch 在全局中的哪个位置,用于学习一定的结构信息。这里Vision Transformer 遵从原本设计,加入了这个可学习的position embedding.

参考的path embedding 代码:

 
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim),
        )

维度变化: 1x3x224x224  --->   1x196x768

LayerNorm说明

涉及到的算子如下:也就是上面的公式:减均值,除方差,乘以scale,加bias

multi-Head Attention 实现

取得全图注意力。

在这里实现没有mask,NLP中用于句子不一样时,还有填空时,做掩码。

参考代码1:【和上图除了mask部分,流程部分基本一致】

 
 def forward(self, hidden_states):
        # 1 x 197 x 768   197 个patch
        mixed_query_layer = self.query(hidden_states)  # fc implement query
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # 1 x 12 x 197 x 64  197pathch  12组特征,每组64维
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # 1 x 12 x 197 x 197 ,感受野添加至 每个patch两两之间通道特征相关性
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 类似于使用温度系数添加注意力多样性
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # patch 内部通道【特征】重要程度使用softmax添加注意力
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        # 感受野范围添加至全图所有patch
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights

参考代码2:

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
 

整体流程描述:

从整理流程聊聊为什么多头注意力可以获取全图的感受野。首先第一个矩阵乘,以上图为例,每一个输出点,获取到了输入图与输入图两两对应的乘累加结果【我们这里叫相关性,也可以叫感受野就是和另外一个patch通道维度】,第二次矩阵乘,输入就是两个patch相关性,单个点具备两个patch所有通道信息,乘累加过程就是,乘以对应通道权重,然后与其他所有patch乘以对应通道权重结果的累加,这样之后就影响这个点结果的因子覆盖了全图。也就是我们常说的感受野是全图。

>>> attention 每一行,patchN 与所有patch相关性,乘以对应通道权重的累加和。

MLP实现

两个FC+一个激活【gelu激活】

class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x
 

gelu实现:

x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 

? ? ? ?全文差不多就这些了。Vit用尽量接近Transformer的方式来做了视觉任务最基本的分类任务,并且也取得了十分SOTA的效果。十分新颖!

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-15 11:47:53  更:2021-10-15 11:50:09 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 11:14:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码