import gradio as gr import torch as T import torch.nn as nn import torch.nn.functional as F import torchaudio import matplotlib.pyplot as plt from utils import load_ckpt, print_colored from tokenizer import make_tokenizer from model import get_hertz_dev_config from typing import Tuple import numpy as np import os # Global variables for model and tokenizer global_generator = None global_tokenizer = None default_audio_path = "sample.wav" # Changed from "testingtesting.wav" def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]: """Initialize the model and tokenizer""" global global_generator, global_tokenizer if global_generator is not None and global_tokenizer is not None: return global_generator, global_tokenizer device = 'cuda' if T.cuda.is_available() else 'cpu' T.cuda.set_device(0) if device == 'cuda' else None print_colored("Initializing model and tokenizer...", "blue") global_tokenizer = make_tokenizer(device) model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation) global_generator = model_config() global_generator = global_generator.eval().to(T.bfloat16).to(device) print_colored("Model initialization complete!", "green") return global_generator, global_tokenizer def process_audio(audio_path: str, sr: int) -> T.Tensor: """Load and preprocess audio file""" audio_tensor, sr = torchaudio.load(audio_path) if audio_tensor.shape[0] == 2: audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0) if sr != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) audio_tensor = resampler(audio_tensor) max_samples = 16000 * 60 * 5 # 5 minutes if audio_tensor.shape[1] > max_samples: audio_tensor = audio_tensor[:, :max_samples] return audio_tensor.unsqueeze(0) def generate_completion( audio_file, prompt_len_seconds: float = 3.0, num_completions: int = 5, generation_seconds: float = 20.0, token_temp: float = 0.8, categorical_temp: float = 0.5, gaussian_temp: float = 0.1, progress=gr.Progress(track_tqdm=True) ) -> list: """Generate audio completions from the input audio""" device = 'cuda' if T.cuda.is_available() else 'cpu' # Use existing model and tokenizer generator, audio_tokenizer = global_generator, global_tokenizer progress(0, desc="Processing input audio...") # Process input audio prompt_audio = process_audio(audio_file, sr=16000) prompt_len = int(prompt_len_seconds * 8) progress(0.2, desc="Encoding prompt...") # Encode prompt with T.autocast(device_type='cuda', dtype=T.bfloat16): encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device)) completions = [] for i in range(num_completions): progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}") # Generate completion encoded_prompt = encoded_prompt_audio[:, :prompt_len] with T.autocast(device_type='cuda', dtype=T.bfloat16): completed_audio_batch = generator.completion( encoded_prompt, temps=(token_temp, (categorical_temp, gaussian_temp)), use_cache=True, gen_len=int(generation_seconds * 8) ) decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16()) # Process audio for output audio_tensor = decoded_completion.cpu().squeeze() if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0) audio_tensor = audio_tensor.float() if audio_tensor.abs().max() > 1: audio_tensor = audio_tensor / audio_tensor.abs().max() # Trim to include only the generated portion output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):] completions.append((16000, output_audio.numpy().T)) progress(1.0, desc="Generation complete!") return completions def create_interface(): # Initialize model at startup init_model() with gr.Blocks(title="Audio Completion Generator") as app: gr.Markdown(""" # Audio Completion Generator Upload an audio file (or use the default) and generate AI completions based on the prompt. """) with gr.Row(): with gr.Column(): # Load the default audio if it exists default_value = default_audio_path if os.path.exists(default_audio_path) else None audio_input = gr.Audio( label="Input Audio", type="filepath", sources=["microphone", "upload"], value=default_value ) with gr.Row(): prompt_len = gr.Slider( minimum=1, maximum=10, value=3, step=0.5, label="Prompt Length (seconds)" ) default_num_completions = 5 num_completions = gr.Slider( minimum=1, maximum=10, value=default_num_completions, step=1, label="Number of Completions" ) gen_length = gr.Slider( minimum=5, maximum=60, value=20, step=5, label="Generation Length (seconds)" ) with gr.Row(): token_temp = gr.Slider( minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Token Temperature" ) cat_temp = gr.Slider( minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Categorical Temperature" ) gauss_temp = gr.Slider( minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Gaussian Temperature" ) generate_btn = gr.Button("Generate Completions") status_text = gr.Markdown("Ready") with gr.Column(): output_audios = [] for i in range(10): # Create 10 audio components output_audios.append(gr.Audio( label=f"Generated Completion {i+1}", type="numpy", visible=False )) def update_visibility(num): return [gr.update(visible=(i < num)) for i in range(10)] def generate_with_status(*args): status_text.value = "Processing input audio..." completions = generate_completion(*args) status_text.value = "Generation complete!" # Prepare outputs for all audio components outputs = [] for i in range(10): if i < len(completions): outputs.append(completions[i]) else: outputs.append(None) return outputs # Set initial visibility on load app.load( fn=update_visibility, inputs=[num_completions], outputs=output_audios ) # Update visibility when slider changes num_completions.change( fn=update_visibility, inputs=[num_completions], outputs=output_audios ) generate_btn.click( fn=generate_with_status, inputs=[ audio_input, prompt_len, num_completions, gen_length, token_temp, cat_temp, gauss_temp ], outputs=output_audios ) return app if __name__ == "__main__": app = create_interface() app.launch(share=True)