NLP实践——基于SIFRank的中文关键短语抽取
0. 本文介绍
本文在《SIFRank: A New Baseline for Unsupervised Keyphrase Extraction Based on Pre-Trained Language Model》的基础上,借鉴原作者的思想,重写实现了一个好用的中文关键短语抽取工具。 首先声明一下,这篇论文我并没有看过,所有的理解全都是基于作者开源出来的代码,因而不保证所有的思想都与原作者保持一致。
这篇论文是一个抽取式的关键短语模型,相比近两年备受关注的生成式关键短语模型,其技术理念已经相对落后,但是在实际应用的生产环境中,尤其是对于无监督的垂直领域,我们更关心的是模型的可解释性以及抽取结果的可控性,因而抽取式的模型相比生成式,能够更加让我感到安心,这也是选择这篇论文作为参考的主要原因。在尝试这个思路之前,也对textrank,yake,autophrasex,UCphrase等关键短语抽取工具进行了尝试,但是效果都不太理想。
下面贴出原项目的地址: https://github.com/yukuotc/SIFRank_zh
原项目的时间比较久,其中所应用到的elmo编码器的预训练模型的下载地址已经失效,并且词性标注模型也比较旧了,所以在此项目的基础上,我从中借鉴了一部分代码,并参考作者的思路,提出并实现了自己的解决方法,主要做出的修改如下:
熟悉我写作风格的同学们应该比较了解,我很少进行理论介绍,我的博客主要从易用的角度,关注一个具体功能的实现,接下来我将从运行环境开始讲起,介绍如何实现这一关键短语抽取模型。
1. 运行环境
首先介绍一下环境配置,我的运行环境如下:
torch 1.8.1
ltp 4.1.4
thulac 0.2.1
nltk 3.5
transformers 4.9.2
sentence-transformers 2.0.0
其中,
- thulac是参考原作者的环境,如果完全按照我的方法去做,不考虑原作者的方法,可以不安装;
- sentence-transformers是用于自监督训练,如果对领域迁移不感兴趣,可以不安装;
- transformers高版本是sentence-transformers的要求,如果不安装后者,估计前者4.0以上即可;
- ltp最好采用4.1或以上版本,其新版与旧版在效率和准确度上都有很大的差异;
- torch满足相应版本的ltp和transformers即可;
- nltk的版本相对随意,一般也不会与其他模块冲突。
2. 项目目录
然后介绍一下项目目录。建立一个项目根目录keyphrase_extractor,在此目录下建立一个jupyter笔记或py文件,建立一个utils.py(其中的内容后边会介绍),以及一个文件夹resources;
resources中,建立一个ner_usr_dict.txt,其中存放分词时的用户自定义实体表,每行写一个实体,例如:
南京市长
江大桥
这个文件的作用是,让分词模型在分词的时候,把“南京市长江大桥”分为[“南京市长”, “江大桥”],而非[“南京市”, “长江大桥”]。
然后去原项目中,下载auxiliary_data下的dict.txt,放在我们的resources下,命名为pretrained_weight_dict.txt。
再去huggingface下载一个你觉得顺眼的模型,比如bert-base,我这里用的例子是electra,然后把整个模型的所有文件放在resources中的一个目录下。(注意:不要用基于Roberta的模型,Roberta的tokenizer比较特殊,我没有进行适配)
全部准备好之后,整个项目目录应该是这个样子:
keyphrase_extractor
|--keyphrase_extract.ipynb
|--utils.py
|--resources
|--ner_usr_dict.txt
|--pretrained_weight_dict.txt
|--chinese-electra-180g-small-discriminator
|--config.json
|--tokenizer_config.json
|--tokenizer.json
|--added_tokens.json
|--special_tokens_map.json
|--vocab.txt
|--pytorch_model.bin
3. 代码实现
终于来到了喜闻乐见的代码环节,在这一环节中的所有代码,除了3.1中,全部依次丢进keyphrase_extract.ipynb中运行即可。 代码的基本逻辑我随手花了一个图,同学们凑合着看。
3.1 utils
首先完善一下我们的辅助类函数,打开utils.py,加入以下三个函数:
- get_word_weight:用于获取词权重
- process_long_input:用于将bert支持的长度从512扩展为1024
- rematch:用于token-level到char-level的匹配
这三个函数是到处借鉴来的,其中1是本项目中改写的,2是此论文所述项目中搬来的,3是从bert4keras中搬来的。
import numpy as np
import unicodedata, re
import torch
import torch.nn.functional as F
def get_word_weight(weightfile="", weightpara=2.7e-4):
"""
Get the weight of words by word_fre/sum_fre_words
:param weightfile
:param weightpara
:return: word2weight[word]=weight : a dict of word weight
"""
if weightpara <= 0:
weightpara = 1.0
word2weight = {}
word2fre = {}
with open(weightfile, encoding='UTF-8') as f:
lines = f.readlines()
sum_fre_words = 0
for line in lines:
word_fre = line.split()
if (len(word_fre) >= 2):
word2fre[word_fre[0]] = float(word_fre[1])
sum_fre_words += float(word_fre[1])
else:
print(line)
for key, value in word2fre.items():
word2weight[key] = weightpara / (weightpara + value / sum_fre_words)
return word2weight
def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens):
"""
Parameters
----------
model: 编码模型
input_ids: (b, l)
attention_mask: (b, l)
start_tokens: 对bert而言就是[101]
end_tokens: [102]
Returns
-------
"""
n, c = input_ids.size()
start_tokens = torch.tensor(start_tokens).to(input_ids)
end_tokens = torch.tensor(end_tokens).to(input_ids)
len_start = start_tokens.size(0)
len_end = end_tokens.size(0)
if c <= 512:
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True,
)
sequence_output = output[0]
attention = output[-1][-1]
else:
new_input_ids, new_attention_mask, num_seg = [], [], []
seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist()
for i, l_i in enumerate(seq_len):
if l_i <= 512:
new_input_ids.append(input_ids[i, :512])
new_attention_mask.append(attention_mask[i, :512])
num_seg.append(1)
else:
input_ids1 = torch.cat([input_ids[i, :512 - len_end], end_tokens], dim=-1)
input_ids2 = torch.cat([start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1)
attention_mask1 = attention_mask[i, :512]
attention_mask2 = attention_mask[i, (l_i - 512): l_i]
new_input_ids.extend([input_ids1, input_ids2])
new_attention_mask.extend([attention_mask1, attention_mask2])
num_seg.append(2)
input_ids = torch.stack(new_input_ids, dim=0)
attention_mask = torch.stack(new_attention_mask, dim=0)
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True,
)
sequence_output = output[0]
attention = output[-1][-1]
i = 0
new_output, new_attention = [], []
for (n_s, l_i) in zip(num_seg, seq_len):
if n_s == 1:
output = F.pad(sequence_output[i], (0, 0, 0, c - 512))
att = F.pad(attention[i], (0, c - 512, 0, c - 512))
new_output.append(output)
new_attention.append(att)
elif n_s == 2:
output1 = sequence_output[i][:512 - len_end]
mask1 = attention_mask[i][:512 - len_end]
att1 = attention[i][:, :512 - len_end, :512 - len_end]
output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end))
mask1 = F.pad(mask1, (0, c - 512 + len_end))
att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end))
output2 = sequence_output[i + 1][len_start:]
mask2 = attention_mask[i + 1][len_start:]
att2 = attention[i + 1][:, len_start:, len_start:]
output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i))
mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i))
att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i])
mask = mask1 + mask2 + 1e-10
output = (output1 + output2) / mask.unsqueeze(-1)
att = (att1 + att2)
att = att / (att.sum(-1, keepdim=True) + 1e-10)
new_output.append(output)
new_attention.append(att)
i += n_s
sequence_output = torch.stack(new_output, dim=0)
attention = torch.stack(new_attention, dim=0)
return sequence_output, attention
def rematch(text, tokens, do_lower_case=True):
if do_lower_case:
text = text.lower()
def is_control(ch):
return unicodedata.category(ch) in ('Cc', 'Cf')
def is_special(ch):
return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')
def stem(token):
if token[:2] == '##':
return token[2:]
else:
return token
normalized_text, char_mapping = '', []
for i, ch in enumerate(text):
if do_lower_case:
ch = unicodedata.normalize('NFD', ch)
ch = ''.join([c for c in ch if unicodedata.category(c) != 'mn'])
ch = ''.join([c for c in ch if not (ord(c) == 0 or ord(c) == 0xfffd or is_control(c))])
normalized_text += ch
char_mapping.extend([i] * len(ch))
text, token_mapping, offset = normalized_text, [], 0
for token in tokens:
if token.startswith('▁'):
token = token[1:]
if is_special(token):
token_mapping.append([])
else:
token = stem(token)
if do_lower_case:
token = token.lower()
try:
start = text[offset:].index(token) + offset
except Exception as e:
print(e)
print(token)
end = start + len(token)
token_mapping.append(char_mapping[start: end])
offset = end
return token_mapping
3.2 初始化各类组件
先import:
import time
import numpy as np
import thulac
import nltk
from nltk.corpus import stopwords
from ltp import LTP
import torch
import torch.nn.functional as F
from transformers import ElectraModel, ElectraTokenizerFast
from sentence_transformers.util import pytorch_cos_sim
from utils import get_word_weight, process_long_input, rematch
3.2.1 标点和停用词
english_punctuations = [',', '.', ':', ';', '?', '(', ')', '[', ']', '&', '!', '*', '@', '#', '$', '%']
chinese_punctuations = '!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗?????〝〞????–—‘’?“”??…?﹏.'
punctuations = ''.join(i for i in english_punctuations) + chinese_punctuations
stop_words = stopwords.words('english')
3.2.2 预训练词汇权重
weightfile_pretrain = './resources/pretrained_weight_dict.txt'
weightpara_pretrain = 2.7e-4
word2weight_pretrain = get_word_weight(weightfile_pretrain, weightpara_pretrain)
3.2.3 分词/词性标注模型
如果采用SIFRank原作者的策略,则实例化一个lac模型
lac_model = thulac.thulac()
我采用的是ltp模型,首先把自定义词表和模型路径加载一下。
ltp_model_path = '/ltp4_data/base/'
ltp_ner_usr_dict_path = './resources/ner_usr_dict.txt'
usr_dict = []
with open(ltp_ner_usr_dict_path) as f:
for line in f.readlines():
usr_dict.append(line.split('\n')[0])
3.2.4 候选短语抽取模型
这个模型的作用是以nltk的正则工具抽取候选关键短语。我在原项目的基础上做了一点点修改,原项目每次抽取都重新实例化抽取器,让我觉得很别扭。
class CandidateExtractor:
"""
参考SIFRank项目的词性正则抽取候选短语
"""
def __init__(self):
grammar = """ NP:
{<n.*|a|uw|i|j|x>*<n.*|uw|x>|<x|j><-><m|q>} # Adjective(s)(optional) + Noun(s)"""
self.parser = nltk.RegexpParser(grammar)
def extract_candidates(self, tokens_tagged):
keyphrase_candidate = []
np_pos_tag_tokens = self.parser.parse(tokens_tagged)
count = 0
for token in np_pos_tag_tokens:
if (isinstance(token, nltk.tree.Tree) and token._label == "NP"):
np = ''.join(word for word, tag in token.leaves())
length = len(token.leaves())
start_end = (count, count + length)
count += length
keyphrase_candidate.append((np, start_end))
else:
count += 1
return keyphrase_candidate
candidate_extractor = CandidateExtractor()
3.2.5 词形还原模型
这个没什么好说的,就是一个简单的词形还原,对中文来说作用不大。
lemma_model = nltk.WordNetLemmatizer()
3.2.6 编码模型
这里可以采用多种编码模型,可以多实验几个预训练模型测试一下效果。注意,Roberta系列的模型和XMLRoberta系列的模型由于tokenizer比较特殊,我没有做相应的适配。
Electra模型:
electra_path = './resources/chinese-electra-180g-small-discriminator'
electra_tokenizer = ElectraTokenizerFast.from_pretrained(electra_path)
electra_model = ElectraModel.from_pretrained(electra_path)
Bert模型:
from transformers import BertTokenizerFast, BertModel
bert_path = './resources/bert-base-chinese/'
bert_model = BertModel.from_pretrained(bert_path)
bert_tokenizer = BertTokenizerFast.from_pretrained(bert_path)
Sentence-bert提供的一个语义相似度预训练bert:
from transformers import DistilBertTokenizerFast, DistilBertModel
distil_bert_path = './finetune_embedding_model/SimCSE/4500/'
distil_bert_model = DistilBertModel.from_pretrained(distil_bert_path)
distil_bert_tokenizer = DistilBertTokenizerFast.from_pretrained(distil_bert_path)
这些模型都可以在huggingface网站上找到,参考本文第2部分。
3.3 建立关键短语抽取模型
万事俱备,接下来就把这些组件放在一起,构建一个大类,用于抽取关键短语。这个大类包含一下几个方法:
- 构造方法:加载3.2中构建的各个组件;
- 添加新的停用词和标点词;
- 获取每个token的编码特征列表;
- 获取每个token的权重列表;
- 获取候选短语列表;
- 从候选短语抽取关键短语;
- 调用方法,给入文本,抽取关键短语;
- 静态方法:获取一个候选的加权表征;
- 静态方法:输入文本预处理。
以上方法将会依次呈现在下面的类中:
class SIFRank:
"""
用于抽取关键短语的SIFRank模型
[步骤]
1. 对原句进行tokenize和词性标注
2. 对原句进行编码,并根据1中tokenize的结果获取embedding_list
3. 根据1中tokenize的结果获取weight_list
4. 抽取原句中的候选关键短语
5. 对候选关键短语进行评分,得到关键短语
---------------
ver: 2021-11-02
by: changhongyu
"""
def __init__(self, tokenize_and_postag_model, candidate_extractor, lemma_model,
encoding_model, encoding_tokenizer, encoding_pooling, encoding_device,
word2weight_pretrain, stop_words, punctuations):
"""
:param tokenize_and_postag_model: 分词和词性标注模型
:param candidate_extractor: 用于抽取候选短语的模型
:param lemma_model: 用于词根还原的模型, 如果None,则忽略
:param encoding_model: PretrainedModel: 编码预训练模型
:param encoding_tokenizer: PretrainedTokenizer: 编码时的tokenizer
:param encoding_pooling: str: 编码时的池化策略, 'mean'或'max'
:param encoding_device: str: 编码时的设备, 'cpu'或'cuda'
:param word2weight_pretrain: dict: 词汇对应权重的大list
:param stop_words: list: 停用词表
:param punctuations: list: 标点符号表
"""
assert encoding_pooling in ['mean', 'max'], Exception("Pooling must be either mean or max.")
assert encoding_device.startswith('cuda') or encoding_device == 'cpu'
self.tokenize_and_postag_model = tokenize_and_postag_model
self.extractor = candidate_extractor
self.lemma_model = lemma_model
self.encoding_model = encoding_model
self.encoding_tokenizer = encoding_tokenizer
self.encoding_pooling = encoding_pooling
self.encoding_device = torch.device(encoding_device)
self.word2weight_pretrain = word2weight_pretrain
self.stop_words = stop_words
self.punctuations = punctuations
print(self)
def __repr__(self):
infos = ['------SIFRank for key-phrase extract------\n',
'SETTINGS: \n'
'tokenize_and_postag_model: {}\n'.format(str(type(self.tokenize_and_postag_model)).replace("'>", "").split('.')[-1]),
'lemma_model: {}\n'.format(str(type(self.lemma_model)).replace("'>", "").split('.')[-1]),
'encoding_model: {}\n'.format(str(type(self.encoding_model)).replace("'>", "").split('.')[-1]),
'encoding_device: {}\n'.format(self.encoding_device),
'encoding_pooling: {}\n'.format(self.encoding_pooling),
]
return ''.join(info for info in infos)
def add_stopword(self, stop_word):
"""
添加停用词,注意停用词是指英文停用词
"""
self.stop_words.append(stop_word)
def add_punctuation(self, punctuation):
"""
添加标点符
"""
self.punctuations.append(punctuation)
def _get_embedding_list(self, text, target_tokens):
"""
获取以token为划分的embedding的list
TODO: 对原句进行清洗,过滤掉对encoding_tokenizer而言OOV的词(耗时太长)
:param text: str: 原文
:param target_tokens: list: tokenize_and_postag_model对当前输入的分词结果
"""
embedding_list = []
self.encoding_model.to(self.encoding_device)
features = self.encoding_tokenizer(text.lower().replace(' ', '-'),
max_length=1024,
truncation=True,
padding='longest',
return_tensors='pt')
input_ids = features['input_ids'].to(self.encoding_device)
attention_mask = features['attention_mask'].to(self.encoding_device)
with torch.no_grad():
enconding_out, _ = process_long_input(self.encoding_model,
input_ids,
attention_mask,
[self.encoding_tokenizer.cls_token_id],
[self.encoding_tokenizer.sep_token_id])
last_hidden_state = enconding_out.squeeze(0).detach().cpu().numpy()
t_mapping = rematch(text, target_tokens, do_lower_case=True)
s_mapping = rematch(text, self.encoding_tokenizer.tokenize(text), do_lower_case=True)
token_lens = []
t_pointer = 0
t = t_mapping[t_pointer]
cur_len = 0
cur_in_t = 0
for s in s_mapping:
if s == t[cur_in_t: cur_in_t + len(s)]:
cur_len += 1
cur_in_t += len(s)
if cur_in_t == len(t):
token_lens.append(cur_len)
cur_len = 0
cur_in_t = 0
t_pointer += 1
if t_pointer >= len(t_mapping):
break
t = t_mapping[t_pointer]
assert len(token_lens) == len(target_tokens), \
Exception("Token_lens and target_tokens shape unmatch: {} vs {}.".format(len(token_lens), len(target_tokens)))
cur_pos = 0
for token_len in token_lens:
if token_len == 0:
cur_emb = np.zeros(last_hidden_state.shape[1])
embedding_list.append(cur_emb)
continue
if self.encoding_pooling == 'mean':
cur_emb = np.mean(last_hidden_state[cur_pos: cur_pos + token_len][:], axis=0)
elif self.encoding_pooling == 'max':
cur_emb = np.max(last_hidden_state[cur_pos: cur_pos + token_len][:], axis=0)
else:
raise ValueError("Pooling Strategy must be either mean or max.")
cur_pos += token_len
embedding_list.append(cur_emb)
assert len(embedding_list) == len(target_tokens), \
Exception("Result embedding list must have same length as target.")
return embedding_list
def _get_weight_list(self, target_tokens):
"""
获取weight列表
:param target_tokens: list: tokenize_and_postag_model对当前输入的分词结果
:return weight_list: list of float: 每个token对应的预训练权重列表
"""
weight_list = []
_max = 0.
for token in target_tokens:
token = token.lower()
if token in self.stop_words or token in self.punctuations:
weight = 0.
elif token in self.word2weight_pretrain:
weight = word2weight_pretrain[token]
else:
weight = _max
_max = max(weight, _max)
weight_list.append(weight)
return weight_list
def _get_candidate_list(self, target_tokens, target_poses):
"""
用词性正则抽取候选关键短语列表
:param target_tokens: list: tokenize_and_postag_model对当前输入的分词结果
:param target_poses: list: tokenize_and_postag_model对当前输入词性标注结果
:return candidates: list of tuples like: ('自然语言', (5, 7))
NOTE: tuple[1]是在target_tokens中的span,对target_tokens索引,得到tuple[0]
"""
assert len(target_tokens) == len(target_poses)
tokens_tagged = [(tok, pos) for tok, pos in zip(target_tokens, target_poses)]
candidates = self.extractor.extract_candidates(tokens_tagged)
return candidates
def _extract_keyphrase(self, candidates, weight_list, embedding_list, max_keyphrase_num):
"""
对候选的关键短语计算与原文编码的相似度,获取关键短语
:param candidates: list of tuples: 候选关键短语list
:param weight_list: list of float: 每个token的预训练权重列表
:param embedding_list: list of array: 每个token的编码结果
:param max_keyphrase_num: int: 最多保留的关键词个数
:return key_phrases: list of tuple: [(k1, 0.9), ...]
"""
assert len(weight_list) == len(embedding_list)
candidate_embeddings_list = []
for cand in candidates:
cand_emb = self.get_candidate_weight_avg(weight_list, embedding_list, cand[1])
candidate_embeddings_list.append(cand_emb)
sent_embeddings = self.get_candidate_weight_avg(weight_list, embedding_list, (0, len(embedding_list)))
sim_list = []
for i, emb in enumerate(candidate_embeddings_list):
sim = float(pytorch_cos_sim(sent_embeddings, candidate_embeddings_list[i]).squeeze().numpy())
sim_list.append(sim)
dict_all = {}
for i, cand in enumerate(candidates):
if self.lemma_model:
cand_lemma = self.lemma_model.lemmatize(cand[0].lower()).replace('▲', ' ')
else:
cand_lemma = cand[0].lower().replace('▲', ' ')
if cand_lemma in dict_all:
dict_all[cand_lemma].append(sim_list[i])
else:
dict_all[cand_lemma] = [sim_list[i]]
final_dict = {}
for cand, sim_list in dict_all.items():
sum_sim = sum(sim_list)
final_dict[cand] = sum_sim / len(sim_list)
return sorted(final_dict.items(), key=lambda x: x[1], reverse=True)[: max_keyphrase_num]
def __call__(self, text, max_keyphrase_num):
"""
抽取关键词
:param text: str: 待抽取原文
:param max_keyphrase_num: int: 最多保留的关键词个数
:return key_phrases: list of tuple: [(k1, 0.9), ...]
"""
text = self.preprocess_input_text(text)
t0 = time.time()
token_and_pos = self.tokenize_and_postag_model.cut(text)
target_tokens = [t_p[0] for t_p in token_and_pos]
target_poses = [t_p[1] for t_p in token_and_pos]
for i, token in enumerate(target_tokens):
if token in self.stop_words:
target_poses[i] = "u"
if token == '-':
target_poses[i] = "-"
if token in ['"', "'"]:
target_poses[i] = '"'
t1 = time.time()
print("耗时统计")
print("<1. 对原句进行tokenize和词性标注: ", round(t1 - t0, 2), 's')
embedding_list = self._get_embedding_list(text, target_tokens)
t2 = time.time()
print("<2. 对原句进行编码: ", round(t2 - t1, 2), 's')
weight_list = self._get_weight_list(target_tokens)
t3 = time.time()
print("<3. 结果获取weight_list: ", round(t3 - t2, 2), 's')
candidate_list = self._get_candidate_list(target_tokens, target_poses)
t4 = time.time()
print("<4. 抽取原句中的候选关键短语: ", round(t4 - t3, 2), 's')
key_phrases = self._extract_keyphrase(candidate_list, weight_list,
embedding_list, max_keyphrase_num)
t5 = time.time()
print("<5. 对候选关键短语进行评分: ", round(t5 - t4, 2), 's')
return key_phrases
@staticmethod
def get_candidate_weight_avg(weight_list, embedding_list, candidate_span):
"""
获取一个候选词的加权表征
:param weight_list: list of float: 每个token的预训练权重列表
:param embedding_list: list of array: 每个token的编码结果
:param candidate_span: tuple: 候选短语的start和end
"""
assert len(weight_list) == len(embedding_list)
start, end = candidate_span
num_words = end - start
embedding_size = embedding_list[0].shape[0]
sum_ = np.zeros(embedding_size)
for i in range(start, end):
tmp = embedding_list[i] * weight_list[i]
sum_ += tmp
return sum_
@staticmethod
def preprocess_input_text(text):
"""
对输入原文进行预处理,主要防止两个tokenizer对齐时出现问题
"""
text = text.lower()
text = text.replace('“', '"').replace('”', '"')
text = text.replace('‘', "'").replace('’', "'")
text = text.replace('?', '-')
text = text.replace('\u3000', ' ').replace('\n', ' ')
text = text.replace(' ', '▲')
return text[: 1024]
注意,在上面的类中调用了sentence-transformer中的pytorch_cos_sim方法计算两个张量之间的余弦相似度,如果没有安装这个包,可以自己写个方法实现余弦相似度的计算,这个不难,可以直接百度到。
3.4 抽取应用
将上面的大类实例化:
keyphrase_extractor = SIFRank(tokenize_and_postag_model=ltp_pos_model,
candidate_extractor=candidate_extractor,
lemma_model=lemma_model,
encoding_model=electra_model,
encoding_tokenizer=electra_tokenizer,
encoding_pooling='mean',
encoding_device='cuda:1',
word2weight_pretrain=word2weight_pretrain,
stop_words=stop_words,
punctuations=punctuations)
然后对输入的text,调用:
keyphrase_extractor(text, max_keyphrase_num=10)
即可返回关键短语的降序排列,以及每个关键短语对应的得分。
4. 改进
4.1 增加候选关键短语
候选关键短语是通过正则的方式对词性进行匹配得到的,其关键代码在这一句:
grammar = """ NP:
{<n.*|a|uw|i|j|x>*<n.*|uw|x>|<x|j><-><m|q>} # Adjective(s)(optional) + Noun(s)"""
通过修改正则语句,我们可以获得自己想要的候选短语。例如,我希望拿到*’"花岗岩"超声速反舰导弹*这样的短语作为关键短语,通过观察词性发现,这类短语的词性构成是:引号+名词+引号+若干名词,翻译成正则语句就是:
<"><n.*><"><n.*>*<n.*>
把它拼接到原来的语句上:
grammar = """ NP:
{<n.*|a|uw|i|j|x>*<n.*|uw|x>|<x|j><-><m|q>|<"><n.*><"><n.*>*<n.*>}"""
然后看一下修改之后得效果:
text = '近日,俄罗斯海军在美国阿拉斯加州附近外海进行了涉及数十艘舰艇和飞机的大型演习,这是自冷战结束后在该地区举行的最大规模演习。\n演习期间,一艘俄罗斯核潜艇在阿拉斯加外海突然上浮,这一不同寻常的举动引起了美军的高度关注。\n据美联社8月29日报道,俄罗斯海军司令尼古拉·叶夫梅诺夫上将说,有50多艘军舰和约40架飞机参加了正在白令海举行的演习,演习中涉及多次导弹发射练习。\n“瓦良格”号导弹巡洋舰发射“玄武岩”反舰导弹。\n俄海军在白令海举行演习\n叶夫梅诺夫在俄罗斯国防部发表的一份声明中说:“这是我们有史以来第一次在那里举行如此大规模的演习。”叶夫梅诺夫强调,这些演习是为了加强俄罗斯在北极地区的存在和保护俄罗斯的资源。他说:“我们正在建立我们的力量,以确保该地区的经济发展”,“ 我们正在适应北极。”\n目前还不清楚演习何时开始,也不清楚演习是否已经结束\n俄罗斯太平洋舰队参加了此次演习,该舰队表示,作为演习的一部分,“鄂木斯克”号核潜艇和“瓦良格”号导弹巡洋舰向白令海的一个练习目标发射了巡航导弹。演习还从楚科奇半岛海岸向阿纳德尔湾的一个练习目标发射了“玛瑙石”岸舰导弹。\n“这两种舰艇都是冷战时期苏联针对美国航母研制的武器,因此此次演习的科目包括演练水面舰艇和潜艇联合打击航母等大型水面战舰。”军事专家韩东分析认为。\n“瓦良格”号导弹巡洋舰是俄太平洋舰队的旗舰,配备了16枚“玄武岩”超声速反舰导弹,最大射程约500千米,而“鄂木斯克”号巡航导弹核潜艇则配备24枚“花岗岩”超声速反舰导弹,射程超过500千米。\n在演习进行期间,美军27日发现一艘俄罗斯潜艇在阿拉斯加附近浮出水面。美国北方司令部发言人比尔·刘易斯指出,俄罗斯的军事演习是在美国境外的国际水域内进行的。刘易斯说,北美航空航天防御司令部和美国北方司令部正在密切监视这艘潜艇。他还说,他们还没有收到俄罗斯海军的任何援助请求,但随时准备为遇难者提供帮助。\n俄罗斯国家通讯社援引俄罗斯太平洋舰队消息人士的话说,“鄂木斯克”号核潜艇浮出水面是例行公事。\n除了海上的行动,同样是在27日晚些时候,北美防空司令部派出F-22战斗机拦截了靠近阿拉斯加的两架俄图-142海上巡逻机。俄军飞机在该地区停留了约5个小时。俄相关负责人说,俄军机仍在国际空域,任何时候都没有进入美国或加拿大主权空域。\n对于俄军机的飞行行动。美国北美防空司令格伦·范赫克(Glen D。 VanHerck)将军在一份声明中说:“随着我们的竞争对手继续扩大军事存在并探测我们的防御,我们的北部方面力量增加了对外国军事活动的监视……今年,我们进行了十多次拦截,这是近年来最多的一次。我们继续努力在北部进行防空行动的重要性从未如此显著。”\n俄军冷战后首次在位于俄罗斯远东领土与美国阿拉斯加州之间的白令海举行联合军事演习引发了外界的关注。俄新社援引俄罗斯海军前参谋长、退役上将维克多·克拉夫琴科评价说:“这是一个信号,表明我们没有沉睡,我们想去哪里就去哪里。”\n美联社报道称,俄海军在白令海的演习引起了美国商业渔船的骚动。美国海岸警卫队发言人基普·瓦德洛27日表示:“我们接到多艘在白令海外作业的渔船的通知,他们遇到了俄军舰艇,并感到担忧。”海岸警卫队联系了位于埃尔曼多夫-理查森联合基地的阿拉斯加司令部,该司令部证实,出现在那里的船只(指俄军舰艇)是俄罗斯预先计划的军事演习的一部分,美国军方官员已经知道。\n俄军苏-27拦截美国B-52H轰炸机。\n美俄相互在对方“家门口”展示“肌肉”\n俄罗斯海空军在阿拉斯加“家门口”的白令海举行演习,美国在俄罗斯“家门口”也有不少军事动作。\n据美国“战区”网站8月25日报道,美国海军发布照片显示,美国“海狼”号核潜艇近日浮出水面,出现在挪威近海。据悉,“海狼”级核潜艇堪称美海军最强核潜艇,主要用于开展情报活动和执行特别任务,这次的公开露面非常罕见,其出没的海域是俄罗斯核潜艇从俄西北部基地前往大洋的必经之地。\n“美海军很少公布执行任务的核潜艇行动信息,此次公布‘海狼’级核潜艇在挪威近海活动应该是对俄核潜艇在阿拉斯加外海活动的一个回应,相互亮‘肌肉’。”韩东认为。'
修改前:
[('俄罗斯太平洋舰队消息人士', 0.8885042016937243),
('俄罗斯太平洋舰队', 0.8724251117358179),
('俄罗斯海军司令尼古拉·叶夫梅诺夫上将', 0.8673940796984003),
('俄罗斯核潜艇', 0.8647546231314885),
('俄太平洋舰队', 0.8500425508933432),
('冷战时期苏联', 0.8485381218751973),
('司令部发言人比尔·刘易斯', 0.8474181641803855),
('俄罗斯潜艇', 0.839906844646745),
('号导弹巡洋舰', 0.8221001479460699),
('超声速反舰导弹', 0.8202724547732473),
('俄罗斯国防部', 0.8166525573126924),
('俄罗斯国家通讯社', 0.8151591420925066),
('号核潜艇', 0.8121318797386072),
('俄罗斯海军', 0.8052707858141447),
('f-22战斗机', 0.804186666125132)]
修改后:
[('俄罗斯太平洋舰队消息人士', 0.8882256970130232),
('俄罗斯太平洋舰队', 0.871818117861476),
('"花岗岩"超声速反舰导弹', 0.8703524536754964),
('俄罗斯海军司令尼古拉·叶夫梅诺夫上将', 0.8683305549970817),
('"玄武岩"超声速反舰导弹', 0.8668547029210916),
('俄罗斯核潜艇', 0.8660240382292157),
('冷战时期苏联', 0.8490640498535709),
('俄太平洋舰队', 0.8487545318893517),
('司令部发言人比尔·刘易斯', 0.8470155663497334),
('俄罗斯潜艇', 0.8408020518321848),
('俄罗斯国家通讯社', 0.8278665326955672),
('俄罗斯国防部', 0.8175267308579581),
('"瓦良格"号导弹巡洋舰', 0.805345079953326),
('f-22战斗机', 0.8047087460463858),
('俄罗斯海军', 0.8034005012063903)]
4.2 自监督训练
SimCSE等自监督训练可以参考我之前的这篇博客,我用SimCSE在6000条军事新闻数据上随便训练了一下,效果并不好。
'《亚洲防务评论》网站报道,日本公开了所拍摄到的中国最新型早期空中预警机照片,这是日本首次在空中遭遇这个新型号。\n日本防卫省联合参谋本部在3月23日发布的新闻中,公开了日本航空自卫队近距离拍摄的空警-500早期预警和指挥机,拍摄位置是在东中国海上空。\n日本方面发布的地图显示,这架空警-500当时飞行在日本冲绳县石垣岛以北、冲绳岛以西,遭到了日本航空自卫队的拦截。\n日本航空自卫队从冲绳岛那霸基地出动了两架F-15J/DJ战斗机前往拦截,当时中国军机已经进入了日本的防空识别区。\n日本航空自卫队截击机拍摄到的照片显示,这架空警-500在垂尾上喷涂有09字样的机号,这意味着她属于中国人民解放军海军航空兵。最近解放军海军航空兵开始用两位数机号逐步取代传统的五位数机号。\n[日本航空自卫队拍摄到的空警-500]\n新的两位数机号使外界更加难以辨别某架飞机的上级单位,但是日本防卫省指出,根据这架空警-500的来去航迹,可以认定她是从上海附近的某个基地起飞的。\n2018年的卫星图像显示,在上海附近的某个军用机场上停放着多架基于运-9运输机的特种飞机,其中包括空警-500。\n在2019年美国麦克萨尔技术公司发布的卫星图像中,能看到两架空警-500和其余4架基于运-9的特种任务飞机。此前的图像还能看到两架KQ-200远程海上巡逻/反潜特种机。\n[国外卫星拍摄的某军用机场]\n空警-500于2014年服役,取代了早期的空警-200和空警-2000早期空中预警和控制飞机。空警-500的特点是采用了不旋转的背负式雷达天线,以及鼻部天线,机身两侧可能还携带了用于电子情报采集的设备。背部的雷达由三部独立的相控阵雷达组成,每一部覆盖120度范围。空警-500采用运-9运输机基础,解放军海军已经把这个型号部署在了南海方向,以海南岛以南的岛屿机场为基地。'
训练前:
[('特种任务飞机', 0.9544675871738232),
('特种飞机', 0.9507889602043011),
('上级单位', 0.9419766881979212),
('日本航空自卫队截击机', 0.9391973945150388),
('运-9', 0.936841137717237),
('空警-200', 0.92588820146344),
('背负式雷达天线', 0.9231113884480613),
('军用机场', 0.9227828802688063),
('解放军海军', 0.9204134273344802),
('机号', 0.9188464054733505),
('空警-2000', 0.9187249554053616),
('电子情报', 0.9115211723773197),
('相控阵雷达', 0.9092908332876184),
('机身', 0.905688648586135),
('空警-500', 0.9055558900237239)]
训练后:
[('解放军海军航空兵', 0.9988492165651826),
('日本航空自卫队截击机', 0.9987936317255659),
('特种飞机', 0.998727913875621),
('机号', 0.9985039091472351),
('上级单位', 0.9983802383804824),
('运-9', 0.998351921860428),
('航迹', 0.9983295566286031),
('中国人民解放军海军航空兵', 0.9978012409215729),
('垂尾', 0.997654260001456),
('上海', 0.9975840676529512),
('数机号', 0.9974602590292685),
('外界', 0.9972179770236596),
('识别区', 0.9969144902282997),
('运-9运输机', 0.9968683425786463),
('卫星图像', 0.9960682017937763)]
关于mask language model的训练,可以参考huggingface官方的文档,我最近可能会整理一版比较方便的代码。如果整理了,可能会更新在博客上。
以上就是本期的全部内容了,总的来说SIFRank这个工具虽然没有那么“智能”,但可以充分做到可控,使用者对抽取结果可以从多方面进行干预和调整,是一个非常好用的关键词抽取工具。
如果这篇文章对你有所帮助,记得点个免费的赞,我们下期再见。
|