2021SC@SDUSC
前几周的源代码分析中,我们已经了解了drfact是如何对语料库进行预处理的,也了解了drfact模型算法的前几步都做了什么事情。但这一周的源代码分析我不会对具体的源代码进行分析,原因在于我在本周进行源代码分析,并回顾了过往的源代码分析内容时,注意到drfact模型对其他模型进行了一定程度的借鉴,这一点尤其体现在其核心源代码之中——调用了其他模型中编写好的函数。因此,体现在源代码之中的内容也就不再是仅仅只要关注到drfact这一个项目包即可,而是需要对整个OpenCSR项目的其他源代码也进行审视。
某种意义上,我对OpenCSR这个项目源代码的核心产生了一定程度错误的评判,这也意味着我需要花上更多的力气对这个项目进行更深层次的理解。因此,在我对源代码的整体结构以及各模型之间的相互勾连达成宏观层面的理解之前,我暂且不会对drfact模型中具体实现算法以及其它进行数据处理的代码中具体细微的详细描述进行进一步的分析与探究,而是转头分析各个模型之间可以互通的模块以及函数特点。这些可以互通的模块以及函数特点会一并体现在这篇博客之中。
二、DrKit模型的已定义函数
要提及DrKit模型中对DrFact模型中产生贡献的模块,首先要谈到BERT。BERT在前面的源代码分析中有提及过,它的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型。它强调了不再像以往一样采用传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的masked language model(MLM),以致能生成深度的双向语言表征。
有了上文做一些铺垫,我们可以转头去看有关DrKit模型中定义的有关BERT的模块。首先是bert_utils.py这个模块,在这个模块中定义有一个BERTPredictor类,该类共有五个成员属性以及五个成员方法,其函数定义如下:
def __init__(self, tokenizer, init_checkpoint, estimator=None):
"""Setup BERT model."""
self.max_seq_length = FLAGS.max_seq_length
self.max_qry_length = FLAGS.max_query_length
self.max_ent_length = FLAGS.max_entity_length
self.batch_size = FLAGS.predict_batch_size
self.tokenizer = tokenizer
if estimator is None:
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
run_config = tf.estimator.tpu.RunConfig()
qa_config = run_dualencoder_lsf.QAConfig(
doc_layers_to_use=FLAGS.doc_layers_to_use,
doc_aggregation_fn=FLAGS.doc_aggregation_fn,
qry_layers_to_use=FLAGS.qry_layers_to_use,
qry_aggregation_fn=FLAGS.qry_aggregation_fn,
dropout=FLAGS.question_dropout,
qry_num_layers=FLAGS.question_num_layers,
projection_dim=FLAGS.projection_dim,
normalize_emb=FLAGS.normalize_emb,
reconstruction_weight=FLAGS.qry_reconstruction_weight,
ent_decomp_weight=FLAGS.ent_decomp_weight,
rel_decomp_weight=FLAGS.rel_decomp_weight,
train_bert=FLAGS.train_bert,
shared_bert_for_qry=FLAGS.shared_bert_for_qry,
load_only_bert=FLAGS.load_only_bert)
model_fn = run_dualencoder_lsf.model_fn_builder(
bert_config=bert_config,
qa_config=qa_config,
init_checkpoint=init_checkpoint,
learning_rate=0.0,
num_train_steps=0,
num_warmup_steps=0,
use_tpu=False,
use_one_hot_embeddings=False)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.estimator.tpu.TPUEstimator(
use_tpu=False,
model_fn=model_fn,
config=run_config,
train_batch_size=self.batch_size,
predict_batch_size=self.batch_size)
self.fast_predictor = FastPredict(estimator, self.get_input_fn)
self.emb_dim = FLAGS.projection_dim
def get_input_fn(self, generator):
"""Return an input_fn which accepts a generator."""
def _input_fn(params):
"""Convert input into features."""
del params
seq_length = self.max_seq_length
qry_length = self.max_qry_length
ent_length = self.max_ent_length
d = tf.data.Dataset.from_generator(
generator,
output_types={
"unique_ids": tf.int32,
"doc_input_ids": tf.int32,
"doc_input_mask": tf.int32,
"doc_segment_ids": tf.int32,
"qry_input_ids": tf.int32,
"qry_input_mask": tf.int32,
"qry_segment_ids": tf.int32,
"ent_input_ids": tf.int32,
"ent_input_mask": tf.int32,
},
output_shapes={
"unique_ids": tf.TensorShape([]),
"doc_input_ids": tf.TensorShape([seq_length]),
"doc_input_mask": tf.TensorShape([seq_length]),
"doc_segment_ids": tf.TensorShape([seq_length]),
"qry_input_ids": tf.TensorShape([qry_length]),
"qry_input_mask": tf.TensorShape([qry_length]),
"qry_segment_ids": tf.TensorShape([qry_length]),
"ent_input_ids": tf.TensorShape([ent_length]),
"ent_input_mask": tf.TensorShape([ent_length]),
})
d = d.batch(batch_size=self.batch_size)
return d
return _input_fn
def _run_on_features(self, features):
"""Run predictions for given features."""
current_size = len(features)
if current_size < self.batch_size:
features += [features[-1]] * (self.batch_size - current_size)
return self.fast_predictor.predict(features)[:current_size]
def get_features(self, doc_tokens, qry_tokens, ent_tokens, uid):
"""Convert list of tokens to a feature dict."""
max_tokens_doc = self.max_seq_length - 2
max_tokens_qry = self.max_qry_length - 2
max_tokens_ent = self.max_ent_length
doc_input_ids = self.tokenizer.convert_tokens_to_ids(
["[CLS]"] + doc_tokens[:max_tokens_doc] + ["[SEP]"])
doc_segment_ids = [1] * len(doc_input_ids)
doc_input_mask = [1] * len(doc_input_ids)
while len(doc_input_ids) < self.max_seq_length:
doc_input_ids.append(0)
doc_input_mask.append(0)
doc_segment_ids.append(0)
qry_input_ids = self.tokenizer.convert_tokens_to_ids(
["[CLS]"] + qry_tokens[:max_tokens_qry] + ["[SEP]"])
qry_segment_ids = [0] * len(qry_input_ids)
qry_input_mask = [1] * len(qry_input_ids)
while len(qry_input_ids) < self.max_qry_length:
qry_input_ids.append(0)
qry_input_mask.append(0)
qry_segment_ids.append(0)
ent_input_ids = self.tokenizer.convert_tokens_to_ids(
ent_tokens[:max_tokens_ent])
ent_input_mask = [1] * len(ent_input_ids)
while len(ent_input_ids) < self.max_ent_length:
ent_input_ids.append(0)
ent_input_mask.append(0)
return {
"unique_ids": uid,
"doc_input_ids": doc_input_ids,
"doc_input_mask": doc_input_mask,
"doc_segment_ids": doc_segment_ids,
"qry_input_ids": qry_input_ids,
"qry_input_mask": qry_input_mask,
"qry_segment_ids": qry_segment_ids,
"ent_input_ids": ent_input_ids,
"ent_input_mask": ent_input_mask,
}
def get_doc_embeddings(self, docs):
"""Run BERT to get features for docs.
Args:
docs: List of list of tokens.
Returns:
embeddings: Numpy array of token features.
"""
num_batches = (len(docs) // self.batch_size) + 1
tf.logging.info("Total batches for BERT = %d", num_batches)
embeddings = np.empty((len(docs), self.max_seq_length, self.emb_dim),
dtype=np.float32)
for nb in tqdm(range(num_batches)):
min_ = nb * self.batch_size
max_ = (nb + 1) * self.batch_size
if min_ >= len(docs):
break
if max_ > len(docs):
max_ = len(docs)
current_features = [
self.get_features(docs[ii], ["dummy"], ["dummy"], ii)
for ii in range(min_, max_)
]
results = self._run_on_features(current_features)
for ir, rr in enumerate(results):
embeddings[min_ + ir, :, :] = rr["doc_features"]
return embeddings[:, 1:, :] # remove [CLS]
def get_qry_embeddings(self, qrys, ents):
"""Run BERT to get features for queries.
Args:
qrys: List of list of tokens.
ents: List of list of tokens.
Returns:
st_embeddings: Numpy array of token features.
en_embeddings: Numpy array of token features.
bow_embeddings: Numpy array of token features.
"""
num_batches = (len(qrys) // self.batch_size) + 1
tf.logging.info("Total batches for BERT = %d", num_batches)
st_embeddings = np.empty((len(qrys), self.emb_dim), dtype=np.float32)
en_embeddings = np.empty((len(qrys), self.emb_dim), dtype=np.float32)
bow_embeddings = np.empty((len(qrys), self.emb_dim), dtype=np.float32)
for nb in tqdm(range(num_batches)):
min_ = nb * self.batch_size
max_ = (nb + 1) * self.batch_size
if min_ >= len(qrys):
break
if max_ > len(qrys):
max_ = len(qrys)
current_features = [
self.get_features(["dummy"], qrys[ii], ents[ii], ii)
for ii in range(min_, max_)
]
results = self._run_on_features(current_features)
for ir, rr in enumerate(results):
st_embeddings[min_ + ir, :] = rr["qry_st_features"]
en_embeddings[min_ + ir, :] = rr["qry_en_features"]
bow_embeddings[min_ + ir, :] = rr["qry_bow_features"]
return st_embeddings, en_embeddings, bow_embeddings
|