Whole Word Masking (wwm)
本文代码部分参考github项目: https://github.com/BSlience/search-engine-zerotohero/tree/main/public/bert_wwm_pretrain
Whole Word Masking (wwm),暂翻译为全词Mask或整词Mask,是谷歌在2019年5月31日发布的一项BERT的升级版本,主要更改了原预训练阶段的训练样本生成策略。我们先看下BERT原文的遮蔽语言模型。
BERT–遮蔽语言模型
在BERT之前,标准的条件语言模型只能从左到右或从右到左进行训练,因为双向条件作用将允许每个单词在多层上下文中间接地看到自己,为了训练深度双向表示,BERT采用了一种简单的方法,即随机遮蔽一定比例的输入标记,然后仅预测那些被遮蔽的标记,这一过程被称为遮蔽语言模型(MLM, masked language model),尽管在文献中它通常被称为完型填词任务。
在这种情况下,就像在标准语言模型中一样,与遮蔽标记相对应的最终隐藏向量被输入到与词汇表对应的输出 softmax 中(也就是要把被遮蔽的标记对应为词汇表中的一个词语)。在所有的实验中,BERT在每个序列中随机遮蔽 15% 的标记。
虽然这确实允许我们获得一个双向预训练模型,但这种方法有两个缺点。第一个缺点是,我们在预训练和微调之间造成了不匹配,因为 [MASK] 标记在微调期间从未出现过。为了缓和这种情况,我们并不总是用真的用 [MASK] 标记替换被选择的单词。而是,训练数据生成器随机选择 15% 的标记,例如,在my dog is hairy 这句话中,它选择 hairy。然后执行以下步骤:
- 80% 的情况下:用 [MASK] 替换被选择的单词,例如,my dog is hairy → my dog is [MASK]
- 10% 的情况下:用一个随机单词替换被选择的单词,例如,my dog is hairy → my dog is apple
- 10% 的情况下:保持被选择的单词不变,例如,my dog is hairy → my dog is hairy。这样做的目
的是使表示偏向于实际观察到的词。
Transformer 编码器不知道它将被要求预测哪些单词,或者哪些单词已经被随机单词替换,因此它被迫保持每个输入标记的分布的上下文表示。另外,因为随机替换只发生在 1.5% 的标记(即,15% 的10%)这似乎不会损害模型的语言理解能力。
第二个缺点是,使用 Transformer 的每批次数据中只有 15% 的标记被预测,这意味着模型可能需要更多的预训练步骤来收敛。在 5.3 节中,我们证明了 Transformer 确实比从左到右的模型(预测每个标记)稍微慢一点,但是 Transformer 模型的实验效果远远超过了它增加的预训练模型的成本。
WordPiece
BERT原文中的遮蔽语言模型是基于wordPiece拆词后的子词进行MASK,所谓的wordPiece其实是把word再进一步的拆分,拆分为piece,得到更细粒度。
比如**“loved”,“loving”,“loves”**这三个单词。其实本身的语义都是“爱”的意思,但是如果我们以单词为单位,那它们就算作是不一样的词,在英语中不同后缀的词非常的多,就会使得词表变的很大,训练速度变慢,训练的效果也不是太好。
WordPiece与BPE(Byte-Pair Encoding)双字节编码算法比较相似,它们是两种不同的子词切分算法,主要区别在于如何选择两个子词进行合并。
例如WordPiece(或BPE)通过训练,能够把上面的”loved”,”loving”,”loves”3个单词拆分成”lov”,”ed”,”ing”,”es”几部分,这样可以把词的本身的意思和时态分开,有效的减少了词表的数量。
Whole Word Masking策略
在BERT中,原有基于WordPiece的分词方式会把一个完整的词切分成若干个子词,在生成训练样本时,这些被分开的子词会随机被mask。 在全词Mask中,如果一个完整的词的部分WordPiece子词被mask,则同属该词的其他部分也会被mask,即全词Mask。
需要注意的是,这里的mask指的是广义的mask(替换成[MASK];保持原词汇;随机替换成另外一个词),并非只局限于单词替换成[MASK]标签的情况。
由于谷歌官方发布的BERT-base, Chinese中,中文是以字为粒度进行切分,没有考虑到传统NLP中的中文分词(CWS, chinese word segment),所以全词Mask可以用在中文预训练中。
数据示例(方便理解)
- 原始文本: 使用语言模型来预测下一个词的probability。
- 分词文本: 使用 语言 模型 来 预测 下 一个 词 的 probability 。
- 原始Mask输入(mlm): 使 用 语 言 [MASK] 型 来 [MASK] 测 下 一 个 词 的 pro [MASK] ##lity 。
- 全词Mask输入(wwm): 使 用 语 言 [MASK] [MASK] 来 [MASK] [MASK] 下 一 个 词 的 [MASK] [MASK] [MASK] 。
代码实现
因为后面我会针对huggingface transformer中的chinese_bert wwm模型进行fine tune,该模型使用的是wwm(也就是全词MASK方法),所以这里记录whole Word Masking的一种实现方式。
huggingface transformer中有一个data collator的概念,数据整理器(data collator)是通过使用数据集元素列表作为输入来形成批次的对象。这些元素与train_dataset或eval_dataset的元素类型相同。
为了能够构建批处理,数据整理器可能会应用一些处理(如填充、截断)。其中一些(如DataCollatorForLanguageModeling)还对所形成的批处理应用了一些随机数据扩充(如随机屏蔽)。
huggingface transformer中关于data collator的文档
当然MASK操作也属于数据整理器的功能之一,整个data collator的步骤如下:
- 先获得这个批次数据的最大长度max_seq_len;
- 对句子进行补齐和截断;
- 对于每个样本的input_ids,随机选择20%字(token),认为其和前面一个词可能组成词;
- 在对应的token前添加特殊符号**##**比如 4 -> ##4
- 将带特征符号##的token传入mask方法(这里是self._whole_word_mask),随机选择15%的字认为是需要mask的,如果选到的字是带##标记的,那么就把它前面的字一起mask,返回mask_label;
- 根据mask_label和input_ids进行mask(80%进行mask掉,10%进行随机替换,10%选择保持不变)
注意:步骤3中选择的20%,是认为可能组成词的字(并不是需要mask的字),因为是随机选的,所以可能根本不是词,因为参考的这个项目就是这么实现的,所以在我看来是一个不完整的实现方案,如果有能力、有兴趣的小伙伴可以完整实现,也就是找到真正的词,可以借助一些分词工具。
下面是实现代码。
class DataCollator:
def __init__(self, max_seq_len: int, tokenizer: BertTokenizer, mlm_probability=0.15):
self.max_seq_len = max_seq_len
self.tokenizer = tokenizer
self.mlm_probability = mlm_probability
def truncate_and_pad(self, input_ids_list, token_type_ids_list, attention_mask_list, max_seq_len):
input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
token_type_ids = torch.zeros_like(input_ids)
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
seq_len = len(input_ids_list[i])
if seq_len <= max_seq_len:
input_ids[i, :seq_len] = torch.tensor(input_ids_list[i][:seq_len], dtype=torch.long)
else:
input_ids[i, :seq_len] = torch.tensor(input_ids_list[i][:max_seq_len - 1] +
[self.tokenizer.sep_token_id], dtype=torch.long)
print(input_ids[i])
seq_len = min(len(input_ids_list[i]), max_seq_len)
token_type_ids[i, :seq_len] = torch.tensor(token_type_ids_list[i][:seq_len], dtype=torch.long)
attention_mask[i, :seq_len] = torch.tensor(attention_mask_list[i][:seq_len], dtype=torch.long)
return input_ids, token_type_ids, attention_mask
def _whole_word_mask(self, input_ids_list: List[str], max_seq_len: int, max_predictions=512):
cand_indexes = []
for (i, token) in enumerate(input_ids_list):
if (token == str(self.tokenizer.cls_token_id)
or token == str(self.tokenizer.sep_token_id)):
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_ids_list) * self.mlm_probability))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
assert len(covered_indexes) == len(masked_lms)
mask_labels = [1 if i in covered_indexes else 0 for i in range(min(len(input_ids_list), max_seq_len))]
mask_labels += [0] * (max_seq_len - len(mask_labels))
return torch.tensor(mask_labels)
def whole_word_mask(self, input_ids_list: List[list], max_seq_len: int) -> torch.Tensor:
mask_labels = []
for input_ids in input_ids_list:
wwm_id = random.choices(range(len(input_ids)), k=int(len(input_ids)*0.2))
input_id_str = [f'##{id_}' if i in wwm_id else str(id_) for i, id_ in enumerate(input_ids)]
mask_label = self._whole_word_mask(input_id_str, max_seq_len)
mask_labels.append(mask_label)
return torch.stack(mask_labels, dim=0)
def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
labels = inputs.clone()
probability_matrix = mask_labels
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer.pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = probability_matrix.bool()
labels[~masked_indices] = -100
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
return inputs, labels
def __call__(self, examples: list) -> dict:
input_ids_list, token_type_ids_list, attention_mask_list = list(zip(*examples))
cur_max_seq_len = max(len(input_id) for input_id in input_ids_list)
max_seq_len = min(cur_max_seq_len, self.max_seq_len)
input_ids, token_type_ids, attention_mask = self.truncate_and_pad(
input_ids_list, token_type_ids_list, attention_mask_list, max_seq_len
)
batch_mask = self.whole_word_mask(input_ids_list, max_seq_len)
input_ids, mlm_labels = self.mask_tokens(input_ids, batch_mask)
data_dict = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids,
'labels': mlm_labels
}
return data_dict
测试数据样例
输入data collator的数据一般是BertTokenizer encode_plus得到的输出,即
- input_ids:输入句子中每个词的编号(在词表中的序号),101代表[cls],102代表[sep];
- token_type_ids:单词属于哪个句子,第一个句子为0,第二句子为1;
- attention_mask:需要对哪些单词做self_attention。
input_ids = [
[101, 4078, 3828, 7029, 4344, 2768, 2642, 8024, 1220, 4289, 924, 2844, 5442, 1316, 2456, 21128, 7344, 4344, 7270, 1814, 21129, 2828, 3315, 1759, 4289, 4905, 1750, 1075, 2768, 4635, 4590, 102],
...
]
token_type_ids = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
...
]
attention_mask = [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
...
]
data = (input_ids, token_type_ids, attention_mask)
data_collator = DataCollator()
data_collator(data)
|