|
import logging |
|
from typing import Any, Dict |
|
|
|
import numpy as np |
|
import torch |
|
from audiocraft.models import MusicGen |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
if torch.cuda.is_available(): |
|
self.device = "cuda" |
|
else: |
|
self.device = "cpu" |
|
|
|
|
|
self.channels = 1 |
|
self.model = MusicGen.get_pretrained( |
|
"facebook/musicgen-large", device=self.device |
|
) |
|
self.sample_rate = self.model.sample_rate |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]: |
|
""" |
|
This call function is called by the endpoint. It takes in a payload and returns an audio signal. |
|
The main advantage of this function is that it supports generation of audio in chunks, |
|
so the limitation of 30s audio generation is removed for the model. |
|
The payload should be a dictionary with the following keys: |
|
prompt: The prompt to generate audio for. |
|
generation_params: A dictionary of generation parameters. The following keys are supported: |
|
duration: The duration of audio to generate in seconds. Default: 30 |
|
temperature: The temperature to use for generation. Default: 0.8 |
|
top_p: The top p value to use for generation. Default: 0.0 |
|
top_k: The top k value to use for generation. Default: 250 |
|
cfg_coef: The amount of classifier free guidance to use. Default: 0.0 |
|
These values are passed to the model's set_generation_params function. Other |
|
values can be passed as well if they are supported by the model. |
|
audio_window: The amount of audio to use as prompt for the next chunk. Default: 20 |
|
chunk_size: The size of each chunk in seconds. Default: 30 |
|
|
|
Args: |
|
data (Dict[str, Any]): The payload to generate audio for. |
|
|
|
Raises: |
|
ValueError: If chunk_size is less than audio_window |
|
or if the duration is not a multiple of chunk_size - audio_window |
|
|
|
Returns: |
|
Dict[str, str]: A dictionary with the generated audio. |
|
""" |
|
prompt = data["prompt"] |
|
|
|
generation_params = data.get("generation_params", {}) |
|
|
|
duration = generation_params.get("duration", 30) |
|
|
|
if duration <= 30: |
|
logger.info(f"Generating audio with duration {duration} in one go.") |
|
self.model.set_generation_params(**generation_params) |
|
final_audio = self.model.generate([prompt], progress=True) |
|
else: |
|
logger.info(f"Generating audio with duration {duration} in chunks.") |
|
|
|
audio_window = data.get("audio_window", 20) |
|
chunk_size = data.get("chunk_size", 30) |
|
continuation = chunk_size - audio_window |
|
final_duration = duration |
|
|
|
if chunk_size < audio_window: |
|
raise ValueError( |
|
f"Chunk size {chunk_size} must be greater than audio window {audio_window}" |
|
) |
|
|
|
if (final_duration - chunk_size) % continuation != 0: |
|
raise ValueError( |
|
f"Duration ({duration} secs) - chunksize ({chunk_size} secs)" |
|
f" must be a multiple of continuation ({continuation} secs)" |
|
) |
|
|
|
generation_params["duration"] = chunk_size |
|
self.model.set_generation_params(**generation_params) |
|
|
|
logger.info( |
|
f"Generating total audio {final_duration} secs with chunks of {chunk_size} secs " |
|
f"and continuation of {continuation} secs." |
|
) |
|
|
|
|
|
logger.info(f"Initializing final audio with {chunk_size} secs of audio.") |
|
final_audio = torch.zeros( |
|
( |
|
self.channels, |
|
self.sample_rate * final_duration, |
|
), |
|
dtype=torch.float, |
|
).to(self.device) |
|
|
|
final_audio[ |
|
:, |
|
: chunk_size * self.sample_rate, |
|
] = self.model.generate([prompt], progress=True) |
|
|
|
n_hops = (final_duration - chunk_size) // continuation |
|
for i_hop in range(n_hops): |
|
logger.info(f"Generating audio for hop {i_hop}") |
|
|
|
prompt_stop = chunk_size + i_hop * continuation |
|
prompt_start = prompt_stop - audio_window |
|
|
|
audio_prompt = final_audio[ |
|
:, prompt_start * self.sample_rate : prompt_stop * self.sample_rate |
|
].reshape(1, self.channels, -1) |
|
|
|
output = self.model.generate_continuation( |
|
audio_prompt, |
|
self.sample_rate, |
|
[prompt], |
|
progress=True, |
|
) |
|
|
|
final_audio[ |
|
:, |
|
prompt_stop |
|
* self.sample_rate : (prompt_stop + continuation) |
|
* self.sample_rate, |
|
] = output[..., audio_window * self.sample_rate :] |
|
logger.info( |
|
f"finished generating audio till {(prompt_stop + continuation)} secs." |
|
) |
|
|
|
return {"generated_audio": final_audio.cpu().numpy().transpose()} |
|
|