【记录】使用transformers从头开始训练bert
这篇记录主要记录使用transformers库训练从头开始训练自己的bert预训练模型;
bert训练任务;
bert预训练模型包含两个任务:
- mask词预测
- 相邻句子预测
使用的API
使用的api为BertForPreTraining
from transformers import BertConfig, BertForPreTraining
config = BertConfig(vocab_size=len(WORDS) + 1)
model = BertForPreTraining(config)
for epoch in range(200):
for data in data_loader:
next_sentence_label = data['next_sentence_label'].to(device).long()
input_ids = data['input_ids'].to(device).long()
token_type_ids = data['token_type_ids'].to(device).long()
attention_mask = data['attention_mask'].to(device).long()
labels = data['bert_label'].to(device).long()
outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
labels=labels, next_sentence_label=next_sentence_label)
loss = outputs['loss']
optim.zero_grad()
loss.backward()
optim.step()
步骤
1)构建数据集,bert的数据集不用标注,因为他的训练任务标签可以程序生成,使用爬取的文章就行;
2)构建字典,一般来说是基于数据集来构建字典,但要注意bert中预留的几个特殊字符(用来对应任务,这里当然也可以自定义,但还是建议按照原版的来),;或者你也可以直接用下载好的字典;
3)构建数据集,和训练数据;这是重点,注意数据的组织形式;
首先是针对预训练任务,
1)next sentence, 你需要一次性输入两个句子,标签为句子是否相邻,注意输入的时候正负均衡采样;输入拼接如下:
input_token: 【cls】+sentence1+【sep】+sentence2+【sep】
next_sentence_label: 1&0
2)mask语言模型:bert针对上述句子,随机maks或者替换掉15%的词,这部分标签,则为mask位置的单词的真实编号,需要注意的是,标签的其他位置用-100代替,如下示例; 操作完之后记得padding;
def random_word(sentence):
"""mask language model, 添加15%的mask"""
tokens = [char for char in sentence]
output_label = []
for i, token in enumerate(tokens):
prob = random.random()
if prob < 0.15:
prob /= 0.15
if prob < 0.8:
tokens[i] = voc['[MASK]']
elif prob < 0.9:
tokens[i] = random.randrange(len(voc))
else:
tokens[i] = voc.get(token, voc['[UNK]'])
output_label.append(voc.get(token, voc['[UNK]']))
else:
tokens[i] = voc.get(token, voc['[UNK]'])
output_label.append(-100)
return tokens, output_label
3)构建token_type_id和attention_mask;token_typeid用来区分第一句还是第二句,attention_mask记录非padding部分;格式如下;
输入组合:【cls】你好吗【sep】我很好啊【sep】【pad】...
input_token: [102],[15],[17],[19],[103][14],[12],[17],[20],[103][0]...
token_type_id: 0, 0,0,0, 0, 1,1,1,1,1[0]...
attention_mask: 1,1,1,1,1,1,1,1,1,1,0,...
需要注意两点:
句子1包含了头部的【cls】和一个【sep】,句子2包含最后一个【sep】,所以token_type_id要对应到上,不能错了位置;bert_label不包含这个标识符,标识符位置用【pad】
参考代码如下:
def __getitem__(self, idx):
t1, t2, is_next_label = self.get_sentence(idx)
t1_random, t1_label = self.random_word(t1)
t2_random, t2_label = self.random_word(t2)
t1 = [self.vocab['[CLS]']] + t1_random + [self.vocab['[SEP]']]
t2 = t2_random + [self.vocab['[SEP]']]
t1_label = [self.vocab['[PAD]']] + t1_label + [self.vocab['[PAD]']]
t2_label = t2_label + [self.vocab['[PAD]']]
segment_label = ([0 for _ in range(len(t1))] + [1 for _ in range(len(t2))])[:self.seq_len]
bert_input = (t1 + t2)[:self.seq_len]
bert_label = (t1_label + t2_label)[:self.seq_len]
padding = [self.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
attention_mask = len(bert_input) * [1] + len(padding) * [0]
bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
attention_mask = np.array(attention_mask)
bert_input = np.array(bert_input)
segment_label = np.array(segment_label)
bert_label = np.array(bert_label)
is_next_label = np.array(is_next_label)
output = {"input_ids": bert_input,
"token_type_ids": segment_label,
'attention_mask': attention_mask,
"bert_label": bert_label}, is_next_label
return output
|