目录
1.__init__方法
2.forward方法
将bert和crf模型结合起来,简单来说就是,设置好Bert模型,以及参数,得到的输出结果给CRF模型即可。
1.__init__方法
这里面主要是bert的参数的定义及导入,还有bert模型的导入。
MODEL_NAME = 'bert-base-chinese-model.bin'
CONFIG_NAME = 'bert-base-chinese-config.json'
VOB_NAME = 'bert-base-chinese-vocab.txt'
class BertCrf(nn.Module):
def __init__(self, config_name: str, model_name:str = None, num_tags: int = 2, batch_first: bool = True) -> None:
self.batch_first = batch_first
# 模型配置文件、模型预训练参数文件判断
if not os.path.exists(config_name):
raise ValueError(
"未找到模型配置文件 '{}'".format(config_name)
)
else:
self.config_name = config_name
if model_name is not None:
if not os.path.exists(model_name):
raise ValueError(
"未找到模型预训练参数文件 '{}'".format(model_name)
)
else:
self.model_name = model_name
else:
self.model_name = None
if num_tags <= 0:
raise ValueError(f'invalid number of tags: {num_tags}')
super().__init__()
# 配置bert的config文件
self.bert_config = BertConfig.from_pretrained(self.config_name)
self.bert_config.num_labels = num_tags
self.model_kwargs = {'config': self.bert_config}
# 如果模型不存在
if self.model_name is not None:
self.bertModel = BertForTokenClassification.from_pretrained(self.model_name, **self.model_kwargs)
else:
self.bertModel = BertForTokenClassification(self.bert_config)
self.crf_model = CRF(num_tags=num_tags, batch_first=batch_first)
2.forward方法
输出的结果,经过处理后,输入CRF函数,返回loss即可。
def forward(self, input_ids: torch.Tensor,
tags: torch.Tensor = None,
attention_mask: Optional[torch.ByteTensor] = None,
token_type_ids=torch.Tensor,
decode:bool = True,
reduction: str = 'mean')->List:
emissions = self.bertModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
# 去掉开头的[CLS]以及结尾,结尾可能有两种情况:1、<pad> 2、[SEP]
new_emissions = emissions[:, 1:-1]
new_mask = attention_mask[:, 2:].bool()
# tags为None, 是预测过程,不能求loss
if tags is None:
loss = None
pass
else:
new_tags = tags[:, 1:-1]
loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)
if decode:
tag_list = self.crf_model.decode(emissions=new_emissions, mask=new_mask)
return [loss, tag_list]
return [loss]
|