skroed commited on
Commit
cb8faa1
1 Parent(s): dfc87a6

Feat: Support monaural large model

Browse files
Files changed (2) hide show
  1. handler.py +122 -13
  2. local_enpoint_test.py +24 -0
handler.py CHANGED
@@ -1,25 +1,134 @@
 
1
  from typing import Any, Dict
2
 
 
 
3
  from audiocraft.models import MusicGen
4
 
 
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- # load model and processor from path
9
- self.model = MusicGen.get_pretrained("small")
 
 
 
 
 
 
 
 
 
10
 
11
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
12
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  Args:
14
- data (:dict:):
15
- The payload with the text prompt and generation parameters.
 
 
 
 
 
 
16
  """
17
- # process input
18
- output = (
19
- self.model.generate_unconditional(num_samples=2, progress=True)[0]
20
- .cpu()
21
- .numpy()
22
- .tolist()
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- return [{"generated_audio": output}]
 
1
+ import logging
2
  from typing import Any, Dict
3
 
4
+ import numpy as np
5
+ import torch
6
  from audiocraft.models import MusicGen
7
 
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
14
+ if torch.cuda.is_available():
15
+ self.device = "cuda"
16
+ else:
17
+ self.device = "cpu"
18
+ # If you want to use a different model, change the model name here
19
+ # Stereo models are also supported but you need to change the channels to 2
20
+ self.channels = 1
21
+ self.model = MusicGen.get_pretrained(
22
+ "facebook/musicgen-large", device=self.device
23
+ )
24
+ self.sample_rate = self.model.sample_rate
25
 
26
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, np.ndarray]:
27
  """
28
+ This call function is called by the endpoint. It takes in a payload and returns an audio signal.
29
+ The main advantage of this function is that it supports generation of audio in chunks,
30
+ so the limitation of 30s audio generation is removed for the model.
31
+ The payload should be a dictionary with the following keys:
32
+ prompt: The prompt to generate audio for.
33
+ generation_params: A dictionary of generation parameters. The following keys are supported:
34
+ duration: The duration of audio to generate in seconds. Default: 30
35
+ temperature: The temperature to use for generation. Default: 0.8
36
+ top_p: The top p value to use for generation. Default: 0.0
37
+ top_k: The top k value to use for generation. Default: 250
38
+ cfg_coef: The amount of classifier free guidance to use. Default: 0.0
39
+ These values are passed to the model's set_generation_params function. Other
40
+ values can be passed as well if they are supported by the model.
41
+ audio_window: The amount of audio to use as prompt for the next chunk. Default: 20
42
+ chunk_size: The size of each chunk in seconds. Default: 30
43
+
44
  Args:
45
+ data (Dict[str, Any]): The payload to generate audio for.
46
+
47
+ Raises:
48
+ ValueError: If chunk_size is less than audio_window
49
+ or if the duration is not a multiple of chunk_size - audio_window
50
+
51
+ Returns:
52
+ Dict[str, str]: A dictionary with the generated audio.
53
  """
54
+ prompt = data["prompt"]
55
+
56
+ generation_params = data.get("generation_params", {})
57
+
58
+ duration = generation_params.get("duration", 30)
59
+
60
+ if duration <= 30:
61
+ logger.info(f"Generating audio with duration {duration} in one go.")
62
+ self.model.set_generation_params(**generation_params)
63
+ final_audio = self.model.generate([prompt], progress=True)
64
+ else:
65
+ logger.info(f"Generating audio with duration {duration} in chunks.")
66
+
67
+ audio_window = data.get("audio_window", 20)
68
+ chunk_size = data.get("chunk_size", 30)
69
+ continuation = chunk_size - audio_window
70
+ final_duration = duration
71
+
72
+ if chunk_size < audio_window:
73
+ raise ValueError(
74
+ f"Chunk size {chunk_size} must be greater than audio window {audio_window}"
75
+ )
76
+
77
+ if (final_duration - chunk_size) % continuation != 0:
78
+ raise ValueError(
79
+ f"Duration ({duration} secs) - chunksize ({chunk_size} secs)"
80
+ f" must be a multiple of continuation ({continuation} secs)"
81
+ )
82
+
83
+ generation_params["duration"] = chunk_size
84
+ self.model.set_generation_params(**generation_params)
85
+
86
+ logger.info(
87
+ f"Generating total audio {final_duration} secs with chunks of {chunk_size} secs "
88
+ f"and continuation of {continuation} secs."
89
+ )
90
+
91
+ # Iniitalize final audio
92
+ logger.info(f"Initializing final audio with {chunk_size} secs of audio.")
93
+ final_audio = torch.zeros(
94
+ (
95
+ self.channels,
96
+ self.sample_rate * final_duration,
97
+ ),
98
+ dtype=torch.float,
99
+ ).to(self.device)
100
+
101
+ final_audio[
102
+ :,
103
+ : chunk_size * self.sample_rate,
104
+ ] = self.model.generate([prompt], progress=True)
105
+
106
+ n_hops = (final_duration - chunk_size) // continuation
107
+ for i_hop in range(n_hops):
108
+ logger.info(f"Generating audio for hop {i_hop}")
109
+
110
+ prompt_stop = chunk_size + i_hop * continuation
111
+ prompt_start = prompt_stop - audio_window
112
+
113
+ audio_prompt = final_audio[
114
+ :, prompt_start * self.sample_rate : prompt_stop * self.sample_rate
115
+ ].reshape(1, self.channels, -1)
116
+
117
+ output = self.model.generate_continuation(
118
+ audio_prompt,
119
+ self.sample_rate,
120
+ [prompt],
121
+ progress=True,
122
+ )
123
+
124
+ final_audio[
125
+ :,
126
+ prompt_stop
127
+ * self.sample_rate : (prompt_stop + continuation)
128
+ * self.sample_rate,
129
+ ] = output[..., audio_window * self.sample_rate :]
130
+ logger.info(
131
+ f"finished generating audio till {(prompt_stop + continuation)} secs."
132
+ )
133
 
134
+ return {"generated_audio": final_audio.cpu().numpy().transpose()}
local_enpoint_test.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from handler import EndpointHandler
4
+ from scipy.io.wavfile import write
5
+
6
+ # init handler
7
+ my_handler = EndpointHandler(path=".")
8
+
9
+ generation_params = {
10
+ "duration": 12,
11
+ }
12
+
13
+ # prepare sample payload
14
+ payload = {
15
+ "prompt": "rock, rock and rock",
16
+ "generation_params": generation_params,
17
+ "audio_window": 2,
18
+ "chunk_size": 4,
19
+ }
20
+
21
+ # test the handler
22
+ test = my_handler(payload)
23
+
24
+ write("test.wav", 32000, test["generated_audio"])