audiocraft_handler / handler.py
skroed
Feat: Support monaural large model
cb8faa1
raw
history blame
5.51 kB
import logging
from typing import Any, Dict
import numpy as np
import torch
from audiocraft.models import MusicGen
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
# If you want to use a different model, change the model name here
# Stereo models are also supported but you need to change the channels to 2
self.channels = 1
self.model = MusicGen.get_pretrained(
"facebook/musicgen-large", device=self.device
)
self.sample_rate = self.model.sample_rate
def __call__(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""
This call function is called by the endpoint. It takes in a payload and returns an audio signal.
The main advantage of this function is that it supports generation of audio in chunks,
so the limitation of 30s audio generation is removed for the model.
The payload should be a dictionary with the following keys:
prompt: The prompt to generate audio for.
generation_params: A dictionary of generation parameters. The following keys are supported:
duration: The duration of audio to generate in seconds. Default: 30
temperature: The temperature to use for generation. Default: 0.8
top_p: The top p value to use for generation. Default: 0.0
top_k: The top k value to use for generation. Default: 250
cfg_coef: The amount of classifier free guidance to use. Default: 0.0
These values are passed to the model's set_generation_params function. Other
values can be passed as well if they are supported by the model.
audio_window: The amount of audio to use as prompt for the next chunk. Default: 20
chunk_size: The size of each chunk in seconds. Default: 30
Args:
data (Dict[str, Any]): The payload to generate audio for.
Raises:
ValueError: If chunk_size is less than audio_window
or if the duration is not a multiple of chunk_size - audio_window
Returns:
Dict[str, str]: A dictionary with the generated audio.
"""
prompt = data["prompt"]
generation_params = data.get("generation_params", {})
duration = generation_params.get("duration", 30)
if duration <= 30:
logger.info(f"Generating audio with duration {duration} in one go.")
self.model.set_generation_params(**generation_params)
final_audio = self.model.generate([prompt], progress=True)
else:
logger.info(f"Generating audio with duration {duration} in chunks.")
audio_window = data.get("audio_window", 20)
chunk_size = data.get("chunk_size", 30)
continuation = chunk_size - audio_window
final_duration = duration
if chunk_size < audio_window:
raise ValueError(
f"Chunk size {chunk_size} must be greater than audio window {audio_window}"
)
if (final_duration - chunk_size) % continuation != 0:
raise ValueError(
f"Duration ({duration} secs) - chunksize ({chunk_size} secs)"
f" must be a multiple of continuation ({continuation} secs)"
)
generation_params["duration"] = chunk_size
self.model.set_generation_params(**generation_params)
logger.info(
f"Generating total audio {final_duration} secs with chunks of {chunk_size} secs "
f"and continuation of {continuation} secs."
)
# Iniitalize final audio
logger.info(f"Initializing final audio with {chunk_size} secs of audio.")
final_audio = torch.zeros(
(
self.channels,
self.sample_rate * final_duration,
),
dtype=torch.float,
).to(self.device)
final_audio[
:,
: chunk_size * self.sample_rate,
] = self.model.generate([prompt], progress=True)
n_hops = (final_duration - chunk_size) // continuation
for i_hop in range(n_hops):
logger.info(f"Generating audio for hop {i_hop}")
prompt_stop = chunk_size + i_hop * continuation
prompt_start = prompt_stop - audio_window
audio_prompt = final_audio[
:, prompt_start * self.sample_rate : prompt_stop * self.sample_rate
].reshape(1, self.channels, -1)
output = self.model.generate_continuation(
audio_prompt,
self.sample_rate,
[prompt],
progress=True,
)
final_audio[
:,
prompt_stop
* self.sample_rate : (prompt_stop + continuation)
* self.sample_rate,
] = output[..., audio_window * self.sample_rate :]
logger.info(
f"finished generating audio till {(prompt_stop + continuation)} secs."
)
return {"generated_audio": final_audio.cpu().numpy().transpose()}