IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【深度学习】Transformer中的mask机制超详细讲解 -> 正文阅读

[人工智能]【深度学习】Transformer中的mask机制超详细讲解

mask机制

  1. encoder中对输入序列的长度进行pad 0到max_src_len,在计算自注意力的时候,只对有效序列长度进行attention计算,pad的0需要mask; 【endoer_mhsa_mas—— s h a p e : ( b a t c h _ s i z e ( B ) , s r c _ s e q _ l e n ( N s ) , N s ) \rm shape:(batch\_size(B), src\_seq\_len(N_s), N_s) shape:(batch_size(B),src_seq_len(Ns?),Ns?)
  2. decoder中的第一个masked多头自注意力模块输入序列为了不能看到当前token之后的信息,需要对当前toekn之后的tokens进行mask;【decoder_mhsa_mask—— s h a p e : ( b a t c h _ s i z e ( B ) , t g t _ s e q _ l e n ( N t ) , N t ) \rm shape:(batch\_size(B), tgt\_seq\_len(N_t), N_t) shape:(batch_size(B),tgt_seq_len(Nt?),Nt?)
  3. decoder中第二个多头交叉注意力模块中query来自decoder的输入的当前token,key-value来自encoder的输出,综合上述两种mask机制,应该对不需要计算注意力的位置进行mask。【decoder_mhca_mask—— s h a p e : ( b a t c h _ s i z e ( B ) , N t , N s ) \rm shape:(batch\_size(B), N_t, N_s) shape:(batch_size(B),Nt?,Ns?)

    上述三种mask机制对应原始论文中的自注意力层如上图所示。

Pytorch代码实现

预定义输入输出序列

# 词嵌入向量维度
d_model = 512

# 单词表大小
vocab_size = 1000

# dropout比例
dropout = 0.1

# 构造两个序列 序列的长度分别为4 5, 0索引为无效
padding_idx = 0

tgt_len = [4, 2, 6]
src_len = [4, 5, 3]

# 随机生成3个序列 不足最大序列长度的 在序列结尾pad 0
src_seq = x = torch.cat([
    F.pad(torch.randint(1, vocab_size, (1, L)), (0, max(src_len) - L)) for L in src_len
])
# [4, 5, 3]
# tensor([[129, 490, 572, 764,   0],
#         [636, 151, 572, 482, 666],
#         [439, 757,  18,   0,   0]])
tgt_seq = y = torch.cat([
    F.pad(torch.randint(1, vocab_size, (1, L)), (0, max(tgt_len) - L)) for L in tgt_len
])
# [4, 2, 6]
# tensor([[509, 360, 486,  88,   0,   0],
#         [415, 609,   0,   0,   0,   0],
#         [767, 817,  59, 990, 853, 101]])

encoder自注意力中的mask

1/True表示该位置要mask, 0/False表示该位置不需要mask

方法1

该方法利用向量之间的相似性 即 (n, 1) @ (1, n) -> (n, n)就能得到每个维度之间的相关性 最后取反即可得到mask矩阵
这种方法看起来比较直观 类似于求两个向量之间的协方差 X X T \rm XX^T XXT

# ----------------------------------
# encoder multi-head self-attn mask
# 在计算输入x token之间的attn_score时 需要忽略pad 0
valid_encoder_mhsa_pos = torch.vstack([
    F.pad(torch.ones(L), (0, max(src_len) - L)) for L in src_len
]).unsqueeze(-1)  # 扩展维度 用于批量计算mask矩阵 (B, Ns, 1) x (B, 1, Ns) -> (B, Ns, Ns)
# print(f'valid_encoder_mhsa_pos: {valid_encoder_mhsa_pos.shape}')

encoder_mhsa_mask = 1 - torch.bmm(valid_encoder_mhsa_pos, valid_encoder_mhsa_pos.transpose(-2, -1))
print(f'encoder_mhsa_mask:\n{encoder_mhsa_mask}')
# ----------------------------------

输出如下:

encoder_mhsa_mask:
tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 1.],
         [0., 0., 0., 1., 1.],
         [0., 0., 0., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

方法2

# 在对attn进行mask的时候可以直接进行广播 而且只需要关注key的pad即可 无需对query中的pad进行mask 与方法一不同
# 因为query是在和key求相似度,只要把key的无效长度mask即可 这样就能得到与key有效长度的注意力分数
# 当然也可以对query进行mask 类似方法一 这样可能就多一些额外的对query的处理
def get_pad_mask(seq_k, pad_idx):
    return (seq_k == pad_idx).unsqueeze(-2)  
    # 添加一个维度,和attn进行mask操作的时候可以进行broadcast

# 该函数将得到mask进行了expand操作,复制了seq_q的长度份 
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()

    padding_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    return padding_attn_mask.expand(batch_size, len_q, len_k)
# 相比来说,get_pad_mask则更简洁 推荐使用
# 此处为了简化表示 将query key value 都只用了(B, seq_len)表示有效位置 及pad的0
# 实际情况中需要加上embed_dim,也就是每个token对应的嵌入向量维度
seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]])


