IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> tensorflow2实现vision transformer -> 正文阅读

[人工智能]tensorflow2实现vision transformer

从结构来看,主要需要实现:

  1. patch_embedding;包括image的embedding+一个分类头,以及pos_embedding
  2. muliHead_Self_Attention;也就是怎么得到q、k、v,以及它们怎么乘得到attention
  3. MLP

用这三个层就可以堆叠出一个transformer encoder,然后循环layer遍就可以了。

然后取出第一个分类头,经过MLP_head(其实就是一个Dense层)做分类就结束了。

让我们看代码吧:

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
from tensorflow.keras.layers import (Layer,Conv2D,LayerNormalization,
                                    Dense,Input,Dropout,Softmax,Add)

from tensorflow.keras.models import Model
from tensorflow.keras.activations import gelu



# patch_embedding层,包括图片的embedding+分类头, 加上pos_embedding
class PatchEmbedding(Layer):
    def __init__(self,image_size,patch_size,embed_dim,**kwargs):
        super(PatchEmbedding,self).__init__(**kwargs)

        self.embed_dim = embed_dim
        self.n_patches = (image_size//patch_size) * (image_size//patch_size)
        self.patch_embed = Conv2D(self.embed_dim,patch_size,patch_size)

        # 添加分类的token,会concat到image_tokens中,使得shape为[b,196+1,768]
        self.cls_token = self.add_weight('cls_token',shape=[1,1,self.embed_dim],
                                dtype='float32',initializer='random_normal',
                                trainable=True)
        # pos_embedding与(image_tokens+cls_token)相加,所以shape也必须为[b,197,768]
        self.pos_embeding = self.add_weight('pos_embedding',shape=[1,self.n_patches+1,self.embed_dim],
                                dtype='float32',initializer='random_normal',
                                trainable=True)
        
    def call(self,inputs):
        # patch_size=16, embed_dim=768
        # [b,224,224,3] -> [b,14,14,768]
        x = self.patch_embed(inputs)
        # [b,14,14,768] -> [b,196,768]
        b,h,w,_ = x.shape
        x = tf.reshape(x,shape=[b,h*w,self.embed_dim])
        # 1,1,768 -> b,1,768
        cls_tokens = tf.broadcast_to(self.cls_token,(b,1,self.embed_dim))
        # -> b, 197, 768
        x = tf.concat([x,cls_tokens],axis=1)

        # 加上pos_embedding -> b, 197, 728
        x = x + self.pos_embeding

        return x

    def get_config(self):
        config = super(PatchEmbedding, self).get_config()
        config.update({"embed_dim": self.embed_dim,
                        "num_patches":self.n_patches,
                        })
        return config


# msa层的实现
class multiHead_self_attention(Layer):
    def __init__(self,embed_dim,num_heads,attention_dropout=0.0,**kwargs):
        super(multiHead_self_attention,self).__init__(**kwargs)

        self.num_heads = num_heads
        self.head_dim = embed_dim // self.num_heads
        self.all_head_dim = self.num_heads * self.head_dim
        
        self.scale = self.head_dim ** (-0.5) # q*k之后的变换系数

        self.qkv = Dense(self.all_head_dim*3)
        self.proj = Dense(self.all_head_dim)

        self.attention_dropout = Dropout(attention_dropout)

        self.softmax = Softmax()
    
    def call(self,inputs):
        # -> b,197,768*3
        qkv = self.qkv(inputs)
        # q,k,v: b,197,768
        q,k,v = tf.split(qkv,3,axis=-1)
        
        b,n_patches,all_head_dim = q.shape
        # q,k,v: b,197,768 -> b,197,num_heads, head_dim 假设num_heads=12
        # b,197,768 -> b,197,12,64
        q = tf.reshape(q,shape=[b,n_patches,self.num_heads,self.head_dim])
        k = tf.reshape(k,shape=[b,n_patches,self.num_heads,self.head_dim])
        v = tf.reshape(v,shape=[b,n_patches,self.num_heads,self.head_dim])

        # b,197,12,64 -> b,12,197,64
        q = tf.transpose(q,[0,2,1,3])
        k = tf.transpose(k,[0,2,1,3])
        v = tf.transpose(v,[0,2,1,3])
        # -> b,12,12,64
        attention = tf.matmul(q,k,transpose_b=True)
        attention = self.scale * attention
        attention = self.softmax(attention)
        attention = self.attention_dropout(attention)
        # -> b,12,197,64
        out = tf.matmul(attention,v)
        # b,12,197,64 -> b,197,12,64
        out = tf.transpose(out,[0,2,1,3])
        # b,197,12,64 -> b,197,768
        out = tf.reshape(out,shape=[b,n_patches,all_head_dim])

        out = self.proj(out)
        return out

    def get_config(self):
        config = super(multiHead_self_attention, self).get_config()
        config.update({"num_heads": self.num_heads,
                        "head_dim":self.head_dim,
                        "all_head_dim":self.all_head_dim,
                        "scale":self.scale
                        })
        return config


class MLP(Layer):
    def __init__(self,embed_dim,mlp_ratio=4.0,dropout=0.0,**kwargs):
        super(MLP,self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.mlp_ratio = mlp_ratio
        self.dropout = dropout

    def call(self,inputs):
        # 1,197,768 -> 1,197,768*4
        x = Dense(int(self.embed_dim*self.mlp_ratio))(inputs)
        x = gelu(x)
        x = Dropout(self.dropout)(x)

        # 1,197,768*4 - 1,197,768
        x = Dense(self.embed_dim)(x)
        x = Dropout(self.dropout)(x)

        return x

    def get_config(self):
        config = super(MLP,self).get_config()
        config.update({
            "embed_dim":self.embed_dim,
            "mlp_ratio":self.mlp_ratio,
            "dropout":self.dropout
        })


def VisionTransformer(input_shape=[224,224,3],num_classes=5):
    image_size = 224
    num_heads = 12
    patch_size = 16
    embed_dim = 768
    layer_length = 12


    inputs = Input(shape=input_shape,batch_size=1)
    # 1,224,224,3 -> 1,197,768
    x = PatchEmbedding(image_size,patch_size,embed_dim,name='patchAndPos_embedding')(inputs)
    
    # 循环layer_length遍
    for i in range(1,layer_length+1):
        h = x
        x = LayerNormalization(name=f'LayerNorm{i}_1')(x)
        # 1,197,768 -> 1,197,768
        x = multiHead_self_attention(embed_dim,num_heads,0,name=f'MSA{i}')(x)
        # 1,197,768 -> 1,197,768
        x = Add(name=f'add{i}_1')([x,h])
        h = x

        x = LayerNormalization(name=f'LayerNorm{i}_2')(x)
        # 1,197,768 -> 1,197,768
        x = MLP(embed_dim,name=f'MLP{i}')(x)
        # 1,197,768 -> 1,197,768
        x = Add(name=f'add{i}_2')([x,h])

    # 1,197,768 -> 1,768
    cls_token = x[:,0] # 取出第1个token出来做分类
    # 1,768 -> 1, num_classes
    x = Dense(num_classes,name='classifier')(cls_token)
    out = Softmax()(x)

    model = Model(inputs=inputs,outputs=out,name='tf2-vit')

    return model


    


if __name__ == '__main__':
    input_shape = [224,224,3]
    num_classes = 5
    vitmodel = VisionTransformer(input_shape,num_classes)
    vitmodel.summary()

Model: "tf2-vit"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_1 (InputLayer)           [(1, 224, 224, 3)]   0           []

 patchAndPos_embedding (PatchEm  (1, 197, 768)       742656      ['input_1[0][0]']
 bedding)

 LayerNorm1_1 (LayerNormalizati  (1, 197, 768)       1536        ['patchAndPos_embedding[0][0]']
 on)

 MSA1 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm1_1[0][0]']
 )

 add1_1 (Add)                   (1, 197, 768)        0           ['MSA1[0][0]',
                                                                  'patchAndPos_embedding[0][0]']

 LayerNorm1_2 (LayerNormalizati  (1, 197, 768)       1536        ['add1_1[0][0]']
 on)

 MLP1 (MLP)                     (1, 197, 768)        0           ['LayerNorm1_2[0][0]']

 add1_2 (Add)                   (1, 197, 768)        0           ['MLP1[0][0]',
                                                                  'add1_1[0][0]']

 LayerNorm2_1 (LayerNormalizati  (1, 197, 768)       1536        ['add1_2[0][0]']
 on)

 MSA2 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm2_1[0][0]']
 )

 add2_1 (Add)                   (1, 197, 768)        0           ['MSA2[0][0]',
                                                                  'add1_2[0][0]']

 LayerNorm2_2 (LayerNormalizati  (1, 197, 768)       1536        ['add2_1[0][0]']
 on)

 MLP2 (MLP)                     (1, 197, 768)        0           ['LayerNorm2_2[0][0]']

 add2_2 (Add)                   (1, 197, 768)        0           ['MLP2[0][0]',
                                                                  'add2_1[0][0]']

 LayerNorm3_1 (LayerNormalizati  (1, 197, 768)       1536        ['add2_2[0][0]']
 on)

 MSA3 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm3_1[0][0]']
 )

 add3_1 (Add)                   (1, 197, 768)        0           ['MSA3[0][0]',
                                                                  'add2_2[0][0]']

 LayerNorm3_2 (LayerNormalizati  (1, 197, 768)       1536        ['add3_1[0][0]']
 on)

 MLP3 (MLP)                     (1, 197, 768)        0           ['LayerNorm3_2[0][0]']

 add3_2 (Add)                   (1, 197, 768)        0           ['MLP3[0][0]',
                                                                  'add3_1[0][0]']

 LayerNorm4_1 (LayerNormalizati  (1, 197, 768)       1536        ['add3_2[0][0]']
 on)

 MSA4 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm4_1[0][0]']
 )

 add4_1 (Add)                   (1, 197, 768)        0           ['MSA4[0][0]',
                                                                  'add3_2[0][0]']

 LayerNorm4_2 (LayerNormalizati  (1, 197, 768)       1536        ['add4_1[0][0]']
 on)

 MLP4 (MLP)                     (1, 197, 768)        0           ['LayerNorm4_2[0][0]']

 add4_2 (Add)                   (1, 197, 768)        0           ['MLP4[0][0]',
                                                                  'add4_1[0][0]']

 LayerNorm5_1 (LayerNormalizati  (1, 197, 768)       1536        ['add4_2[0][0]']
 on)

 MSA5 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm5_1[0][0]']
 )

 add5_1 (Add)                   (1, 197, 768)        0           ['MSA5[0][0]',
                                                                  'add4_2[0][0]']

 LayerNorm5_2 (LayerNormalizati  (1, 197, 768)       1536        ['add5_1[0][0]']
 on)

 MLP5 (MLP)                     (1, 197, 768)        0           ['LayerNorm5_2[0][0]']

 add5_2 (Add)                   (1, 197, 768)        0           ['MLP5[0][0]',
                                                                  'add5_1[0][0]']

 LayerNorm6_1 (LayerNormalizati  (1, 197, 768)       1536        ['add5_2[0][0]']
 on)

 MSA6 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm6_1[0][0]']
 )

 add6_1 (Add)                   (1, 197, 768)        0           ['MSA6[0][0]',
                                                                  'add5_2[0][0]']

 LayerNorm6_2 (LayerNormalizati  (1, 197, 768)       1536        ['add6_1[0][0]']
 on)

 MLP6 (MLP)                     (1, 197, 768)        0           ['LayerNorm6_2[0][0]']

 add6_2 (Add)                   (1, 197, 768)        0           ['MLP6[0][0]',
                                                                  'add6_1[0][0]']

 LayerNorm7_1 (LayerNormalizati  (1, 197, 768)       1536        ['add6_2[0][0]']
 on)

 MSA7 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm7_1[0][0]']
 )

 add7_1 (Add)                   (1, 197, 768)        0           ['MSA7[0][0]',
                                                                  'add6_2[0][0]']

 LayerNorm7_2 (LayerNormalizati  (1, 197, 768)       1536        ['add7_1[0][0]']
 on)

 MLP7 (MLP)                     (1, 197, 768)        0           ['LayerNorm7_2[0][0]']

 add7_2 (Add)                   (1, 197, 768)        0           ['MLP7[0][0]',
                                                                  'add7_1[0][0]']

 LayerNorm8_1 (LayerNormalizati  (1, 197, 768)       1536        ['add7_2[0][0]']
 on)

 MSA8 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm8_1[0][0]']
 )

 add8_1 (Add)                   (1, 197, 768)        0           ['MSA8[0][0]',
                                                                  'add7_2[0][0]']

 LayerNorm8_2 (LayerNormalizati  (1, 197, 768)       1536        ['add8_1[0][0]']
 on)

 MLP8 (MLP)                     (1, 197, 768)        0           ['LayerNorm8_2[0][0]']

 add8_2 (Add)                   (1, 197, 768)        0           ['MLP8[0][0]',
                                                                  'add8_1[0][0]']

 LayerNorm9_1 (LayerNormalizati  (1, 197, 768)       1536        ['add8_2[0][0]']
 on)

 MSA9 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm9_1[0][0]']
 )

 add9_1 (Add)                   (1, 197, 768)        0           ['MSA9[0][0]',
                                                                  'add8_2[0][0]']

 LayerNorm9_2 (LayerNormalizati  (1, 197, 768)       1536        ['add9_1[0][0]']
 on)

 MLP9 (MLP)                     (1, 197, 768)        0           ['LayerNorm9_2[0][0]']

 add9_2 (Add)                   (1, 197, 768)        0           ['MLP9[0][0]',
                                                                  'add9_1[0][0]']

 LayerNorm10_1 (LayerNormalizat  (1, 197, 768)       1536        ['add9_2[0][0]']
 ion)

 MSA10 (multiHead_self_attentio  (1, 197, 768)       2362368     ['LayerNorm10_1[0][0]']
 n)

 add10_1 (Add)                  (1, 197, 768)        0           ['MSA10[0][0]',
                                                                  'add9_2[0][0]']

 LayerNorm10_2 (LayerNormalizat  (1, 197, 768)       1536        ['add10_1[0][0]']
 ion)

 MLP10 (MLP)                    (1, 197, 768)        0           ['LayerNorm10_2[0][0]']

 add10_2 (Add)                  (1, 197, 768)        0           ['MLP10[0][0]',
                                                                  'add10_1[0][0]']

 LayerNorm11_1 (LayerNormalizat  (1, 197, 768)       1536        ['add10_2[0][0]']
 ion)

 MSA11 (multiHead_self_attentio  (1, 197, 768)       2362368     ['LayerNorm11_1[0][0]']
 n)

 add11_1 (Add)                  (1, 197, 768)        0           ['MSA11[0][0]',
                                                                  'add10_2[0][0]']

 LayerNorm11_2 (LayerNormalizat  (1, 197, 768)       1536        ['add11_1[0][0]']
 ion)

 MLP11 (MLP)                    (1, 197, 768)        0           ['LayerNorm11_2[0][0]']

 add11_2 (Add)                  (1, 197, 768)        0           ['MLP11[0][0]',
                                                                  'add11_1[0][0]']

 LayerNorm12_1 (LayerNormalizat  (1, 197, 768)       1536        ['add11_2[0][0]']
 ion)

 MSA12 (multiHead_self_attentio  (1, 197, 768)       2362368     ['LayerNorm12_1[0][0]']
 n)

 add12_1 (Add)                  (1, 197, 768)        0           ['MSA12[0][0]',
                                                                  'add11_2[0][0]']

 LayerNorm12_2 (LayerNormalizat  (1, 197, 768)       1536        ['add12_1[0][0]']
 ion)

 MLP12 (MLP)                    (1, 197, 768)        0           ['LayerNorm12_2[0][0]']

 add12_2 (Add)                  (1, 197, 768)        0           ['MLP12[0][0]',
                                                                  'add12_1[0][0]']

 tf.__operators__.getitem (Slic  (1, 768)            0           ['add12_2[0][0]']
 ingOpLambda)

 classifier (Dense)             (1, 5)               3845        ['tf.__operators__.getitem[0][0]'
                                                                 ]

 softmax_12 (Softmax)           (1, 5)               0           ['classifier[0][0]']

==================================================================================================
Total params: 29,131,781
Trainable params: 29,131,781
Non-trainable params: 0
_________________________________________________________________________________________

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-27 10:59:45  更:2022-02-27 11:00:15 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 3:05:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码