import logging from typing import Any, Dict import numpy as np import torch from audiocraft.models import MusicGen logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" # If you want to use a different model, change the model name here # Stereo models are also supported but you need to change the channels to 2 self.channels = 1 self.model = MusicGen.get_pretrained( "facebook/musicgen-large", device=self.device ) self.sample_rate = self.model.sample_rate def __call__(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]: """ This call function is called by the endpoint. It takes in a payload and returns an audio signal. The main advantage of this function is that it supports generation of audio in chunks, so the limitation of 30s audio generation is removed for the model. The payload should be a dictionary with the following keys: prompt: The prompt to generate audio for. generation_params: A dictionary of generation parameters. The following keys are supported: duration: The duration of audio to generate in seconds. Default: 30 temperature: The temperature to use for generation. Default: 0.8 top_p: The top p value to use for generation. Default: 0.0 top_k: The top k value to use for generation. Default: 250 cfg_coef: The amount of classifier free guidance to use. Default: 0.0 These values are passed to the model's set_generation_params function. Other values can be passed as well if they are supported by the model. audio_window: The amount of audio to use as prompt for the next chunk. Default: 20 chunk_size: The size of each chunk in seconds. Default: 30 Args: data (Dict[str, Any]): The payload to generate audio for. Raises: ValueError: If chunk_size is less than audio_window or if the duration is not a multiple of chunk_size - audio_window Returns: Dict[str, str]: A dictionary with the generated audio. """ prompt = data["inputs"] generation_params = data.get("generation_params", {}) duration = generation_params.get("duration", 30) if duration <= 30: logger.info(f"Generating audio with duration {duration} in one go.") self.model.set_generation_params(**generation_params) final_audio = self.model.generate([prompt], progress=True) else: logger.info(f"Generating audio with duration {duration} in chunks.") audio_window = data.get("audio_window", 20) chunk_size = data.get("chunk_size", 30) continuation = chunk_size - audio_window final_duration = duration if chunk_size < audio_window: raise ValueError( f"Chunk size {chunk_size} must be greater than audio window {audio_window}" ) if (final_duration - chunk_size) % continuation != 0: raise ValueError( f"Duration ({duration} secs) - chunksize ({chunk_size} secs)" f" must be a multiple of continuation ({continuation} secs)" ) generation_params["duration"] = chunk_size self.model.set_generation_params(**generation_params) logger.info( f"Generating total audio {final_duration} secs with chunks of {chunk_size} secs " f"and continuation of {continuation} secs." ) # Iniitalize final audio logger.info(f"Initializing final audio with {chunk_size} secs of audio.") final_audio = torch.zeros( ( self.channels, self.sample_rate * final_duration, ), dtype=torch.float, ).to(self.device) final_audio[ :, : chunk_size * self.sample_rate, ] = self.model.generate([prompt], progress=True) n_hops = (final_duration - chunk_size) // continuation for i_hop in range(n_hops): logger.info(f"Generating audio for hop {i_hop}") prompt_stop = chunk_size + i_hop * continuation prompt_start = prompt_stop - audio_window audio_prompt = final_audio[ :, prompt_start * self.sample_rate : prompt_stop * self.sample_rate ].reshape(1, self.channels, -1) output = self.model.generate_continuation( audio_prompt, self.sample_rate, [prompt], progress=True, ) final_audio[ :, prompt_stop * self.sample_rate : (prompt_stop + continuation) * self.sample_rate, ] = output[..., audio_window * self.sample_rate :] logger.info( f"finished generating audio till {(prompt_stop + continuation)} secs." ) return {"generated_audio": final_audio.cpu().numpy().transpose()}