|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
|
|
from models.tts.base.tts_inferece import TTSInference |
|
from models.tts.jets.jets_dataset import JetsTestDataset, JetsTestCollator |
|
from utils.util import load_config |
|
from utils.io import save_audio |
|
from models.tts.jets.jets import Jets |
|
from models.vocoders.vocoder_inference import synthesis |
|
from pathlib import Path |
|
from processors.phone_extractor import phoneExtractor |
|
from text.text_token_collation import phoneIDCollation |
|
import numpy as np |
|
import json |
|
import time |
|
|
|
|
|
class JetsInference(TTSInference): |
|
def __init__(self, args, cfg): |
|
TTSInference.__init__(self, args, cfg) |
|
self.args = args |
|
self.cfg = cfg |
|
self.infer_type = args.mode |
|
|
|
def _build_model(self): |
|
self.model = Jets(self.cfg) |
|
return self.model |
|
|
|
def _build_test_dataset(self): |
|
return JetsTestDataset, JetsTestCollator |
|
|
|
def inference_for_batches(self): |
|
|
|
n_batch = len(self.test_dataloader) |
|
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) |
|
print( |
|
"Model eval time: {}, batch_size = {}, n_batch = {}".format( |
|
now, self.test_batch_size, n_batch |
|
) |
|
) |
|
self.model.eval() |
|
|
|
|
|
pred_res = [] |
|
with torch.no_grad(): |
|
for i, batch_data in enumerate( |
|
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader) |
|
): |
|
outputs = self.model.inference(batch_data) |
|
|
|
audios, d_predictions = outputs |
|
d_predictions = d_predictions.unsqueeze(-1) |
|
|
|
for idx in range(audios.size(0)): |
|
audio = audios[idx, 0, :].data.cpu().float() |
|
duration = d_predictions[idx, :, :] |
|
audio_length = ( |
|
duration.sum([0, 1]).long() * self.cfg.preprocess.hop_size |
|
) |
|
audio_length = audio_length.cpu().numpy() |
|
audio = audio[:audio_length] |
|
pred_res.append(audio) |
|
|
|
return pred_res |
|
|