深度学习中,我们经常会遇到需要添加mask的场景,如:
- nlp中为了长度对齐,需要补齐长度,但在计算attention时会将补齐位置mask掉从而不参与attention计算;
- mask相关的预训练任务,如MLM、MAE等,需要mask掉被遮盖的token,以完成预测的预训练任务;
- swin中,在做shift操作后,为了防止原本物理位置不相邻的区域产生交互,需要进行mask attention。
- 计算loss时想忽略掉一些不想用来计算该loss的样本。
样例
在attention操作中,在计算attn softmax前,将被mask位置的logits设置为一个很小的数,如-10000,在计算softmax后,就会抑制掉这些位置的作用,代码如下:
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 # 分母根号d
self.with_qkv = with_qkv
if self.with_qkv:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_drop = nn.Dropout(attn_drop)
def forward(self, x, attention_mask=None):
B, N, C = x.shape
if self.with_qkv:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
else:
qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q, k, v = qkv, qkv, qkv
attn = (q @ k.transpose(-2, -1)) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=attn.dtype)
attention_mask = (1.0 - attention_mask) * -10000.0
attn = attn + attention_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
if self.with_qkv:
x = self.proj(x)
x = self.proj_drop(x)
return x
|