sanchit-gandhi HF staff commited on
Commit
1deba83
1 Parent(s): b4d4d63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -12,13 +12,9 @@ import gradio as gr
12
  import spaces
13
 
14
 
15
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
-
17
- model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
18
  processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
19
 
20
- if device == "cuda:0":
21
- model.to(device).half();
22
 
23
  class MusicgenStreamer(BaseStreamer):
24
  def __init__(
@@ -143,25 +139,26 @@ class MusicgenStreamer(BaseStreamer):
143
  sampling_rate = model.audio_encoder.config.sampling_rate
144
  frame_rate = model.audio_encoder.config.frame_rate
145
 
 
 
 
146
 
147
  @spaces.GPU
148
  def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0):
149
- inputs = processor(
150
- text=text_prompt,
151
- padding=True,
152
- return_tensors="pt",
153
- )
154
-
155
  max_new_tokens = int(frame_rate * audio_length_in_s)
156
  play_steps = int(frame_rate * play_steps_in_s)
157
 
158
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
159
-
160
  if device != model.device:
161
  model.to(device)
 
 
162
 
163
- if device == "cuda:0":
164
- model.to(device).half();
 
 
 
165
 
166
  streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
167
 
 
12
  import spaces
13
 
14
 
15
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", low_cpu_mem_usage=True)
 
 
16
  processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
17
 
 
 
18
 
19
  class MusicgenStreamer(BaseStreamer):
20
  def __init__(
 
139
  sampling_rate = model.audio_encoder.config.sampling_rate
140
  frame_rate = model.audio_encoder.config.frame_rate
141
 
142
+ target_dtype = np.int16
143
+ max_range = np.iinfo(target_dtype).max
144
+
145
 
146
  @spaces.GPU
147
  def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0):
 
 
 
 
 
 
148
  max_new_tokens = int(frame_rate * audio_length_in_s)
149
  play_steps = int(frame_rate * play_steps_in_s)
150
 
151
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
152
  if device != model.device:
153
  model.to(device)
154
+ if device == "cuda:0":
155
+ model.half();
156
 
157
+ inputs = processor(
158
+ text=text_prompt,
159
+ padding=True,
160
+ return_tensors="pt",
161
+ )
162
 
163
  streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
164