t5模型是常用于文本生成部分的一个模型,也是目前我看到的各个nlp模型之中,唯一完整地使用transformer的所有完整结构(encoder部分加上decoder部分)的一个模型,接下来聊一下t5模型的生成优化过程。
优化的部分
首先对于生成这一块,最慢的速度在于推断而不在于训练,所以t5模型的优化部分在推断内容部分进行优化,推断部分使用的是transformer中的decoder结构,这里我们先看一下t5的decoder主要构成,我将它的结构图简化如下:
DecoderLayerTransformers DecoderLayerAttention
decoder部分的结构图---
DecoderCrossTransformers DecoderCrossAttention
1.decoderlayerattention的拼接
在这里的DecoderLayerAttention中优化的过程,采用的是计算完key和value的值之后,将之前同一网络层的key和value与现在网络层的key和value值拼接在一块,这里的关键点在于每一次对于下一个单词进行预测的时候,实际上只需要预测当前单词的概率即可,并不需要把所有的词语的概率全部都预测出来。 具体分析: 这里输入的query值只是由当前的单词id所构成,而key和value通过拼接之后,实际上跟原始的key的value的值相同
if past_key_value != None:
key = torch.cat([past_key_value[0],key],dim=2)
if past_key_value != None:
value = torch.cat([past_key_value[1],value],dim=2)
首先t5模型没有position_embedding,只有word_embedding的情况下,a的embedding和b的embedding拼接在一起,跟a+b的embedding拼接在一起的结果是一样的,这就保证了第一次decoderlayerattention网络层中的输入一样。 其次关键的在于,在attention的公式之中
A
t
t
e
n
t
i
o
n
(
K
,
Q
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
Attention(K,Q,V) = Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V
Attention(K,Q,V)=Softmax(dk?
?QKT?)V 对于每一个句子中单独的词来说,这个句子中每一个词的信息都会影响到当前词语的信息,因此在softmax的公式之中,这里的前后信息交互的关键点在于两点 1.
Q
K
T
QK^{T}
QKT 这里的Q代表每一个词,而K向量代表与其他信息交互的句子中所有的词语,因为我们只需要关注当前的词,所以这里Q只需要取出当前的词即可,因为在我们预测下一个词的概率时,我们只需要知道当前词语的信息。但是这里的K代表的整个句子各个词语的词向量的信息,这些词语会对Q的当前这个词产生影响,因此这里的K必须是完整的,否则会造成只有部分的词向量对当前的K产生影响,造成结果的不准确。 2.
S
o
f
t
m
a
x
(
Q
K
T
d
k
)
Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})
Softmax(dk?
?QKT?)的结果乘上V 这里的V与上面的K向量同理,是用来交互的,因此V也必须保持向量的完整性,也就需要进行拼接。
2.decoderlayerattention不拼接完整的计算过程
为了读懂不拼接的计算过程,这里我专门比较了一下拼接与不拼接的整个计算过程,下面从整个模型具体的走一遍。 这里我们假设batch_size为1,从第二次生成的输入开始,即max_length = 2,而优化的部分max_length的值永远为1。
步骤1.原始输入
原始输入为input_ids = (1,2,768)(从第二次输入开始),则优化后的输入为input_ids = (1,1,768),这里的(1,1,768)为(1,2,768)的后面维度的矩阵信息。
步骤2.经过query,key,value网络层的输出结果
输入的input_ids分别经过query,key,value三个网络层,得到优化之前的输入为query = (1,2,768),key = (1,2,768),value = (1,2,768),如果是优化后的矩阵向量,则query = (1,1,768),key = (1,1,768),value = (1,1,768)
步骤3.拼接past_key和past_value的值
之前聊到key和value记录了整个句子完整的向量信息,所以key和value的值需要保持完整,因此这里优化的部分需要拼接上之前计算的key和value的信息,得到与未优化之前相同的key和value的值 key = (1,2,768),value = (1,2,768)
步骤4.将query,key,value的值进行拆解
key = key.view(batch_size,-1,self.config.num_heads,self.config.size_per_head)
value = value.view(batch_size,-1,self.config.num_heads,self.config.size_per_head)
这里只是拆解最后一个维度,跟前面的内容没有关系, 未优化之后拆解出的向量维度为 query = (1,2,12,64),key = (1,2,12,64),value = (1,2,12,64) 优化之后拆解出的向量维度为 query = (1,2,12,64)(这里的1为未优化2的最后一维),key = (1,2,12,64),value = (1,2,12,64)
步骤5.transpose操作
query = query.transpose(1,2)
key = key.transpose(1,2)
value = value.transpose(1,2)
未优化的向量 query = (1,12,2,64),key = (1,12,2,64),value = (1,12,2,64) 优化后的向量 query = (1,12,1,64),key = (1,12,2,64),value = (1,12,2,64) 这里的未优化的query与优化的query的关系是未优化的query中每间隔一波会得到去取优化的query,具体内容如下
[[[[1.4670e-02, 1.2135e-01, -1.9623e-01, -6.7619e-02, -2.3180e-01,
-7.0569e-02, 3.0198e-02, -1.6405e-01, -2.1121e-01, 3.5561e-02,
5.2837e-02, -4.2778e-02, 4.4988e-02, -2.1911e-01, 9.9120e-03,
1.7280e-01, 9.6410e-02, -7.0861e-02, 1.2775e-01, -1.2981e-01,
-2.9778e-02, -1.5738e-01, -6.2758e-02, 1.1868e-02, -6.1655e-04,
-4.7120e-04, 2.2410e-01, 7.7863e-02, -1.5514e-01, -6.7375e-02,
-2.8295e-02, -3.7375e-02, 6.5420e-02, -7.0116e-02, -9.4598e-02,
7.6665e-03, 9.0203e-03, 9.5351e-03, 6.0339e-03, 4.3080e-02,
-3.7418e-03, 7.9097e-03, -2.0718e-02, 6.0211e-02, -2.0378e-02,
2.0389e-02, 1.0511e-01, -1.4655e-01, -1.4036e-01, -9.6260e-02,
-1.9263e-03, -9.3218e-02, -1.2116e-02, -2.5003e-01, 1.2911e-01,
2.0643e-01, -3.3275e-02, -3.7536e-02, -2.0530e-01, 6.0305e-02,
-5.0186e-02, 9.7535e-02, 1.2266e-01, -5.5510e-02],
(下面为优化的参数部分1)
[-4.9299e-04, -2.6518e-02, -1.0284e-01, -2.8316e-02, 1.4845e-01,
2.4035e-02, -1.1961e-01, 1.4806e-02, -1.0239e-02, -1.2139e-01,
-9.8680e-02, 8.0918e-02, 1.0025e-01, 2.6667e-03, -9.8165e-03,
4.2049e-02, -7.7295e-03, -1.2139e-01, -9.3672e-02, -9.5083e-02,
4.5946e-02, -3.0294e-02, 1.1112e-01, 1.0422e-03, 1.6938e-01,
5.7986e-02, -3.0207e-02, 1.9465e-01, 1.0476e-01, 1.9391e-01,
-1.1847e-02, 8.8533e-02, 6.9837e-02, 8.0209e-02, -2.8369e-03,
-1.0628e-01, 3.3870e-02, 2.6993e-02, 9.9969e-02, -5.6613e-02,
-1.8507e-01, -5.4061e-02, 1.0212e-01, 7.1761e-02, -7.0027e-02,
-8.0300e-02, 1.0698e-01, -2.2551e-02, -5.0615e-02, 4.5206e-02,
1.2523e-01, 2.4763e-02, -3.0133e-02, 1.3311e-01, -6.7099e-02,
2.0527e-02, -7.5543e-02, 2.0281e-02, 3.3036e-02, -4.6203e-02,
-3.1867e-02, 2.7087e-02, 4.4259e-02, -4.5838e-02]],
[[-1.4408e-01, 1.5186e-01, -3.5974e-03, -6.0452e-02, -7.9654e-02,
-4.7384e-02, -2.2503e-02, 3.4365e-01, 1.2694e-01, -4.5820e-02,
3.0887e-02, 1.0342e-01, -1.0327e-01, -4.4748e-03, -1.8945e-01,
-2.8120e-02, 1.4307e-01, 2.0421e-02, 1.0495e-01, 3.9390e-02,
-1.6610e-01, -8.9004e-02, -8.5324e-02, 8.1240e-02, -8.8805e-02,
5.4473e-02, 2.4430e-01, -1.9869e-01, 1.1048e-01, -5.5874e-02,
1.5152e-01, 7.5828e-02, -2.1933e-01, -2.9484e-01, 3.2189e-03,
8.5885e-02, -5.4767e-02, -1.8218e-01, 1.7896e-01, -6.2724e-02,
4.0281e-02, -1.1383e-01, -1.2164e-01, -2.7832e-01, 1.3230e-01,
-2.9016e-02, -1.6377e-01, 1.7774e-01, 1.0014e-01, 1.1170e-01,
8.6232e-03, 2.3320e-01, 5.7124e-03, -4.9258e-02, 1.1669e-02,
-8.8721e-02, -2.6996e-02, -2.5208e-02, 6.9340e-02, -4.9958e-03,
-8.7542e-02, 1.2076e-01, -1.4579e-02, -7.5249e-02],
(下面为优化的参数部分2)
[-7.5944e-02, 3.1948e-02, -1.3581e-01, 1.5007e-01, -2.0339e-02,
-8.8263e-02, -2.0876e-02, 9.4324e-02, 8.8024e-02, -7.6570e-02,
2.4468e-02, -1.4054e-01, 1.4406e-01, -6.7122e-02, -3.1915e-01,
1.3064e-01, -1.0095e-02, 7.6921e-02, 1.3721e-01, 1.6839e-01,
1.2220e-01, 1.2142e-01, -1.0998e-01, 1.4507e-01, 1.7634e-04,
-1.6147e-01, 4.2507e-02, 1.6338e-01, 6.4832e-02, -3.5208e-02,
-9.1079e-02, 5.4273e-04, -9.4308e-03, 5.4123e-02, 4.8480e-02,
1.0629e-01, -1.5671e-02, -3.0359e-02, 3.4089e-02, 1.0919e-02,
-7.4085e-02, -8.4118e-02, -3.2656e-02, -5.7829e-02, 7.5287e-02,
1.6278e-01, 6.6263e-02, -9.5057e-02, -6.9226e-02, -1.0631e-01,
-3.3601e-02, -1.8329e-02, 1.4080e-01, -2.6989e-02, 2.4369e-01,
-1.7388e-01, -1.0928e-01, -1.9072e-01, 1.3444e-02, 1.0209e-01,
-7.9701e-02, -1.8037e-02, 9.1845e-02, -9.0854e-02]],
可以看出这里的参数是隔着相等的
步骤6.相乘得到scores的内容
scores = torch.matmul(
query,key.transpose(3,2)
)
得到scores优化后的内容和未曾优化后的内容 未曾优化后得到的scores内容
scores =
tensor([[[[-0.5559, 0.4717],
(下面是优化后的参数内容)
[ 1.9015, 2.6016]],
[[ 2.4006, -5.8395],
(下面是优化后的参数内容)
[ 2.7282, 2.8212]],
[[-4.8745, -3.9924],
(下面是优化后的参数内容)
[-0.6129, 0.8469]],
[[-1.1932, -6.7089],
(下面是优化后的参数内容)
[ 3.9086, -7.9154]],
[[-5.8440, 3.6701],
(下面是优化后的参数内容)
[ 8.3781, 7.1280]],
[[-5.2619, -3.4999],
(下面是优化后的参数内容)
[-0.4299, -1.0349]],
[[-2.2012, 0.9375],
(下面是优化后的参数内容)
[ 8.4696, 4.2147]],
[[-3.0179, -1.1522],
(下面是优化后的参数内容)
[ 2.1133, 1.3674]],
[[ 0.6151, 1.4965],
(下面是优化后的参数内容)
[ 3.1316, 5.0822]],
[[-3.9366, -1.4958],
(下面是优化后的参数内容)
[ 1.8155, 0.0675]],
[[ 3.9352, 0.9258],
(下面是优化后的参数内容)
[ 1.8104, 5.5592]],
[[-5.1672, -2.3799],
(下面是优化后的参数内容)
[-1.6228, -1.7248]]],
步骤7.scores+position_bias值:没变化
position_bias相等,所以没变化 这里的值未优化时scores.shape = (2,12,2,2),优化之后scores.shape = (2,12,1,2)间隔相等。
步骤8.计算attn_weights与value相乘
attn_output = torch.matmul(attn_weights,value)
这里的value优化与未优化的参数值是一样的,所以经过相乘之后,得到attn_output的值优化与未优化的还是间隔相等的 这里未优化的情况下,attn_weights = (2,12,2,2),value = (2,12,2,64),相乘之后 attn_output =
(
2
,
12
,
2
,
2
)
?
(
2
,
12
,
2
,
64
)
=
(
2
,
12
,
2
,
64
)
(2,12,2,2)*(2,12,2,64) = (2,12,2,64)
(2,12,2,2)?(2,12,2,64)=(2,12,2,64) 优化的情况下,attn_weights = (2,12,1,2),value = (2,12,2,64),相乘之后attn_output =
(
2
,
12
,
1
,
2
)
?
(
2
,
12
,
2
,
64
)
=
(
2
,
12
,
1
,
64
)
(2,12,1,2)*(2,12,2,64) = (2,12,1,64)
(2,12,1,2)?(2,12,2,64)=(2,12,1,64) 此时还是间隔相等
步骤9.计算attn_output
attn_output = attn_output.transpose(1,2)
这里transpose(1,2)之后,如果优化的情况下attn_output = (2,12,2,64)->(2,2,12,64),不优化的情况下attn_output = (2,12,1,64)->(2,1,12,64),此时由于transpose翻转矩阵的存在,本身矩阵由间隔相等变为了最后一维度相等 开头一次transpose将矩阵的形状翻转
Q
T
Q^T
QT,结尾的时候又调用了一次矩阵的翻转 这里的结果翻转最主要的是与开头的翻转进行抵消,开头有这样一段翻转
query = query.transpose(1,2)
key = key.transpose(1,2)
value = value.transpose(1,2)
而结尾的时候将相乘出来的结果翻转一次,能够将开头的翻转抵消掉
attn_output = attn_output.transpose(1,2)
3.decodercrossattention的拼接
这里的拼接过程较为简单,只需要拼接上之前在encoderlayerattention网络层部分的输出即可,所以直接保存encoderlayerattention网络层之前的输出经过当前网络层的内容,避免重复计算。 注意这里保存的也是每一层的encoder编码部分的输出内容 这里的key和value都是当前网络层经过key_layer和value_layer线性层的输出
attn_output = attn_output.transpose(1,2).contiguous().view(batch_size,-1,self.config.num_heads*self.config.size_per_head)
4.简化运算,从另外一个角度来看只取出最后一个维度的计算结果
A
t
t
e
n
t
i
o
n
(
K
,
Q
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
Attention(K,Q,V) = Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V
Attention(K,Q,V)=Softmax(dk?
?QKT?)V 这里我们去除掉对于维度变化等结果没有影响的操作过程,将公式简化为如下操作
Q
K
T
V
QK^{T}V
QKTV 这样原始变化为
Q
K
T
QK^{T}
QKT =
(
1
,
12
,
5
,
64
)
?
(
1
,
12
,
64
,
5
)
=
(
1
,
12
,
5
,
5
)
(1,12,5,64)*(1,12,64,5) = (1,12,5,5)
(1,12,5,64)?(1,12,64,5)=(1,12,5,5) 优化之后的变换为
Q
K
T
QK^{T}
QKT =
(
1
,
12
,
1
,
64
)
?
(
1
,
12
,
64
,
5
)
=
(
1
,
12
,
1
,
5
)
(1,12,1,64)*(1,12,64,5) = (1,12,1,5)
(1,12,1,64)?(1,12,64,5)=(1,12,1,5) 这样优化之后的变换
(
1
,
12
,
1
,
5
)
(1,12,1,5)
(1,12,1,5)正好为
(
1
,
12
,
5
,
5
)
(1,12,5,5)
(1,12,5,5)的最后一波,所以在每一个(1,12)内,这里的优化后的5向量是每隔5个位置出现一波 接着操作
(
Q
K
T
)
V
(QK^{T})V
(QKT)V,未优化的情况下等于
(
1
,
12
,
5
,
5
)
?
(
1
,
12
,
5
,
64
)
=
(
1
,
12
,
5
,
64
)
(1,12,5,5)*(1,12,5,64) = (1,12,5,64)
(1,12,5,5)?(1,12,5,64)=(1,12,5,64), 优化下的情况等于
(
1
,
12
,
1
,
5
)
?
(
1
,
12
,
5
,
64
)
=
(
1
,
12
,
1
,
64
)
(1,12,1,5)*(1,12,5,64) = (1,12,1,64)
(1,12,1,5)?(1,12,5,64)=(1,12,1,64) 这里正好(1,64)就是(5,64)的最后一维度,所以每隔5波出现一次,总共出现12次 最后这里翻转就是将这些间隔的相同内容聚集在一起
(
1
,
12
,
5
,
64
)
?
>
(
1
,
5
,
12
,
64
)
?
>
(
1
,
5
,
768
)
(1,12,5,64) -> (1,5,12,64)->(1,5,768)
(1,12,5,64)?>(1,5,12,64)?>(1,5,768),
(
1
,
12
,
1
,
64
)
?
>
(
1
,
1
,
12
,
64
)
?
>
(
1
,
5
,
768
)
(1,12,1,64) -> (1,1,12,64)->(1,5,768)
(1,12,1,64)?>(1,1,12,64)?>(1,5,768),因此这里最后一个维度的内容相同
5.随便说说
今天在力扣中无意间看到桶排序,感觉可以用于生成的topk计算算法优化 桶排序算法 不过其实整体数据也没多少,所以感觉topk优化的话提升的效率也不大
|