入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。
目录
一、原理
1、总体介绍
2、具体实现
(1)不采取稀疏处理(默认)
(2)采取稀疏训练
?(3)稀疏推断
二、代码解析
1、__init__
(1)参数设定
(2)存储激活检查点标志
(3)定义输出层初始化方法
(4)Position embedding
(5)窗口定义
(6)Transformer layers设置
(7)将 num_layer 个 transformer layer打包在一起,以列表形式保存
(8)output层的LayerNorm处理
(9)激活点检查
2、forward
(1)获取最终的输入层的相关信息
(2)attention mask建立
(3)稀疏训练or推断准备
(4)对输入层的处理
(5)这次是否有产生记忆模块
(6)获取下一层的输入——分为是否采取检查点激活两种情况来分析
(7)最后一层norm
(8)记忆模块更新
(9)返回这一层的输出结果和记忆模块
一、原理
1、总体介绍
将n个的 transformer blocks 打包在一起,即 n * transformer layer + final layernorm 两部分组成
2、具体实现
(1)不采取稀疏处理(默认)
?(2)采取稀疏训练
?新建的rmask(k为输入的总列数;w为窗口大小;t为调整窗口数量所用)
?(3)稀疏推断
二、代码解析
1、__init__
(1)参数设定
class GPT2ParallelTransformer(torch.nn.Module):
"""GPT-2 transformer.
This module takes input from embedding layer and it's output can
be used directly by a logit layer. It consists of L (num-layers)
blocks of:
layer norm
self attention
residual connection
layer norm
mlp
residual connection
followed by a final layer norm.
Arguments:
num_layers: Number of transformer layers.
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
checkpoint_activations: if True, checkpoint activations.
checkpoint_num_layers: number of layers to checkpoint. This
is basically the chunk size in checkpoitning.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method_std: standard deviation of the init method which has
the form N(0, std).
use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
scaling for the output weights (
output of self attention and mlp).
"""
def __init__(self,
num_layers,
hidden_size,
num_attention_heads,
max_sequence_length,
max_memory_length,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
use_scaled_init_for_output_weights=True,
query_window=128,
key_window_times=6,
num_pivot=768
):
super(GPT2ParallelTransformer, self).__init__()
- num_layers:transformer层的数量;
- hidden_size:自我注意力模块的隐藏大小(嵌入向量的维度);
- num_attention_heads:自我注意力模块中attention head的数量;
- max_sequence_length:词典大小;
- max_memory_length:最大记忆长度;
- embedding_dropout_prob:嵌入层(该模块的输入部分)中元素被dropout的概率(为了解决过拟合问题而随机丢弃一部分元素);
- attention_dropout_prob:同样道理,注意力模块中注意力得分被dropout的概率;
- output_dropout_prob:同理,输出层后的输出被dropout的概率;
- checkpoint_activations:是否执行检查点激活;
- checkpoint_num_layers:检查点的层数。这基本上是checkpoitning中的块大小;
- layernorm_epsilon:在layernform中用于避免被零除的ε(用于防止分母为0);
- init_method_std:初始化方法(使用让权重呈现正态分布的方法)中正态分布的方差;
- use_scaled_init_for_output_weights:是否对自注意力和mlp的输出的权重调用scaled_init_method进行初始化;
- query_window:稀疏处理中的窗口大小;
- key_window_times:用于调整窗口数量;
- num_pivot:transformer里图像token和文本token的总和数量
(2)存储激活检查点标志
# Store activation checkpoiting flag.
#首先先记录是否执行检查点激活,检查点的层数,最大记忆长度和最大序列长度信息
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.max_memory_length = max_memory_length
self.max_sequence_length = max_sequence_length
(3)定义输出层初始化方法
由use_scaled_init_for_output_weights决定,若为False则不进行初始化缩放,若为true则调用scaled_init_method进行初始化
#输出层初始化方法定义——由use_scaled_init_for_output_weights决定,若为False则不进行初始化缩放,若为true则调用scaled_init_method进行初始化
output_layer_init_method = None
if use_scaled_init_for_output_weights:
output_layer_init_method = scaled_init_method(init_method_std,
num_layers)
scaled_init_method函数——返回初始化方法:初始权重呈均值为0,方差为init_method_std//sqrt(2*num_layers)的正态分布。
def scaled_init_method(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
(4)Position embedding
先进行嵌入层的dropout(防止过拟合),然后调用torch.nn.Embedding()方法按词典大小max_sequence_length和嵌入向量的维度hidden_size来定义词向量格式,然后将词向量的值初始化为呈以0为均值,以init_method_std为方差的正态分布。
# Embeddings dropout嵌入层dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
# Position embedding (serial).初始化含位置信息的词向量方法
self.position_embeddings = torch.nn.Embedding(max_sequence_length,
hidden_size)#随机以max_sequence_length为词典的大小(词的个数),以hidden_size来嵌入向量的维度(即用多少维来表示一个符号)初始化词向量,默认词向量值在正态分布N(0,1)中随机取值
# Initialize the position embeddings.词向量值在正态分布N(0,init_method_std)中随机取值
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
(5)窗口定义
self.query_window = query_window
self.key_window_times = key_window_times
self.num_pivot = num_pivot
首先定义了一个get_layer()函数来获得对应层id的网络层(transformer layer)
#获得对应层id的网络层
def get_layer(layer_id):
return GPT2ParallelTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
unscaled_init_method(init_method_std),
output_layer_init_method=output_layer_init_method,
query_window=query_window,
key_window_times=key_window_times,
scale_normalization=True
)
这里调用了GPT2ParallelTransformerLayer类
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(num_layers)])
(8)output层的LayerNorm处理
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
(9)激活点检查
if deepspeed.checkpointing.is_configured():
global get_cuda_rng_tracker, checkpoint
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint
self.rmask = None#是否进行稀疏处理
2、forward
def forward(self, hidden_states, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse=0, *mems):
'''''
hidden_states:输入的网络层;
position_ids:位置编码;
attention_mask;
txt_indices_bool:选取文本token有效的索引矩阵
img_indices_bool:选取图像token有效的索引矩阵
is_sparse:是否稀疏处理,稀疏训练,稀疏推断
mems:记忆模块;
'''''
(1)获取最终的输入层的相关信息
获取b,s和最终的输入列数(hidden_states和记忆模块的concat的结果)
batch_size, query_length = hidden_states.size()[:2]#获取batchsize(b)和读取的序列长度(s)
memory_length = mems[0].size(1) if mems else 0#获取记忆模块的序列长度(模块列数)
key_length = query_length + memory_length#得到最终的序列长度(类似concat维数增加)
(2)attention mask建立
最终shape[1,1,s,s](无记忆模块情况下,有记忆为[1,1,s,s+m],m为memory_length)
# conventional transformer
#建立常规transformer的attention mask
def build_mask_matrix(query_length, key_length, sep):
m = torch.ones((1, query_length, key_length), device=hidden_states.device, dtype=hidden_states.dtype)#初始化为全一矩阵
assert query_length <= key_length
m[0, :, -query_length:] = torch.tril(m[0, :, -query_length:])#返回m[0, :, -query_length:]区域(最后两维)是下三角矩阵的矩阵
m[0, :, :sep + (key_length - query_length)] = 1#注意力标记
m = m.unsqueeze(1)#[1,s,s+m]->[1,1,s,s+m]
return m
#生成attention_mask,无记忆模块是[1,1,s,s],有记忆是[1,1,s,s+m]
attention_mask = build_mask_matrix(query_length, key_length, sep)
(3)稀疏训练or推断准备
?获取稀疏训练的rmask
#启用稀疏训练生成rmask
if is_sparse == 1 and (self.rmask is None):
w, times = self.query_window, self.key_window_times#滑动窗口大小+窗口数的减少量获取
g = key_length // w#获取全局attention窗口个数
tmp = torch.ones((g-times+1, w , w), device=hidden_states.device, dtype=hidden_states.dtype)#初始化rmask(可理解为g-times+1个窗口)
tmp = torch.tril(1 - torch.block_diag(*tmp))#*将三维矩阵变成二维矩阵列表;torch.block_diag将g-times+1个w*w矩阵组合成一个块对角矩阵,1-使得中间块为0,其余为1;torch.tril返回下三角矩阵。shape为((g-times+1)*w,(g-times+1)*w)
self.rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom),这四个元素的位置代表了填充的位置,大小为填充的行数,默认填0,所以最终shape为(g*w,g*w),左下角为一个((g-times+1)*w,(g-times+1)*w)大小的下三角矩阵
?获取左边界和支点
if is_sparse == 2:#稀疏推断
left_boundary = max(0, key_length - self.key_window_times * self.query_window)#获取左边界(将key_length分为n份query_window的块块,做除法后的余数部分为左边界
window_idx = torch.arange(left_boundary, key_length, device=hidden_states.device, dtype=torch.long).expand(batch_size, -1)#torch.arange获得[left_boundary,...,key_length-1];expand(batch_size, -1)获得batchsize条[left_boundary,...,key_length-1],获得shape为(batchsize*key_length-left_boundary)
elif is_sparse == 1:#稀疏训练
left_boundary = key_length#获取左边界
num_pivot = self.num_pivot#transformer里图像token和文本token的总和数量获取
?选取每个batch中对应有效的index的image token和txt token
#选取每个batch中对应有效的index的image token和txt token
if is_sparse: # 1 or 2
# select out the real indices for sampling
img_indices = [img_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)]#.nonzero(as_tuple=False)取出非0元素的索引(即取出有效索引);.view(-1)将其展平
txt_indices = [txt_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)]
?稀疏推断支点数目设定
#稀疏推断支点数目设定(总token数量增加)
if is_sparse == 2:
ratio = self.num_pivot / self.max_sequence_length#支点比例获取
max_text_num = max(len(text_idx) for text_idx in txt_indices)#获取batch中最长的有效文本token长度
num_pivot = max_text_num + int((left_boundary - max_text_num) * ratio)#支点数目更新
(4)对输入层的处理
给输入层加入初始化的位置信息词向量并且进行dropout操作
#对输入层的处理
position_embeddings = self.position_embeddings(position_ids)#对位置信息position_ids进行词向量的初始化
hidden_states = hidden_states + position_embeddings#输入层加入初始化的位置信息词向量
hidden_states = self.embedding_dropout(hidden_states)#对输入层进行dropout
(5)这次是否有产生记忆模块
若拥有最大记忆长度,则产生的记忆模块是输入层,但不需要计算其梯度
#这次是否有产生记忆模块
if self.max_memory_length > 0:#若拥有最大记忆长度,
mem_layers = [hidden_states.detach()]#记忆模块赋为输入层,但不需要计算其梯度
else:#否则没有记忆模块
mem_layers = []
然后保存一下attention mask
attention_mask_saved = attention_mask#保存attention mask
(6)获取下一层的输入——分为是否采取检查点激活两种情况来分析
(都要利用get_layer来实现,所以都要先获取相应的参数输入才可调用)
?采取检查点激活
①首先是必要的初始化和参数获取
l = 0#初始化start层id
num_layers = len(self.layers)#Transformer layers的数量获取
chunk_length = self.checkpoint_num_layers#检查点的层数
循环获取层
while l < num_layers:
②稀疏训练or推断情况下获取下一层的输入的参数
if is_sparse > 0:#稀疏训练or推断
🌳获取pivot的索引(pivot即随机抽取的token,用于代表全局整幅图片)
# ===================== Pivot Mask ======================== #
pivot_idx = torch.stack([
torch.cat((
text_idx,
img_indices[i][
torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
]
), dim=0)
for i, text_idx in enumerate(txt_indices)
])
#首先由random.sample随机抽取(预设支点数量-该batch的有效文本token长度)=该batch的有效图像token长度个图像token索引,并且将文本token和图像token拼接在一起
🌳然后对于稀疏训练:获取pivot_attention_mask,进而获取输入所需的参数列表
if is_sparse == 1: # sparse training
assert key_length == query_length#断言最终的序列长度和读取的序列长度(s)是否相同
b, s = batch_size, key_length
pivot_attention_mask = self.rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, self.num_pivot))#生成针对随机选取的token的注意力矩阵——pivot attention mask
#expand()函数扩展维度,其余不变。
# 相当于先由b个原来的s*s(s=g*w)大小的rmask(即每个batch里都有rmask)拼成一个大小为(b,s,s)的矩阵;再由gather函数根据 index 参数(即是索引)返回矩阵里面对应位置的值(即挑出随机选中的token对应索引值的rmask值)——针对的是每个batch的s*s的rmask;最后再由expand函数展成大小为(b,s,随机选取的token数量)的矩阵
args = [hidden_states, pivot_attention_mask, pivot_idx, torch.tensor(is_sparse)]#参数列表记录
🌳然后对于稀疏推理:获取全部需要注意的token的idx,并形成参数列表
elif is_sparse == 2: # sparse inference
pw_idx = torch.cat((pivot_idx, window_idx), dim=-1)#获取随机选取的token的idx矩阵与额外标记注意的窗口的idx矩阵concat后的需要attention的idx矩阵
args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)]#参数列表记录
🌳错误提示
else:
raise NotImplementedError
③非稀疏处理情况下参数列表获取
else:
args = [hidden_states, attention_mask_saved]#非稀疏处理的参数列表记录(输入层和attention mask)
④记忆模块对参数列表的补充
#对于记忆模块的参数补充
if mems:
args += mems[l: l + chunk_length]
⑤获得下一层的输入并进行检查,且start层idx(l)更新
#检查点激活并得到下一层的输入层
hidden_states = checkpoint(custom(l, l + chunk_length), *args)
#start为第l层,end为第l + chunk_length层(共检查点层数数量)
l += chunk_length#下一个检查点的开始层数
这里调用custom函数——用于获取下一层的输入
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]#获取对应的层序列
x_, inputs = inputs[0], inputs[1:]#将他们分成两份(头和其余)
if is_sparse > 0:#稀疏处理
inputs, mems_ = inputs[:3], inputs[3:]#输入为前3层,其余为记忆模块
else:#不采取稀疏处理
inputs, mems_ = inputs[:1], inputs[1:]#输入为第1层,其余为记忆模块
for i, layer in enumerate(layers_):
mem_i_ = mems_[i] if mems_ else None#获取第i层的记忆模块
x_ = layer(x_, *inputs, mem=mem_i_)#调用get_layer中GPT2ParallelTransformerLayer的forward——x_对应hidden_states(输入), inputs对应ltor_mask(attention mask)
if self.max_memory_length > 0:
mem_layers.append(x_.detach())#记忆模块添加(不参与梯度计算)
return x_
return custom_forward
?不采取检查点激活
思路和上面的检查点激活类似,只是不考虑了检查点层数和checkpoint
else:#不进行检查点激活
assert is_sparse != 1, 'Please use checkpoint_activations for sparse attention training.'
for i, layer in enumerate(self.layers):#遍历Transformer layers
if is_sparse == 0:#非稀疏处理——获取下一步传入的参数列表
args = [hidden_states, attention_mask_saved]
elif is_sparse == 2:#稀疏推断
pivot_idx = torch.stack([
torch.cat((
text_idx,
img_indices[i][
torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
]
), dim=0)
for i, text_idx in enumerate(txt_indices)
])#首先由random.sample随机抽取(预设支点数量-该batch的有效文本token长度)=该batch的有效图像token长度个图像token索引,并且将文本token和图像token拼接在一起
pw_idx = torch.cat((pivot_idx, window_idx), dim=-1)#获取随机选取的token的idx矩阵与额外标记注意的窗口的idx矩阵concat后的需要attention的idx矩阵
args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)]#参数列表记录
mem_i = mems[i] if mems else None#对应层的记忆模块
hidden_states = layer(*args, mem=mem_i)#下一层的输入层获取
if self.max_memory_length > 0:#记忆层添加
mem_layers.append(hidden_states.detach())
(7)最后一层norm
作Layernorm操作
# Final layer norm.
output = self.final_layernorm(hidden_states)#即对下一层的输入(这一层的输出)做一个LayerNorm规范
(8)记忆模块更新
#更新记忆模块
if self.max_memory_length > 0:
mem_layers = self.update_mems(mem_layers, mems)
这里调用update_mems进行更新
def update_mems(self, hiddens, mems):
memory_length = mems[0].size(1) if mems else 0#原记忆模块的长度(列数)
query_length = hiddens[0].size(1)#新待加入的记忆模块长度(列数)
new_memory_length = min(self.max_memory_length, memory_length + query_length)#新的记忆模块的长度确定
new_mems = []
with torch.no_grad():
for i in range(len(hiddens)):
if new_memory_length <= query_length:#说明选中的是self.max_memory_length(记忆模块完全为新的记忆层组成)。取每一层的每一行的后new_memory_length组成新的记忆矩阵
new_mems.append(hiddens[i][:, -new_memory_length:])
else:#说明选中的是memory_length + query_length。取原来的记忆模块和新加入的进行拼接(沿列拼接)
new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
return new_mems
(9)返回这一层的输出结果和记忆模块
return (output, *mem_layers)#返回下一层的输入(这一层的输出结果)和记忆模块
欢迎大家在评论区批评指正,谢谢~
|