?贴一个自己根据原论文以及参考众多项目的一个transformer复现:
import torch
import torch.nn as nn
from torch.nn import functional as F
class PositionalEncoding(nn.Module):
def __init__(self, din, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=0.1)
pe = torch.zeros(max_len, din)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, din, 2).float() * (-math.log(10000.0) / din))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
class PositionWiseFeedForward(nn.Module):
def __init__(self, dim, ff_dim):
super().__init__()
self.fc1 = nn.Linear(dim, ff_dim)
self.fc2 = nn.Linear(ff_dim, dim)
def forward(self, x):
# (B, S, D) -> (B, S, D_ff) -> (B, S, D)
return self.fc2(F.relu(self.fc1(x)))
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
self.num_heads = num_heads
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, input_Q, input_K, input_V, pad_mask=None):
"""
input_Q: [batch_size, len_q, d_model]
input_K: [batch_size, len_k, d_model]
input_V: [batch_size, len_v(=len_k), d_model]
attn_mask: [batch_size, seq_len, seq_len]
return: [batch_size, len_q, d_model]
"""
assert self.d_model % self.num_heads == 0
dk = dv = self.d_model // self.num_heads
batch_size = input_Q.size(0)
# 分组进行可以先整组运算再拆
Q = self.W_Q(input_Q).view(batch_size, -1, self.num_heads, dk).permute(0, 2, 1, 3)
K = self.W_K(input_K).view(batch_size, -1, self.num_heads, dk).permute(0, 2, 1, 3)
V = self.W_V(input_V).view(batch_size, -1, self.num_heads, dv).permute(0, 2, 1, 3)
# ScaledDotProductAttention
batch_size = Q.size(0)
scores = Q @ K.transpose(-1, -2) / torch.sqrt(dk)
if pad_mask is not None:
pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
scores = scores.masked_fill(pad_mask, -1e9)
scores = self.drop(F.softmax(scores, dim=-1))
h = (scores @ V).transpose(1, 2).contiguous()
h = h.view(batch_size, -1, self.dim)
h = self.fc(h)
return h
class EncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads, dropout)
self.pwff = PositionWiseFeedForward(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask):
a1 = self.norm1(self.dropout1(self.attn(x, x, x, mask)) + x)
a2 = self.norm1(self.dropout2(self.pwff(a1)) + a1)
return a2
class DecoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.attn1 = MultiHeadAttention(d_model, num_heads, dropout)
self.attn2 = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.pwff = PositionWiseFeedForward(d_model, d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, tgt, memory, tgt_mask, memory_mask):
x1 = self.norm1(self.dropout1(self.attn1(tgt, tgt, tgt, tgt_mask)) + tgt)
x2 = self.norm2(self.dropout2(self.attn2(x1, memory, memory, memory_mask)) + x1) # 注意顺序是 Q K V
x3 = self.norm3(self.dropout3(self.pwff(x2)) + x2)
return x3
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, N, num_heads, dropout):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.pe = PositionalEncoding(d_model)
self.layers = nn.ModuleList([DecoderBlock(d_model, num_heads, dropout) for _ in range(N)])
self.N = N
def forward(self, tgt, memory, memory_mask, tgt_mask):
x = self.embed(tgt)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, memory, tgt_mask, memory_mask)
return x
class Encoder(nn.Module):
" Input Embedding PositionalEncoding n x EncoderLayer"
def __init__(self, vocab_size, d_model, N, num_heads, dropout):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.pe = PositionalEncoding(d_model)
self.layers = nn.ModuleList([EncoderBlock(d_model, num_heads, dropout) for _ in range(N)])
self.N = N
def forward(self, src, mask):
x = self.embed(src)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, mask)
return x
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model, N, num_heads, dropout):
super().__init__()
self.encoder = Encoder(src_vocab, d_model, N, num_heads, dropout)
self.decoder = Decoder(tgt_vocab, d_model, N, num_heads, dropout)
self.out = nn.Linear(d_model, tgt_vocab)
def forward(self, src, tgt, src_mask, tgt_mask):
memory = self.encoder(src, src_mask)
d_output = self.decoder(tgt, memory, src_mask, tgt_mask)
output = self.out(d_output)
return output
|