ViT网络paddle代码
加入位置信息
在ViT中引入一个额外的token用来学习全局信息从而进行分类
?Mutil Head Attention ?
? #基于paddle
? ? ?
#2021/12/13
#注:该代码是paddlepaddle官方开的ViT课程中老师编写的,我只是把它搬运过来以防丢失,方便随#时查找
import paddle
import paddle.nn as nn
import numpy as np
from PIL import Image
from attention import Attention
paddle.set_device('cpu')
class Identity(nn.Layer):#定义一个啥也不干
def __init__(self):
super().__init__()
def forward(self, x):
return x
class PatchEmbedding(nn.Layer):
def __init__(self, image_size, patch_size, in_channels, embed_dim, dropout=0.):
super().__init__()
self.embed_dim = embed_dim
n_patches = (image_size // patch_size) * (image_size // patch_size)
self.patch_embed = nn.Conv2D(in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,#这里是关键,是一个无重叠的卷积,本质上可以看成是一个MLP,用来将一个batch做embed
bias_attr=False)
self.dropout = nn.Dropout(dropout)
#add class token
self.class_token = paddle.create_parameter(
shape = [1, 1,embed_dim],
dtype='float32',
default_initializer = nn.initializer.Constant(0.))
#add position embedding
self.position_embedding = paddle.create_parameter(
shape = [1, n_patches+1, embed_dim],
dtype='float32',
default_initializer=nn.initializer.TruncatedNormal(std=.02)
)
def forward(self,x):
#[n, c, h, w]
cls_tokens = self.class_token.expand([x.shape[0], -1, -1])
x = self.patch_embed(x) #[n,embed_dim,h',w']
x = x.flatten(2)#[n,embed_dim,num_patches]
x = x.transpose([0, 2, 1])#[n,num_patches,embed_dim]
x = paddle.concat([cls_tokens, x], axis=1)
x = x + self.position_embedding
return x
class Mlp(nn.Layer):
def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.):
super().__init__()
self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))
self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class EncoderLayer(nn.Layer):
def __init__(self, embed_dim=768, num_heads=4, qkv_bias=True, mlp_ratio=4.0,dropout=0.,attention_dropout=0.):
super().__init__()
self.attn = Attention(embed_dim, num_heads)#TODO
self.attn_norm = nn.LayerNorm(embed_dim)
self.mlp = Mlp(embed_dim, mlp_ratio)
self.mlp_norm = nn.LayerNorm(embed_dim)
def forward(self, x):
h = x#注意这里用的是pre——norm也就是先做norm再做计算
x = self.attn_norm(x)
x = self.attn(x)
x = h + x
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = h + x
return x
class Encoder(nn.Layer):
def __init__(self, embed_dim, depth):
super().__init__()
layer_list = []
for i in range(depth):
encoder_layer = EncoderLayer()
layer_list.append(encoder_layer)
self.layers = nn.LayerList(layer_list)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = self.norm(x)
return x
class Visualtransformer(nn.Layer):
def __init__(self,
image_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=3,
num_heads=8,
mlp_ratio=4,
qkv_bias=True,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super().__init__()
self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
self.encoder = Encoder(embed_dim, depth)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x):
#x:[N,C,H,W]
x = self.patch_embedding(x)#x:[N,embed_dim,h',w']
#x = x.flatten(2)#[N, embed_dim, h'*w']
#x = x.transpose([0,2,1])
x = self.encoder(x)
x = self.classifier(x[:, 0])
return x
def main():
t = paddle.randn([4, 3, 224, 224])
model = Visualtransformer()
paddle.summary(model,(4,3,224,224))
out = model(t)
print(out.shape)
#3.MLP
if __name__ == '__main__':
main()
|