Spaces:
Paused
Paused
import gradio as gr | |
import torch as T | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio | |
import matplotlib.pyplot as plt | |
from utils import load_ckpt, print_colored | |
from tokenizer import make_tokenizer | |
from model import get_hertz_dev_config | |
from typing import Tuple | |
import numpy as np | |
import os | |
# Global variables for model and tokenizer | |
global_generator = None | |
global_tokenizer = None | |
default_audio_path = "sample.wav" # Changed from "testingtesting.wav" | |
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]: | |
"""Initialize the model and tokenizer""" | |
global global_generator, global_tokenizer | |
if global_generator is not None and global_tokenizer is not None: | |
return global_generator, global_tokenizer | |
device = 'cuda' if T.cuda.is_available() else 'cpu' | |
T.cuda.set_device(0) if device == 'cuda' else None | |
print_colored("Initializing model and tokenizer...", "blue") | |
global_tokenizer = make_tokenizer(device) | |
model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation) | |
global_generator = model_config() | |
global_generator = global_generator.eval().to(T.bfloat16).to(device) | |
print_colored("Model initialization complete!", "green") | |
return global_generator, global_tokenizer | |
def process_audio(audio_path: str, sr: int) -> T.Tensor: | |
"""Load and preprocess audio file""" | |
audio_tensor, sr = torchaudio.load(audio_path) | |
if audio_tensor.shape[0] == 2: | |
audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0) | |
if sr != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) | |
audio_tensor = resampler(audio_tensor) | |
max_samples = 16000 * 60 * 5 # 5 minutes | |
if audio_tensor.shape[1] > max_samples: | |
audio_tensor = audio_tensor[:, :max_samples] | |
return audio_tensor.unsqueeze(0) | |
def generate_completion( | |
audio_file, | |
prompt_len_seconds: float = 3.0, | |
num_completions: int = 5, | |
generation_seconds: float = 20.0, | |
token_temp: float = 0.8, | |
categorical_temp: float = 0.5, | |
gaussian_temp: float = 0.1, | |
progress=gr.Progress(track_tqdm=True) | |
) -> list: | |
"""Generate audio completions from the input audio""" | |
device = 'cuda' if T.cuda.is_available() else 'cpu' | |
# Use existing model and tokenizer | |
generator, audio_tokenizer = global_generator, global_tokenizer | |
progress(0, desc="Processing input audio...") | |
# Process input audio | |
prompt_audio = process_audio(audio_file, sr=16000) | |
prompt_len = int(prompt_len_seconds * 8) | |
progress(0.2, desc="Encoding prompt...") | |
# Encode prompt | |
with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device)) | |
completions = [] | |
for i in range(num_completions): | |
progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}") | |
# Generate completion | |
encoded_prompt = encoded_prompt_audio[:, :prompt_len] | |
with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
completed_audio_batch = generator.completion( | |
encoded_prompt, | |
temps=(token_temp, (categorical_temp, gaussian_temp)), | |
use_cache=True, | |
gen_len=int(generation_seconds * 8) | |
) | |
decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16()) | |
# Process audio for output | |
audio_tensor = decoded_completion.cpu().squeeze() | |
if audio_tensor.ndim == 1: | |
audio_tensor = audio_tensor.unsqueeze(0) | |
audio_tensor = audio_tensor.float() | |
if audio_tensor.abs().max() > 1: | |
audio_tensor = audio_tensor / audio_tensor.abs().max() | |
# Trim to include only the generated portion | |
output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):] | |
completions.append((16000, output_audio.numpy().T)) | |
progress(1.0, desc="Generation complete!") | |
return completions | |
def create_interface(): | |
# Initialize model at startup | |
init_model() | |
with gr.Blocks(title="Audio Completion Generator") as app: | |
gr.Markdown(""" | |
# Audio Completion Generator | |
Upload an audio file (or use the default) and generate AI completions based on the prompt. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Load the default audio if it exists | |
default_value = default_audio_path if os.path.exists(default_audio_path) else None | |
audio_input = gr.Audio( | |
label="Input Audio", | |
type="filepath", | |
sources=["microphone", "upload"], | |
value=default_value | |
) | |
with gr.Row(): | |
prompt_len = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=0.5, | |
label="Prompt Length (seconds)" | |
) | |
default_num_completions = 5 | |
num_completions = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=default_num_completions, | |
step=1, | |
label="Number of Completions" | |
) | |
gen_length = gr.Slider( | |
minimum=5, | |
maximum=60, | |
value=20, | |
step=5, | |
label="Generation Length (seconds)" | |
) | |
with gr.Row(): | |
token_temp = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.8, | |
step=0.1, | |
label="Token Temperature" | |
) | |
cat_temp = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.5, | |
step=0.1, | |
label="Categorical Temperature" | |
) | |
gauss_temp = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.1, | |
step=0.1, | |
label="Gaussian Temperature" | |
) | |
generate_btn = gr.Button("Generate Completions") | |
status_text = gr.Markdown("Ready") | |
with gr.Column(): | |
output_audios = [] | |
for i in range(10): # Create 10 audio components | |
output_audios.append(gr.Audio( | |
label=f"Generated Completion {i+1}", | |
type="numpy", | |
visible=False | |
)) | |
def update_visibility(num): | |
return [gr.update(visible=(i < num)) for i in range(10)] | |
def generate_with_status(*args): | |
status_text.value = "Processing input audio..." | |
completions = generate_completion(*args) | |
status_text.value = "Generation complete!" | |
# Prepare outputs for all audio components | |
outputs = [] | |
for i in range(10): | |
if i < len(completions): | |
outputs.append(completions[i]) | |
else: | |
outputs.append(None) | |
return outputs | |
# Set initial visibility on load | |
app.load( | |
fn=update_visibility, | |
inputs=[num_completions], | |
outputs=output_audios | |
) | |
# Update visibility when slider changes | |
num_completions.change( | |
fn=update_visibility, | |
inputs=[num_completions], | |
outputs=output_audios | |
) | |
generate_btn.click( | |
fn=generate_with_status, | |
inputs=[ | |
audio_input, | |
prompt_len, | |
num_completions, | |
gen_length, | |
token_temp, | |
cat_temp, | |
gauss_temp | |
], | |
outputs=output_audios | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch(share=True) |