GPT2论文 Language models are unsupervised multitask learners GPT2模型结构 transformers库中的GPT2模型源码几部分分别为 GPT2LMHeadModel类、GPT2Model类、Block类、MLP类与Attention类 其中,一个Transformer的Block类中包含了Attention类与MLP类,而在GPT2Model类中又以Block类为基础构建了12层Block的模型主体结构
GPT2LMHeadModel主体为调用GPT2Model类以及一个输出层self.lm_head,
GPT2Model类用来进行12层Block的计算
输出层self.lm_head则将GPT2Model类输出的最后一个Block层的隐藏状态hidden_states张量的最后一个维度由768(config.n_embd)投影为(config.vocab_size),hidden_states经过输出层投影后即为lm_logits
当使用GPT2LMHeadModel类来进行自回归预训练时,其可以传入labels,当GPT2LMHeadModel类中使用GPT2Model类(self.transformer)与输出层self.lm_head计算得出了最终的lm_logits时,lm_logits张量便可以与传入的labels张量利用自回归的方式 (即取(1, n-1)的lm_logits值与(2, n)的label值) 来计算自回归交叉熵损失值loss,自回归交叉熵损失值loss便可以用来反向传播计算梯度,最终优化整个GPT2模型。
需要注意的是此时代码中的config为transformers库中的configuration_gpt2模块中的GPT2Config类,GPT2Config类中保存了GPT2模型中的各种超参数,若在使用GPT2模型时需要修改某一超参数,则只需在传入GPT2模型中的config(GPT2Config类)中修改对应超参数即可。
class GPT2LMHeadModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights()
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2",
output_type=CausalLMOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPastAndCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
|