|
|
|
|
|
import os |
|
|
|
import argparse |
|
import yaml |
|
import torch |
|
|
|
from audioldm import LatentDiffusion, seed_everything |
|
from audioldm.utils import default_audioldm_config |
|
|
|
|
|
import time |
|
|
|
def make_batch_for_text_to_audio(text, batchsize=1): |
|
text = [text] * batchsize |
|
if batchsize < 1: |
|
print("Warning: Batchsize must be at least 1. Batchsize is set to .") |
|
fbank = torch.zeros((batchsize, 1024, 64)) |
|
stft = torch.zeros((batchsize, 1024, 512)) |
|
waveform = torch.zeros((batchsize, 160000)) |
|
fname = [""] * batchsize |
|
batch = ( |
|
fbank, |
|
stft, |
|
None, |
|
fname, |
|
waveform, |
|
text, |
|
) |
|
return batch |
|
|
|
|
|
|
|
def build_model( |
|
ckpt_path=None, |
|
config=None, |
|
model_name="audioldm-s-full" |
|
): |
|
print("Load AudioLDM: %s" % model_name) |
|
|
|
resume_from_checkpoint = "ckpt/%s.ckpt" % model_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if(torch.cuda.is_available()): |
|
device = torch.device("cuda:0") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
if(config is not None): |
|
assert type(config) is str |
|
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) |
|
else: |
|
config = default_audioldm_config(model_name) |
|
|
|
|
|
config["model"]["params"]["device"] = device |
|
config["model"]["params"]["cond_stage_key"] = "text" |
|
|
|
|
|
latent_diffusion = LatentDiffusion(**config["model"]["params"]) |
|
|
|
checkpoint = torch.load(resume_from_checkpoint, map_location=device) |
|
latent_diffusion.load_state_dict(checkpoint["state_dict"]) |
|
|
|
latent_diffusion.eval() |
|
latent_diffusion = latent_diffusion.to(device) |
|
|
|
latent_diffusion.cond_stage_model.embed_mode = "text" |
|
return latent_diffusion |
|
|
|
def duration_to_latent_t_size(duration): |
|
return int(duration * 25.6) |
|
|
|
def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None): |
|
seed_everything(int(seed)) |
|
batch = make_batch_for_text_to_audio(text, batchsize=batchsize) |
|
|
|
latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) |
|
with torch.no_grad(): |
|
waveform = latent_diffusion.generate_sample( |
|
[batch], |
|
unconditional_guidance_scale=guidance_scale, |
|
n_candidate_gen_per_text=n_candidate_gen_per_text, |
|
duration=duration |
|
) |
|
return waveform |
|
|