import torch
import torch.nn.functional as F
from src.utils import generate_square_subsequent_mask
import math
class Translator():
def __init__(self, bos_idx, eos_idx, device, max_steps=64, beam_size=3, length_norm_coefficient=0.6):
'''
length_norm_coefficient: co-efficient for normalizing decoded sequences' scores by their lengths
'''
self.bos_idx = bos_idx
self.eos_idx = eos_idx
self.beam_size = beam_size
self.device = device
self.max_steps = max_steps
self.length_norm_coefficient = length_norm_coefficient
def beam_translate(self, old_model, enc_inputs, tokenizer):
"""
Translates a source language sequence to the target language, with beam search decoding.
:param enc_inputs: [1, src_len]
:param enc_mask: [1, src_len]
:return: the best hypothesis, and all candidate hypotheses
"""
if hasattr(old_model, "module"):
model = old_model.module
else:
model = old_model
with torch.no_grad():
# Beam size
k = self.beam_size
# Minimum number of hypotheses to complete
n_completed_hypotheses = min(k, 10)
# Vocab size
vocab_size = len(tokenizer)
# Encode
memory = model.encode(**enc_inputs) # (1, source_sequence_length, d_model)
# Our hypothesis to begin with is just <BOS>
# Our hypothesis to begin with is just <BOS>
hypotheses = torch.ones(k,1).fill_(self.bos_idx).long().to(self.device) # (k, 1)
# Tensor to store hypotheses' scores; now it's just 0
hypotheses_scores = torch.zeros(k).to(self.device) # (k)
# Lists to store completed hypotheses and their scores
completed_hypotheses = list()
completed_hypotheses_scores = list()
# Start decoding
step = 1
# Assume "s" is the number of incomplete hypotheses currently in the bag; a number less than or equal to "k"
# At this point, s is 1, because we only have 1 hypothesis to work with, i.e. "<BOS>"
while True:
num_hyp = hypotheses.size(0) # 相当于batch_size
hyp_mask = (generate_square_subsequent_mask(hypotheses.size(1))
.type(torch.bool)).to(self.device)
# (num_hyp, step, d_model)
decoder_sequences = model.decode(tgt=hypotheses,
src=enc_inputs['input_ids'],
memory=memory.repeat(num_hyp, 1, 1),
tgt_mask=hyp_mask,
)
# Scores at this step
scores = decoder_sequences[:, -1, :] # (num_hyp, vocab_size)
scores = torch.log(scores) # (num_hyp, vocab_size)
# Add hypotheses' scores from last step to scores at this step to get scores for all possible new hypotheses
scores = hypotheses_scores.unsqueeze(1) + scores # (num_hyp, vocab_size)
# Unroll and find top k scores, and their unrolled indices
if step == 1: # step=1的时候只需要取其中一个条数据取top-k
top_k_hypotheses_scores, unrolled_indices = scores[0].topk(num_hyp, 0, True, True) # (k)
else:
top_k_hypotheses_scores, unrolled_indices = scores.view(-1).topk(num_hyp, 0, True, True) # (k)
# Convert unrolled indices to actual indices of the scores tensor which yielded the best scores
prev_word_indices = unrolled_indices
next_word_indices = unrolled_indices % vocab_size # (num_hyp)
# Construct the the new top k hypotheses from these indices
top_k_hypotheses = torch.cat([hypotheses[prev_word_indices], next_word_indices.unsqueeze(1)],
dim=1) # (num_hyp, step + 1)
# Which of these new hypotheses are complete (reached <EOS>)?
complete = next_word_indices == self.eos_idx # (num_hyp), bool
# Set aside completed hypotheses and their scores normalized by their lengths
# For the length normalization formula, see
# "Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation"
completed_hypotheses.extend(top_k_hypotheses[complete].tolist())
# norm = math.pow(((5 + step) / (5 + 1)), self.length_norm_coefficient)
norm = 1.0
completed_hypotheses_scores.extend((top_k_hypotheses_scores[complete] / norm).tolist())
# Stop if we have completed enough hypotheses
if len(completed_hypotheses) >= n_completed_hypotheses:
break
# Else, continue with incomplete hypotheses
hypotheses = top_k_hypotheses[~complete] # (s, step + 1)
hypotheses_scores = top_k_hypotheses_scores[~complete] # (s)
# Stop if things have been going on for too long
if step > self.max_steps:
break
step += 1
# If there is not a single completed hypothesis, use partial hypotheses
if len(completed_hypotheses) == 0:
completed_hypotheses = hypotheses.tolist()
completed_hypotheses_scores = hypotheses_scores.tolist()
# Decode the hypotheses
all_hypotheses = list()
for i, com_hyp in enumerate(completed_hypotheses):
predict_seq = tokenizer.decode(com_hyp, skip_special_tokens=True, clean_up_tokenization_spaces=True)
predict_seq = predict_seq.replace("<S>", "").replace("<\S>", "").strip()
all_hypotheses.append({"hypothesis": predict_seq, "score": completed_hypotheses_scores[i]})
# Find the best scoring completed hypothesis
i = completed_hypotheses_scores.index(max(completed_hypotheses_scores))
best_hypothesis = all_hypotheses[i]["hypothesis"]
return best_hypothesis, all_hypotheses
# function to generate output sequence using greedy algorithm
def greedy_translate(self, old_model, enc_inputs, tokenizer):
if hasattr(old_model, "module"):
model = old_model.module
else:
model = old_model
with torch.no_grad():
memory = model.encode(**enc_inputs)
ys = torch.ones(1, 1).fill_(self.bos_idx).type(torch.long).to(self.device)
for i in range(self.max_steps):
tgt_mask = (generate_square_subsequent_mask(ys.size(1)).bool()).to(self.device) # [1, tgt_len, tgt_len]
out = model.decode(tgt=ys, src=enc_inputs['input_ids'], memory=memory, tgt_mask=tgt_mask) # [1, tgt_len, d_model]
prob = out[:,-1]
# print(prob.shape)
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
# print(next_word)
ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(ys)], dim=-1)
if next_word == self.eos_idx:
break
predict_seq = tokenizer.decode(ys.squeeze().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True)
predict_seq = predict_seq.replace("<S>", "").replace("<\S>", "").strip()
return predict_seq
|