# 对query求自注意力的mask
src_mask1 = get_pad_mask(seq_q, 0)
src_mask2 = get_attn_padding_mask(seq_q, seq_q)
attn = torch.randn(seq_q.size(1), seq_q.size(1))
attn1 = attn.masked_fill(src_mask1 == 1, -1e9)
attn2 = attn.masked_fill(src_mask2 == 1, -1e9)
print(attn1 == attn2)  # 可以验证两者相同

输出如下

query的无效长度对key的有效长度的注意力没有被mask 但是应该不影响最终的结果

# seq_k的自注意力mask
tensor([[[0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]], dtype=torch.uint8)
# 可以将mask矩阵的0维度看做q, 1维度看做k False表示q对k计算了相似度

encoder_mhsa_mask: # 对比上述mask
tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])

decoder中的masked自注意力中的mask

方法1

先形成一个下三角矩阵, 其他位置pad 0, 然后取反就得到了mask
这样看起来也很直观,对无效长度的地方也进行了mask

# ----------------------------------
# decoder multi-head self-attn mask
# decoder输入需要形成一个下三角矩阵mask (B, Nt, Nt)
# 不能看到当前token之后的信息
decoder_mhsa_mask = 1 - torch.stack([
    F.pad(torch.tril(torch.ones(L, L)), (0, max(tgt_len) - L, 0, max(tgt_len) - L)) \
        for L in tgt_len
])

print(f'decoder_mhsa_mask:\n{decoder_mhsa_mask.shape}')
# ----------------------------------

输出如下

tensor([[[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [0., 0., 0., 1., 1., 1.],
         [0., 0., 0., 0., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]],

        [[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]]])

方法2

def get_pad_mask(seq_k, pad_idx):
    return (seq_k == pad_idx).unsqueeze(-2)  
    # 添加一个维度,和attn进行mask操作的时候可以进行broadcast

# 生成一个上三角矩阵
def get_subsequent_mask(seq):
    sz_b, len_s = seq.size()
    subsequent_mask = (torch.triu(
        torch.ones((1, len_s, len_s)), diagonal=1)).bool()
    return subsequent_mask
# 但是想要和方法1一样对key中的无效长度部分mask 需要如下操作

seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 0, 0, 0, 0]])

mask = get_pad_mask(seq_k, 0) | get_subsequent_mask(seq_k)
print(get_pad_mask(seq_q, 0).byte())
print(get_subsequent_mask(seq_q).byte())
print(mask)

输出如下

# print(get_pad_mask(seq_q, 0).byte())
tensor([[[0, 0, 0, 0, 1, 1]],

        [[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
torch.Size([2, 1, 6])

# print(get_subsequent_mask(seq_q).byte())
# 上三角矩阵
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
torch.Size([1, 6, 6])

# 上述两个矩阵进行或操作即可得到最终的mask 也就是保证query和key的有效长度计算注意力
# 两个seq对应的上三角矩阵相同 序列长度不一样 因此广播之后就得到了两个不同序列的有效上三角矩阵
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1, 1]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
# query 的无效长度处的token也对key的有效长度计算了注意力 本来应该无 有也不影响 

# 对比方法1
tensor([[[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [0., 0., 0., 1., 1., 1.],
         [0., 0., 0., 0., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]],

        [[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]]])

decoder交叉注意力中的mask

方法1

# ----------------------------------
# decoder multi-head cross-attn mask
# Q --> decoder
# K V --> encoder
# mask shape (B, Nt, Ns)
valid_decoder_mhca_pos = torch.vstack([
    F.pad(torch.ones(L), (0, max(tgt_len) - L)) for L in tgt_len
]).unsqueeze(-1)  # 同encoder
# (B, Nt, 1) x (B, 1, Ns) -> (B, Nt, Ns)
decoder_mhca_mask = 1 - torch.matmul(valid_decoder_mhca_pos, valid_encoder_mhsa_pos.transpose(-2, -1))
print(f'decoder_mhca_mask:\n{decoder_mhca_mask.shape}')
# ----------------------------------

输出如下

tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

方法2


def get_pad_mask(seq_k, pad_idx):
    return (seq_k == pad_idx).unsqueeze(-2)  
    # 添加一个维度,和attn进行mask操作的时候可以进行broadcast

# 生成一个上三角矩阵
def get_subsequent_mask(seq):
    sz_b, len_s = seq.size()
    subsequent_mask = (torch.triu(
        torch.ones((1, len_s, len_s)), diagonal=1)).bool()
    return subsequent_mask

seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 0, 0, 0, 0]])
# get_attn_padding_mask(seq_q, seq_k).byte()

# 求交叉注意力mask
src_mask1 = get_pad_mask(seq_k, 0)
src_mask2 = get_attn_padding_mask(seq_q, seq_k)
attn = torch.randn(seq_q.size(1), seq_k.size(1))
attn1 = attn.masked_fill(src_mask1 == 1, -1e9)
attn2 = attn.masked_fill(src_mask2 == 1, -1e9)
# print(attn1 == attn2)  # 可以验证两者相同

输出如下

tensor([[[0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]], dtype=torch.uint8)

# 对比方法1
tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

# 相当于seq_q的无效长度处tokens对seq_k的有效长度的token也进行了注意力计算但没mask
# 方法1则mask掉了
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-20 23:00:25  更:2022-06-20 23:01:55 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 3:33:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码