Spaces:
Runtime error
Runtime error
from onmt_modules.decoder_transformer import TransformerDecoder | |
from onmt_modules.misc import sequence_mask | |
class OnmtDecoder_1(TransformerDecoder): | |
# overide forward | |
# without teacher forcing for stop | |
def forward(self, tgt, memory_bank, step=None, **kwargs): | |
"""Decode, possibly stepwise.""" | |
if step == 0: | |
self._init_cache(memory_bank) | |
if step is None: | |
tgt_lens = kwargs["tgt_lengths"] | |
else: | |
tgt_words = kwargs["tgt_words"] | |
emb = self.embeddings(tgt, step=step) | |
assert emb.dim() == 3 # len x batch x embedding_dim | |
output = emb.transpose(0, 1).contiguous() | |
src_memory_bank = memory_bank.transpose(0, 1).contiguous() | |
pad_idx = self.embeddings.word_padding_idx | |
src_lens = kwargs["memory_lengths"] | |
src_max_len = self.state["src"].shape[0] | |
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) | |
if step is None: | |
tgt_max_len = tgt_lens.max() | |
tgt_pad_mask = ~sequence_mask(tgt_lens, tgt_max_len).unsqueeze(1) | |
else: | |
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) | |
with_align = kwargs.pop('with_align', False) | |
attn_aligns = [] | |
for i, layer in enumerate(self.transformer_layers): | |
layer_cache = self.state["cache"]["layer_{}".format(i)] \ | |
if step is not None else None | |
output, attn, attn_align = layer( | |
output, | |
src_memory_bank, | |
src_pad_mask, | |
tgt_pad_mask, | |
layer_cache=layer_cache, | |
step=step, | |
with_align=with_align) | |
if attn_align is not None: | |
attn_aligns.append(attn_align) | |
output = self.layer_norm(output) | |
dec_outs = output.transpose(0, 1).contiguous() | |
attn = attn.transpose(0, 1).contiguous() | |
attns = {"std": attn} | |
if self._copy: | |
attns["copy"] = attn | |
if with_align: | |
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` | |
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg | |
# TODO change the way attns is returned dict => list or tuple (onnx) | |
return dec_outs, attns |