从结构来看,主要需要实现:
- patch_embedding;包括image的embedding+一个分类头,以及pos_embedding
- muliHead_Self_Attention;也就是怎么得到q、k、v,以及它们怎么乘得到attention
- MLP
用这三个层就可以堆叠出一个transformer encoder,然后循环layer遍就可以了。
然后取出第一个分类头,经过MLP_head(其实就是一个Dense层)做分类就结束了。
让我们看代码吧:
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.layers import (Layer,Conv2D,LayerNormalization,
Dense,Input,Dropout,Softmax,Add)
from tensorflow.keras.models import Model
from tensorflow.keras.activations import gelu
# patch_embedding层,包括图片的embedding+分类头, 加上pos_embedding
class PatchEmbedding(Layer):
def __init__(self,image_size,patch_size,embed_dim,**kwargs):
super(PatchEmbedding,self).__init__(**kwargs)
self.embed_dim = embed_dim
self.n_patches = (image_size//patch_size) * (image_size//patch_size)
self.patch_embed = Conv2D(self.embed_dim,patch_size,patch_size)
# 添加分类的token,会concat到image_tokens中,使得shape为[b,196+1,768]
self.cls_token = self.add_weight('cls_token',shape=[1,1,self.embed_dim],
dtype='float32',initializer='random_normal',
trainable=True)
# pos_embedding与(image_tokens+cls_token)相加,所以shape也必须为[b,197,768]
self.pos_embeding = self.add_weight('pos_embedding',shape=[1,self.n_patches+1,self.embed_dim],
dtype='float32',initializer='random_normal',
trainable=True)
def call(self,inputs):
# patch_size=16, embed_dim=768
# [b,224,224,3] -> [b,14,14,768]
x = self.patch_embed(inputs)
# [b,14,14,768] -> [b,196,768]
b,h,w,_ = x.shape
x = tf.reshape(x,shape=[b,h*w,self.embed_dim])
# 1,1,768 -> b,1,768
cls_tokens = tf.broadcast_to(self.cls_token,(b,1,self.embed_dim))
# -> b, 197, 768
x = tf.concat([x,cls_tokens],axis=1)
# 加上pos_embedding -> b, 197, 728
x = x + self.pos_embeding
return x
def get_config(self):
config = super(PatchEmbedding, self).get_config()
config.update({"embed_dim": self.embed_dim,
"num_patches":self.n_patches,
})
return config
# msa层的实现
class multiHead_self_attention(Layer):
def __init__(self,embed_dim,num_heads,attention_dropout=0.0,**kwargs):
super(multiHead_self_attention,self).__init__(**kwargs)
self.num_heads = num_heads
self.head_dim = embed_dim // self.num_heads
self.all_head_dim = self.num_heads * self.head_dim
self.scale = self.head_dim ** (-0.5) # q*k之后的变换系数
self.qkv = Dense(self.all_head_dim*3)
self.proj = Dense(self.all_head_dim)
self.attention_dropout = Dropout(attention_dropout)
self.softmax = Softmax()
def call(self,inputs):
# -> b,197,768*3
qkv = self.qkv(inputs)
# q,k,v: b,197,768
q,k,v = tf.split(qkv,3,axis=-1)
b,n_patches,all_head_dim = q.shape
# q,k,v: b,197,768 -> b,197,num_heads, head_dim 假设num_heads=12
# b,197,768 -> b,197,12,64
q = tf.reshape(q,shape=[b,n_patches,self.num_heads,self.head_dim])
k = tf.reshape(k,shape=[b,n_patches,self.num_heads,self.head_dim])
v = tf.reshape(v,shape=[b,n_patches,self.num_heads,self.head_dim])
# b,197,12,64 -> b,12,197,64
q = tf.transpose(q,[0,2,1,3])
k = tf.transpose(k,[0,2,1,3])
v = tf.transpose(v,[0,2,1,3])
# -> b,12,12,64
attention = tf.matmul(q,k,transpose_b=True)
attention = self.scale * attention
attention = self.softmax(attention)
attention = self.attention_dropout(attention)
# -> b,12,197,64
out = tf.matmul(attention,v)
# b,12,197,64 -> b,197,12,64
out = tf.transpose(out,[0,2,1,3])
# b,197,12,64 -> b,197,768
out = tf.reshape(out,shape=[b,n_patches,all_head_dim])
out = self.proj(out)
return out
def get_config(self):
config = super(multiHead_self_attention, self).get_config()
config.update({"num_heads": self.num_heads,
"head_dim":self.head_dim,
"all_head_dim":self.all_head_dim,
"scale":self.scale
})
return config
class MLP(Layer):
def __init__(self,embed_dim,mlp_ratio=4.0,dropout=0.0,**kwargs):
super(MLP,self).__init__(**kwargs)
self.embed_dim = embed_dim
self.mlp_ratio = mlp_ratio
self.dropout = dropout
def call(self,inputs):
# 1,197,768 -> 1,197,768*4
x = Dense(int(self.embed_dim*self.mlp_ratio))(inputs)
x = gelu(x)
x = Dropout(self.dropout)(x)
# 1,197,768*4 - 1,197,768
x = Dense(self.embed_dim)(x)
x = Dropout(self.dropout)(x)
return x
def get_config(self):
config = super(MLP,self).get_config()
config.update({
"embed_dim":self.embed_dim,
"mlp_ratio":self.mlp_ratio,
"dropout":self.dropout
})
def VisionTransformer(input_shape=[224,224,3],num_classes=5):
image_size = 224
num_heads = 12
patch_size = 16
embed_dim = 768
layer_length = 12
inputs = Input(shape=input_shape,batch_size=1)
# 1,224,224,3 -> 1,197,768
x = PatchEmbedding(image_size,patch_size,embed_dim,name='patchAndPos_embedding')(inputs)
# 循环layer_length遍
for i in range(1,layer_length+1):
h = x
x = LayerNormalization(name=f'LayerNorm{i}_1')(x)
# 1,197,768 -> 1,197,768
x = multiHead_self_attention(embed_dim,num_heads,0,name=f'MSA{i}')(x)
# 1,197,768 -> 1,197,768
x = Add(name=f'add{i}_1')([x,h])
h = x
x = LayerNormalization(name=f'LayerNorm{i}_2')(x)
# 1,197,768 -> 1,197,768
x = MLP(embed_dim,name=f'MLP{i}')(x)
# 1,197,768 -> 1,197,768
x = Add(name=f'add{i}_2')([x,h])
# 1,197,768 -> 1,768
cls_token = x[:,0] # 取出第1个token出来做分类
# 1,768 -> 1, num_classes
x = Dense(num_classes,name='classifier')(cls_token)
out = Softmax()(x)
model = Model(inputs=inputs,outputs=out,name='tf2-vit')
return model
if __name__ == '__main__':
input_shape = [224,224,3]
num_classes = 5
vitmodel = VisionTransformer(input_shape,num_classes)
vitmodel.summary()
Model: "tf2-vit"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(1, 224, 224, 3)] 0 []
patchAndPos_embedding (PatchEm (1, 197, 768) 742656 ['input_1[0][0]']
bedding)
LayerNorm1_1 (LayerNormalizati (1, 197, 768) 1536 ['patchAndPos_embedding[0][0]']
on)
MSA1 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm1_1[0][0]']
)
add1_1 (Add) (1, 197, 768) 0 ['MSA1[0][0]',
'patchAndPos_embedding[0][0]']
LayerNorm1_2 (LayerNormalizati (1, 197, 768) 1536 ['add1_1[0][0]']
on)
MLP1 (MLP) (1, 197, 768) 0 ['LayerNorm1_2[0][0]']
add1_2 (Add) (1, 197, 768) 0 ['MLP1[0][0]',
'add1_1[0][0]']
LayerNorm2_1 (LayerNormalizati (1, 197, 768) 1536 ['add1_2[0][0]']
on)
MSA2 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm2_1[0][0]']
)
add2_1 (Add) (1, 197, 768) 0 ['MSA2[0][0]',
'add1_2[0][0]']
LayerNorm2_2 (LayerNormalizati (1, 197, 768) 1536 ['add2_1[0][0]']
on)
MLP2 (MLP) (1, 197, 768) 0 ['LayerNorm2_2[0][0]']
add2_2 (Add) (1, 197, 768) 0 ['MLP2[0][0]',
'add2_1[0][0]']
LayerNorm3_1 (LayerNormalizati (1, 197, 768) 1536 ['add2_2[0][0]']
on)
MSA3 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm3_1[0][0]']
)
add3_1 (Add) (1, 197, 768) 0 ['MSA3[0][0]',
'add2_2[0][0]']
LayerNorm3_2 (LayerNormalizati (1, 197, 768) 1536 ['add3_1[0][0]']
on)
MLP3 (MLP) (1, 197, 768) 0 ['LayerNorm3_2[0][0]']
add3_2 (Add) (1, 197, 768) 0 ['MLP3[0][0]',
'add3_1[0][0]']
LayerNorm4_1 (LayerNormalizati (1, 197, 768) 1536 ['add3_2[0][0]']
on)
MSA4 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm4_1[0][0]']
)
add4_1 (Add) (1, 197, 768) 0 ['MSA4[0][0]',
'add3_2[0][0]']
LayerNorm4_2 (LayerNormalizati (1, 197, 768) 1536 ['add4_1[0][0]']
on)
MLP4 (MLP) (1, 197, 768) 0 ['LayerNorm4_2[0][0]']
add4_2 (Add) (1, 197, 768) 0 ['MLP4[0][0]',
'add4_1[0][0]']
LayerNorm5_1 (LayerNormalizati (1, 197, 768) 1536 ['add4_2[0][0]']
on)
MSA5 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm5_1[0][0]']
)
add5_1 (Add) (1, 197, 768) 0 ['MSA5[0][0]',
'add4_2[0][0]']
LayerNorm5_2 (LayerNormalizati (1, 197, 768) 1536 ['add5_1[0][0]']
on)
MLP5 (MLP) (1, 197, 768) 0 ['LayerNorm5_2[0][0]']
add5_2 (Add) (1, 197, 768) 0 ['MLP5[0][0]',
'add5_1[0][0]']
LayerNorm6_1 (LayerNormalizati (1, 197, 768) 1536 ['add5_2[0][0]']
on)
MSA6 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm6_1[0][0]']
)
add6_1 (Add) (1, 197, 768) 0 ['MSA6[0][0]',
'add5_2[0][0]']
LayerNorm6_2 (LayerNormalizati (1, 197, 768) 1536 ['add6_1[0][0]']
on)
MLP6 (MLP) (1, 197, 768) 0 ['LayerNorm6_2[0][0]']
add6_2 (Add) (1, 197, 768) 0 ['MLP6[0][0]',
'add6_1[0][0]']
LayerNorm7_1 (LayerNormalizati (1, 197, 768) 1536 ['add6_2[0][0]']
on)
MSA7 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm7_1[0][0]']
)
add7_1 (Add) (1, 197, 768) 0 ['MSA7[0][0]',
'add6_2[0][0]']
LayerNorm7_2 (LayerNormalizati (1, 197, 768) 1536 ['add7_1[0][0]']
on)
MLP7 (MLP) (1, 197, 768) 0 ['LayerNorm7_2[0][0]']
add7_2 (Add) (1, 197, 768) 0 ['MLP7[0][0]',
'add7_1[0][0]']
LayerNorm8_1 (LayerNormalizati (1, 197, 768) 1536 ['add7_2[0][0]']
on)
MSA8 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm8_1[0][0]']
)
add8_1 (Add) (1, 197, 768) 0 ['MSA8[0][0]',
'add7_2[0][0]']
LayerNorm8_2 (LayerNormalizati (1, 197, 768) 1536 ['add8_1[0][0]']
on)
MLP8 (MLP) (1, 197, 768) 0 ['LayerNorm8_2[0][0]']
add8_2 (Add) (1, 197, 768) 0 ['MLP8[0][0]',
'add8_1[0][0]']
LayerNorm9_1 (LayerNormalizati (1, 197, 768) 1536 ['add8_2[0][0]']
on)
MSA9 (multiHead_self_attention (1, 197, 768) 2362368 ['LayerNorm9_1[0][0]']
)
add9_1 (Add) (1, 197, 768) 0 ['MSA9[0][0]',
'add8_2[0][0]']
LayerNorm9_2 (LayerNormalizati (1, 197, 768) 1536 ['add9_1[0][0]']
on)
MLP9 (MLP) (1, 197, 768) 0 ['LayerNorm9_2[0][0]']
add9_2 (Add) (1, 197, 768) 0 ['MLP9[0][0]',
'add9_1[0][0]']
LayerNorm10_1 (LayerNormalizat (1, 197, 768) 1536 ['add9_2[0][0]']
ion)
MSA10 (multiHead_self_attentio (1, 197, 768) 2362368 ['LayerNorm10_1[0][0]']
n)
add10_1 (Add) (1, 197, 768) 0 ['MSA10[0][0]',
'add9_2[0][0]']
LayerNorm10_2 (LayerNormalizat (1, 197, 768) 1536 ['add10_1[0][0]']
ion)
MLP10 (MLP) (1, 197, 768) 0 ['LayerNorm10_2[0][0]']
add10_2 (Add) (1, 197, 768) 0 ['MLP10[0][0]',
'add10_1[0][0]']
LayerNorm11_1 (LayerNormalizat (1, 197, 768) 1536 ['add10_2[0][0]']
ion)
MSA11 (multiHead_self_attentio (1, 197, 768) 2362368 ['LayerNorm11_1[0][0]']
n)
add11_1 (Add) (1, 197, 768) 0 ['MSA11[0][0]',
'add10_2[0][0]']
LayerNorm11_2 (LayerNormalizat (1, 197, 768) 1536 ['add11_1[0][0]']
ion)
MLP11 (MLP) (1, 197, 768) 0 ['LayerNorm11_2[0][0]']
add11_2 (Add) (1, 197, 768) 0 ['MLP11[0][0]',
'add11_1[0][0]']
LayerNorm12_1 (LayerNormalizat (1, 197, 768) 1536 ['add11_2[0][0]']
ion)
MSA12 (multiHead_self_attentio (1, 197, 768) 2362368 ['LayerNorm12_1[0][0]']
n)
add12_1 (Add) (1, 197, 768) 0 ['MSA12[0][0]',
'add11_2[0][0]']
LayerNorm12_2 (LayerNormalizat (1, 197, 768) 1536 ['add12_1[0][0]']
ion)
MLP12 (MLP) (1, 197, 768) 0 ['LayerNorm12_2[0][0]']
add12_2 (Add) (1, 197, 768) 0 ['MLP12[0][0]',
'add12_1[0][0]']
tf.__operators__.getitem (Slic (1, 768) 0 ['add12_2[0][0]']
ingOpLambda)
classifier (Dense) (1, 5) 3845 ['tf.__operators__.getitem[0][0]'
]
softmax_12 (Softmax) (1, 5) 0 ['classifier[0][0]']
==================================================================================================
Total params: 29,131,781
Trainable params: 29,131,781
Non-trainable params: 0
_________________________________________________________________________________________
|