File size: 2,699 Bytes
24f6ec0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import argparse
import yaml
import torch
from audioldm import LatentDiffusion, seed_everything
from audioldm.utils import default_audioldm_config
import time
def make_batch_for_text_to_audio(text, batchsize=1):
text = [text] * batchsize
if batchsize < 1:
print("Warning: Batchsize must be at least 1. Batchsize is set to .")
fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format
stft = torch.zeros((batchsize, 1024, 512)) # Not used
waveform = torch.zeros((batchsize, 160000)) # Not used
fname = [""] * batchsize # Not used
batch = (
fbank,
stft,
None,
fname,
waveform,
text,
)
return batch
def build_model(
ckpt_path=None,
config=None,
model_name="audioldm-s-full"
):
print("Load AudioLDM: %s" % model_name)
resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
# if(ckpt_path is None):
# ckpt_path = get_metadata()[model_name]["path"]
# if(not os.path.exists(ckpt_path)):
# download_checkpoint(model_name)
if(torch.cuda.is_available()):
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
if(config is not None):
assert type(config) is str
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
else:
config = default_audioldm_config(model_name)
# Use text as condition instead of using waveform during training
config["model"]["params"]["device"] = device
config["model"]["params"]["cond_stage_key"] = "text"
# No normalization here
latent_diffusion = LatentDiffusion(**config["model"]["params"])
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
latent_diffusion.load_state_dict(checkpoint["state_dict"])
latent_diffusion.eval()
latent_diffusion = latent_diffusion.to(device)
latent_diffusion.cond_stage_model.embed_mode = "text"
return latent_diffusion
def duration_to_latent_t_size(duration):
return int(duration * 25.6)
def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None):
seed_everything(int(seed))
batch = make_batch_for_text_to_audio(text, batchsize=batchsize)
latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
with torch.no_grad():
waveform = latent_diffusion.generate_sample(
[batch],
unconditional_guidance_scale=guidance_scale,
n_candidate_gen_per_text=n_candidate_gen_per_text,
duration=duration
)
return waveform
|