mask机制
- encoder中对输入序列的长度进行pad 0到max_src_len,在计算自注意力的时候,只对有效序列长度进行attention计算,pad的0需要mask; 【endoer_mhsa_mas——
s
h
a
p
e
:
(
b
a
t
c
h
_
s
i
z
e
(
B
)
,
s
r
c
_
s
e
q
_
l
e
n
(
N
s
)
,
N
s
)
\rm shape:(batch\_size(B), src\_seq\_len(N_s), N_s)
shape:(batch_size(B),src_seq_len(Ns?),Ns?)】
- decoder中的第一个masked多头自注意力模块输入序列为了不能看到当前token之后的信息,需要对当前toekn之后的tokens进行mask;【decoder_mhsa_mask——
s
h
a
p
e
:
(
b
a
t
c
h
_
s
i
z
e
(
B
)
,
t
g
t
_
s
e
q
_
l
e
n
(
N
t
)
,
N
t
)
\rm shape:(batch\_size(B), tgt\_seq\_len(N_t), N_t)
shape:(batch_size(B),tgt_seq_len(Nt?),Nt?)】
- decoder中第二个多头交叉注意力模块中query来自decoder的输入的当前token,key-value来自encoder的输出,综合上述两种mask机制,应该对不需要计算注意力的位置进行mask。【decoder_mhca_mask——
s
h
a
p
e
:
(
b
a
t
c
h
_
s
i
z
e
(
B
)
,
N
t
,
N
s
)
\rm shape:(batch\_size(B), N_t, N_s)
shape:(batch_size(B),Nt?,Ns?)】
上述三种mask机制对应原始论文中的自注意力层如上图所示。
Pytorch代码实现
预定义输入输出序列
d_model = 512
vocab_size = 1000
dropout = 0.1
padding_idx = 0
tgt_len = [4, 2, 6]
src_len = [4, 5, 3]
src_seq = x = torch.cat([
F.pad(torch.randint(1, vocab_size, (1, L)), (0, max(src_len) - L)) for L in src_len
])
tgt_seq = y = torch.cat([
F.pad(torch.randint(1, vocab_size, (1, L)), (0, max(tgt_len) - L)) for L in tgt_len
])
encoder自注意力中的mask
1/True 表示该位置要mask, 0/False 表示该位置不需要mask
方法1
该方法利用向量之间的相似性 即 (n, 1) @ (1, n) -> (n, n) 就能得到每个维度之间的相关性 最后取反即可得到mask矩阵 这种方法看起来比较直观 类似于求两个向量之间的协方差
X
X
T
\rm XX^T
XXT
valid_encoder_mhsa_pos = torch.vstack([
F.pad(torch.ones(L), (0, max(src_len) - L)) for L in src_len
]).unsqueeze(-1)
encoder_mhsa_mask = 1 - torch.bmm(valid_encoder_mhsa_pos, valid_encoder_mhsa_pos.transpose(-2, -1))
print(f'encoder_mhsa_mask:\n{encoder_mhsa_mask}')
输出如下:
encoder_mhsa_mask:
tensor([[[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
[[0., 0., 0., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]])
方法2
def get_pad_mask(seq_k, pad_idx):
return (seq_k == pad_idx).unsqueeze(-2)
def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
padding_attn_mask = seq_k.data.eq(0).unsqueeze(1)
return padding_attn_mask.expand(batch_size, len_q, len_k)
seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]])
src_mask1 = get_pad_mask(seq_q, 0)
src_mask2 = get_attn_padding_mask(seq_q, seq_q)
attn = torch.randn(seq_q.size(1), seq_q.size(1))
attn1 = attn.masked_fill(src_mask1 == 1, -1e9)
attn2 = attn.masked_fill(src_mask2 == 1, -1e9)
print(attn1 == attn2)
输出如下
query的无效长度对key的有效长度的注意力没有被mask 但是应该不影响最终的结果
tensor([[[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1]],
[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
encoder_mhsa_mask:
tensor([[[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]])
decoder中的masked自注意力中的mask
方法1
先形成一个下三角矩阵, 其他位置pad 0, 然后取反就得到了mask 这样看起来也很直观,对无效长度的地方也进行了mask
decoder_mhsa_mask = 1 - torch.stack([
F.pad(torch.tril(torch.ones(L, L)), (0, max(tgt_len) - L, 0, max(tgt_len) - L)) \
for L in tgt_len
])
print(f'decoder_mhsa_mask:\n{decoder_mhsa_mask.shape}')
输出如下
tensor([[[0., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1.],
[0., 0., 0., 1., 1., 1.],
[0., 0., 0., 0., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]],
[[0., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]]])
方法2
def get_pad_mask(seq_k, pad_idx):
return (seq_k == pad_idx).unsqueeze(-2)
def get_subsequent_mask(seq):
sz_b, len_s = seq.size()
subsequent_mask = (torch.triu(
torch.ones((1, len_s, len_s)), diagonal=1)).bool()
return subsequent_mask
seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 0, 0, 0, 0]])
mask = get_pad_mask(seq_k, 0) | get_subsequent_mask(seq_k)
print(get_pad_mask(seq_q, 0).byte())
print(get_subsequent_mask(seq_q).byte())
print(mask)
输出如下
tensor([[[0, 0, 0, 0, 1, 1]],
[[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
torch.Size([2, 1, 6])
tensor([[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
torch.Size([1, 6, 6])
tensor([[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1]],
[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
tensor([[[0., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1.],
[0., 0., 0., 1., 1., 1.],
[0., 0., 0., 0., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]],
[[0., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]]])
decoder交叉注意力中的mask
方法1
valid_decoder_mhca_pos = torch.vstack([
F.pad(torch.ones(L), (0, max(tgt_len) - L)) for L in tgt_len
]).unsqueeze(-1)
decoder_mhca_mask = 1 - torch.matmul(valid_decoder_mhca_pos, valid_encoder_mhsa_pos.transpose(-2, -1))
print(f'decoder_mhca_mask:\n{decoder_mhca_mask.shape}')
输出如下
tensor([[[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]])
方法2
def get_pad_mask(seq_k, pad_idx):
return (seq_k == pad_idx).unsqueeze(-2)
def get_subsequent_mask(seq):
sz_b, len_s = seq.size()
subsequent_mask = (torch.triu(
torch.ones((1, len_s, len_s)), diagonal=1)).bool()
return subsequent_mask
seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 0, 0, 0, 0]])
src_mask1 = get_pad_mask(seq_k, 0)
src_mask2 = get_attn_padding_mask(seq_q, seq_k)
attn = torch.randn(seq_q.size(1), seq_k.size(1))
attn1 = attn.masked_fill(src_mask1 == 1, -1e9)
attn2 = attn.masked_fill(src_mask2 == 1, -1e9)
输出如下
tensor([[[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1]],
[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
tensor([[[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]])
|