import torch import torch.nn as nn import torch.nn.functional as F from onmt_modules.misc import sequence_mask class DecodeFunc_Sp(object): """ Decoding functions """ def __init__(self, hparams, type_out): if type_out == 'Sp': self.dim_freq = hparams.dim_freq self.max_decoder_steps = hparams.dec_steps_sp elif type_out == 'Tx': self.dim_freq = hparams.dim_code self.max_decoder_steps = hparams.dec_steps_tx else: raise ValueError self.gate_threshold = hparams.gate_threshold self.type_out = type_out def __call__(self, tgt, memory_bank, memory_lengths, decoder, postnet): dec_outs, attns = decoder(tgt, memory_bank, step=None, memory_lengths=memory_lengths) spect_gate = postnet(dec_outs) spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1] return spect, gate def infer(self, tgt_real, memory_bank, memory_lengths, decoder, postnet): B = memory_bank.size(1) device = memory_bank.device spect_outputs = torch.zeros((self.max_decoder_steps, B, self.dim_freq), dtype=torch.float, device=device) gate_outputs = torch.zeros((self.max_decoder_steps, B, 1), dtype=torch.float, device=device) tgt_words = torch.zeros([B, 1], dtype=torch.float, device=device) current_pred = torch.zeros([1, B, self.dim_freq], dtype=torch.float, device=device) for t in range(self.max_decoder_steps): dec_outs, _ = decoder(current_pred, memory_bank, t, memory_lengths=memory_lengths, tgt_words=tgt_words) spect_gate = postnet(dec_outs) spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1] spect_outputs[t:t+1] = spect gate_outputs[t:t+1] = gate stop = (torch.sigmoid(gate) - self.gate_threshold + 0.5).round() current_pred = spect.data tgt_words = stop.squeeze(-1).t() if t == self.max_decoder_steps - 1: print(f"Warning! {self.type_out} reached max decoder steps") if (stop == 1).all(): break stop_quant = (torch.sigmoid(gate_outputs.data) - self.gate_threshold + 0.5).round().squeeze(-1) len_spect = (stop_quant.cumsum(dim=0)==0).sum(dim=0) return spect_outputs, len_spect, gate_outputs