|
import os |
|
|
|
import argparse |
|
import yaml |
|
import torch |
|
from torch import autocast |
|
from tqdm import tqdm, trange |
|
|
|
from audioldm import LatentDiffusion, seed_everything |
|
from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint |
|
from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file |
|
from audioldm.latent_diffusion.ddim import DDIMSampler |
|
from einops import repeat |
|
import os |
|
|
|
def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): |
|
text = [text] * batchsize |
|
if batchsize < 1: |
|
print("Warning: Batchsize must be at least 1. Batchsize is set to .") |
|
|
|
if(fbank is None): |
|
fbank = torch.zeros((batchsize, 1024, 64)) |
|
else: |
|
fbank = torch.FloatTensor(fbank) |
|
fbank = fbank.expand(batchsize, 1024, 64) |
|
assert fbank.size(0) == batchsize |
|
|
|
stft = torch.zeros((batchsize, 1024, 512)) |
|
|
|
if(waveform is None): |
|
waveform = torch.zeros((batchsize, 160000)) |
|
else: |
|
waveform = torch.FloatTensor(waveform) |
|
waveform = waveform.expand(batchsize, -1) |
|
assert waveform.size(0) == batchsize |
|
|
|
fname = [""] * batchsize |
|
|
|
batch = ( |
|
fbank, |
|
stft, |
|
None, |
|
fname, |
|
waveform, |
|
text, |
|
) |
|
return batch |
|
|
|
def round_up_duration(duration): |
|
return int(round(duration/2.5) + 1) * 2.5 |
|
|
|
def build_model( |
|
ckpt_path=None, |
|
config=None, |
|
model_name="audioldm-s-full" |
|
): |
|
print("Load AudioLDM: %s", model_name) |
|
|
|
if(ckpt_path is None): |
|
ckpt_path = get_metadata()[model_name]["path"] |
|
|
|
if(not os.path.exists(ckpt_path)): |
|
download_checkpoint(model_name) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
if config is not None: |
|
assert type(config) is str |
|
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) |
|
else: |
|
config = default_audioldm_config(model_name) |
|
|
|
|
|
config["model"]["params"]["device"] = device |
|
config["model"]["params"]["cond_stage_key"] = "text" |
|
|
|
|
|
latent_diffusion = LatentDiffusion(**config["model"]["params"]) |
|
|
|
resume_from_checkpoint = ckpt_path |
|
|
|
checkpoint = torch.load(resume_from_checkpoint, map_location=device) |
|
latent_diffusion.load_state_dict(checkpoint["state_dict"]) |
|
|
|
latent_diffusion.eval() |
|
latent_diffusion = latent_diffusion.to(device) |
|
|
|
latent_diffusion.cond_stage_model.embed_mode = "text" |
|
return latent_diffusion |
|
|
|
def duration_to_latent_t_size(duration): |
|
return int(duration * 25.6) |
|
|
|
def set_cond_audio(latent_diffusion): |
|
latent_diffusion.cond_stage_key = "waveform" |
|
latent_diffusion.cond_stage_model.embed_mode="audio" |
|
return latent_diffusion |
|
|
|
def set_cond_text(latent_diffusion): |
|
latent_diffusion.cond_stage_key = "text" |
|
latent_diffusion.cond_stage_model.embed_mode="text" |
|
return latent_diffusion |
|
|
|
def text_to_audio( |
|
latent_diffusion, |
|
text, |
|
original_audio_file_path = None, |
|
seed=42, |
|
ddim_steps=200, |
|
duration=10, |
|
batchsize=1, |
|
guidance_scale=2.5, |
|
n_candidate_gen_per_text=3, |
|
config=None, |
|
): |
|
seed_everything(int(seed)) |
|
waveform = None |
|
if(original_audio_file_path is not None): |
|
waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) |
|
|
|
batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) |
|
|
|
latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) |
|
|
|
if(waveform is not None): |
|
print("Generate audio that has similar content as %s" % original_audio_file_path) |
|
latent_diffusion = set_cond_audio(latent_diffusion) |
|
else: |
|
print("Generate audio using text %s" % text) |
|
latent_diffusion = set_cond_text(latent_diffusion) |
|
|
|
with torch.no_grad(): |
|
waveform = latent_diffusion.generate_sample( |
|
[batch], |
|
unconditional_guidance_scale=guidance_scale, |
|
ddim_steps=ddim_steps, |
|
n_candidate_gen_per_text=n_candidate_gen_per_text, |
|
duration=duration, |
|
) |
|
return waveform |
|
|
|
def style_transfer( |
|
latent_diffusion, |
|
text, |
|
original_audio_file_path, |
|
transfer_strength, |
|
seed=42, |
|
duration=10, |
|
batchsize=1, |
|
guidance_scale=2.5, |
|
ddim_steps=200, |
|
config=None, |
|
): |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
assert original_audio_file_path is not None, "You need to provide the original audio file path" |
|
|
|
audio_file_duration = get_duration(original_audio_file_path) |
|
|
|
assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path |
|
|
|
|
|
|
|
|
|
|
|
if(duration >= audio_file_duration): |
|
print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration)) |
|
duration = round_up_duration(audio_file_duration) |
|
print("Set new duration as %s-seconds" % duration) |
|
|
|
|
|
|
|
latent_diffusion = set_cond_text(latent_diffusion) |
|
|
|
if config is not None: |
|
assert type(config) is str |
|
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) |
|
else: |
|
config = default_audioldm_config() |
|
|
|
seed_everything(int(seed)) |
|
|
|
latent_diffusion.cond_stage_model.embed_mode = "text" |
|
|
|
fn_STFT = TacotronSTFT( |
|
config["preprocessing"]["stft"]["filter_length"], |
|
config["preprocessing"]["stft"]["hop_length"], |
|
config["preprocessing"]["stft"]["win_length"], |
|
config["preprocessing"]["mel"]["n_mel_channels"], |
|
config["preprocessing"]["audio"]["sampling_rate"], |
|
config["preprocessing"]["mel"]["mel_fmin"], |
|
config["preprocessing"]["mel"]["mel_fmax"], |
|
) |
|
|
|
mel, _, _ = wav_to_fbank( |
|
original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT |
|
) |
|
mel = mel.unsqueeze(0).unsqueeze(0).to(device) |
|
mel = repeat(mel, "1 ... -> b ...", b=batchsize) |
|
init_latent = latent_diffusion.get_first_stage_encoding( |
|
latent_diffusion.encode_first_stage(mel) |
|
) |
|
if(torch.max(torch.abs(init_latent)) > 1e2): |
|
init_latent = torch.clip(init_latent, min=-10, max=10) |
|
sampler = DDIMSampler(latent_diffusion) |
|
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) |
|
|
|
t_enc = int(transfer_strength * ddim_steps) |
|
prompts = text |
|
|
|
with torch.no_grad(): |
|
with autocast("cuda"): |
|
with latent_diffusion.ema_scope(): |
|
uc = None |
|
if guidance_scale != 1.0: |
|
uc = latent_diffusion.cond_stage_model.get_unconditional_condition( |
|
batchsize |
|
) |
|
|
|
c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) |
|
z_enc = sampler.stochastic_encode( |
|
init_latent, torch.tensor([t_enc] * batchsize).to(device) |
|
) |
|
samples = sampler.decode( |
|
z_enc, |
|
c, |
|
t_enc, |
|
unconditional_guidance_scale=guidance_scale, |
|
unconditional_conditioning=uc, |
|
) |
|
|
|
|
|
x_samples = latent_diffusion.decode_first_stage(samples) |
|
|
|
x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:]) |
|
|
|
waveform = latent_diffusion.first_stage_model.decode_to_waveform( |
|
x_samples |
|
) |
|
|
|
return waveform |
|
|
|
def super_resolution_and_inpainting( |
|
latent_diffusion, |
|
text, |
|
original_audio_file_path = None, |
|
seed=42, |
|
ddim_steps=200, |
|
duration=None, |
|
batchsize=1, |
|
guidance_scale=2.5, |
|
n_candidate_gen_per_text=3, |
|
time_mask_ratio_start_and_end=(0.10, 0.15), |
|
|
|
|
|
freq_mask_ratio_start_and_end=(1.0, 1.0), |
|
config=None, |
|
): |
|
seed_everything(int(seed)) |
|
if config is not None: |
|
assert type(config) is str |
|
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) |
|
else: |
|
config = default_audioldm_config() |
|
fn_STFT = TacotronSTFT( |
|
config["preprocessing"]["stft"]["filter_length"], |
|
config["preprocessing"]["stft"]["hop_length"], |
|
config["preprocessing"]["stft"]["win_length"], |
|
config["preprocessing"]["mel"]["n_mel_channels"], |
|
config["preprocessing"]["audio"]["sampling_rate"], |
|
config["preprocessing"]["mel"]["mel_fmin"], |
|
config["preprocessing"]["mel"]["mel_fmax"], |
|
) |
|
|
|
|
|
mel, _, _ = wav_to_fbank( |
|
original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT |
|
) |
|
|
|
batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize) |
|
|
|
|
|
latent_diffusion = set_cond_text(latent_diffusion) |
|
|
|
with torch.no_grad(): |
|
waveform = latent_diffusion.generate_sample_masked( |
|
[batch], |
|
unconditional_guidance_scale=guidance_scale, |
|
ddim_steps=ddim_steps, |
|
n_candidate_gen_per_text=n_candidate_gen_per_text, |
|
duration=duration, |
|
time_mask_ratio_start_and_end=time_mask_ratio_start_and_end, |
|
freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end |
|
) |
|
return waveform |