import gradio as gr
import os, json
from generator import HijackedMusicGen
from audiocraft.data.audio import audio_write
from audio import predict
from itertools import zip_longest
def split_prompt(bigly_prompt, num_segments):
prompts = bigly_prompt.split(',,')
num_segments = int(num_segments) # Assuming 'segment' comes as a string from Gradio slider
# repeat last prompt to fill in the rest
if len(prompts) < num_segments:
prompts += [prompts[-1]] * (num_segments - len(prompts))
elif len(prompts) > num_segments:
prompts = prompts[:num_segments]
return prompts
loaded_model = None
audio_files = []
def model_interface(model_name, top_k, top_p, temperature, cfg_coef, segments, overlap, duration, optional_audio, prompt):
global loaded_model
if loaded_model is None or loaded_model.name != model_name:
loaded_model = HijackedMusicGen.get_pretrained(None, name=model_name)
print(optional_audio)
loaded_model.set_generation_params(
use_sampling=True,
duration=duration,
top_p=top_p,
top_k=top_k,
temperature=temperature,
cfg_coef=cfg_coef,
)
extension_parameters = {"segments":segments, "overlap":overlap}
optional_audio_parameters = {"optional_audio":optional_audio, "sample_rate":loaded_model.sample_rate}
prompts = split_prompt(prompt, segments)
first_prompt = prompts[0]
sample_rate, audio = predict(loaded_model, prompts, duration, optional_audio_parameters, extension_parameters)
counter = 1
audio_path = "static/"
audio_name = first_prompt
while os.path.exists(audio_path + audio_name + ".wav"):
audio_name = f"{first_prompt}({counter})"
counter += 1
file = audio_write(audio_path + audio_name, audio.squeeze(), sample_rate, strategy="loudness")
audio_files.append(file)
audio_list_html = "
".join([
f'''