File size: 5,510 Bytes
cb8faa1
dfc87a6
 
cb8faa1
 
dfc87a6
 
cb8faa1
 
 
dfc87a6
 
 
cb8faa1
 
 
 
 
 
 
 
 
 
 
dfc87a6
cb8faa1
dfc87a6
cb8faa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc87a6
cb8faa1
 
 
 
 
 
 
 
dfc87a6
245bc3c
cb8faa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc87a6
cb8faa1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import logging
from typing import Any, Dict

import numpy as np
import torch
from audiocraft.models import MusicGen

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EndpointHandler:
    def __init__(self, path=""):
        if torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"
        # If you want to use a different model, change the model name here
        # Stereo models are also supported but you need to change the channels to 2
        self.channels = 1
        self.model = MusicGen.get_pretrained(
            "facebook/musicgen-large", device=self.device
        )
        self.sample_rate = self.model.sample_rate

    def __call__(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]:
        """
        This call function is called by the endpoint. It takes in a payload and returns an audio signal.
        The main advantage of this function is that it supports generation of audio in chunks,
        so the limitation of 30s audio generation is removed for the model.
        The payload should be a dictionary with the following keys:
        prompt: The prompt to generate audio for.
        generation_params: A dictionary of generation parameters. The following keys are supported:
            duration: The duration of audio to generate in seconds. Default: 30
            temperature: The temperature to use for generation. Default: 0.8
            top_p: The top p value to use for generation. Default: 0.0
            top_k: The top k value to use for generation. Default: 250
            cfg_coef: The amount of classifier free guidance to use. Default: 0.0
            These values are passed to the model's set_generation_params function. Other
            values can be passed as well if they are supported by the model.
        audio_window: The amount of audio to use as prompt for the next chunk. Default: 20
        chunk_size: The size of each chunk in seconds. Default: 30

        Args:
            data (Dict[str, Any]): The payload to generate audio for.

        Raises:
            ValueError: If chunk_size is less than audio_window
                or if the duration is not a multiple of chunk_size - audio_window

        Returns:
            Dict[str, str]: A dictionary with the generated audio.
        """
        prompt = data["inputs"]

        generation_params = data.get("generation_params", {})

        duration = generation_params.get("duration", 30)

        if duration <= 30:
            logger.info(f"Generating audio with duration {duration} in one go.")
            self.model.set_generation_params(**generation_params)
            final_audio = self.model.generate([prompt], progress=True)
        else:
            logger.info(f"Generating audio with duration {duration} in chunks.")

            audio_window = data.get("audio_window", 20)
            chunk_size = data.get("chunk_size", 30)
            continuation = chunk_size - audio_window
            final_duration = duration

            if chunk_size < audio_window:
                raise ValueError(
                    f"Chunk size {chunk_size} must be greater than audio window {audio_window}"
                )

            if (final_duration - chunk_size) % continuation != 0:
                raise ValueError(
                    f"Duration ({duration} secs) - chunksize ({chunk_size} secs)"
                    f" must be a multiple of continuation ({continuation} secs)"
                )

            generation_params["duration"] = chunk_size
            self.model.set_generation_params(**generation_params)

            logger.info(
                f"Generating total audio {final_duration} secs with chunks of {chunk_size} secs "
                f"and continuation of {continuation} secs."
            )

            # Iniitalize final audio
            logger.info(f"Initializing final audio with {chunk_size} secs of audio.")
            final_audio = torch.zeros(
                (
                    self.channels,
                    self.sample_rate * final_duration,
                ),
                dtype=torch.float,
            ).to(self.device)

            final_audio[
                :,
                : chunk_size * self.sample_rate,
            ] = self.model.generate([prompt], progress=True)

            n_hops = (final_duration - chunk_size) // continuation
            for i_hop in range(n_hops):
                logger.info(f"Generating audio for hop {i_hop}")

                prompt_stop = chunk_size + i_hop * continuation
                prompt_start = prompt_stop - audio_window

                audio_prompt = final_audio[
                    :, prompt_start * self.sample_rate : prompt_stop * self.sample_rate
                ].reshape(1, self.channels, -1)

                output = self.model.generate_continuation(
                    audio_prompt,
                    self.sample_rate,
                    [prompt],
                    progress=True,
                )

                final_audio[
                    :,
                    prompt_stop
                    * self.sample_rate : (prompt_stop + continuation)
                    * self.sample_rate,
                ] = output[..., audio_window * self.sample_rate :]
                logger.info(
                    f"finished generating audio till {(prompt_stop + continuation)} secs."
                )

        return {"generated_audio": final_audio.cpu().numpy().transpose()}