class MeshedDecoder_5(Module):
def __init__(self, vocab_size, max_len, N_dec, padding_idx, args, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
super(MeshedDecoder_5, self).__init__()
self.d_model = d_model
self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
self.layers = ModuleList(
[MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
enc_att_module_kwargs=enc_att_module_kwargs, args=args) for _ in range(N_dec)])
self.fc = nn.Linear(d_model, vocab_size, bias=False)
self.max_len = max_len
self.padding_idx = padding_idx
self.N = N_dec
self.conf = ModuleList(
[nn.Sequential(nn.Linear(2*d_model, 2*d_model),
nn.ReLU(),
nn.Linear(2*d_model, 1)) for _ in range(N_dec)])
self.args = args
self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
self.register_state('running_seq', torch.zeros((1,)).long())
def forward(self, input, encoder_output, mask_encoder, saliency_feats=None):
b_s, seq_len = input.shape[:2]
mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
diagonal=1)
mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
if self._is_stateful:
self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1)
mask_self_attention = self.running_mask_self_attention
seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
if self._is_stateful:
self.running_seq.add_(1)
seq = self.running_seq
out = self.word_emb(input) + self.pos_emb(seq)
out1 = self.word_emb(input) + self.pos_emb(seq)
# out = self.word_emb(input) + pos_fea
outs = []
outs_fc = []
pconf = []
# 可学习权重
for i, l in enumerate(self.layers):
out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
p = self.conf[i](torch.cat([out, self.word_emb(input) + self.pos_emb(seq)], dim = 2))
p = p.repeat(1, 1, 512)
p = p.unsqueeze(0)
pconf.append(p)
outs.append(out)
pconf = torch.cat(pconf,0)
pconf = torch.softmax(pconf, 0)
# pconf = torch.sigmoid(pconf)
out_sum = pconf[0] * outs[0] + pconf[1] * outs[1] + pconf[2] * outs[2]
out_fc = self.fc(out_sum)
return F.log_softmax(out_fc, dim=-1)
|