guyyariv commited on
Commit
44620f0
1 Parent(s): 04757be

AudioTokenDemo

Browse files
Files changed (2) hide show
  1. app.py +16 -2
  2. requirements.txt +1 -0
app.py CHANGED
@@ -8,6 +8,7 @@ from diffusers.models.attention_processor import LoRAAttnProcessor
8
  from diffusers import StableDiffusionPipeline
9
  import numpy as np
10
  import gradio as gr
 
11
 
12
 
13
  class AudioTokenWrapper(torch.nn.Module):
@@ -90,10 +91,23 @@ class AudioTokenWrapper(torch.nn.Module):
90
 
91
 
92
  def greet(audio):
93
- audio = audio[-1].astype(np.float32, order='C') / 32768.0
 
 
 
94
  if audio.ndim == 2:
95
  audio = audio.sum(axis=1) / 2
96
 
 
 
 
 
 
 
 
 
 
 
97
  weight_dtype = torch.float32
98
  prompt = 'a photo of <*>'
99
 
@@ -143,6 +157,6 @@ if __name__ == "__main__":
143
  outputs="image",
144
  title='AudioToken',
145
  description=description,
146
- examples=examples
147
  )
148
  demo.launch()
 
8
  from diffusers import StableDiffusionPipeline
9
  import numpy as np
10
  import gradio as gr
11
+ from scipy import signal
12
 
13
 
14
  class AudioTokenWrapper(torch.nn.Module):
 
91
 
92
 
93
  def greet(audio):
94
+ sample_rate, audio = audio
95
+ audio = audio.astype(np.float32, order='C') / 32768.0
96
+ desired_sample_rate = 16000
97
+
98
  if audio.ndim == 2:
99
  audio = audio.sum(axis=1) / 2
100
 
101
+ if sample_rate != desired_sample_rate:
102
+ # Calculate the resampling ratio
103
+ resample_ratio = desired_sample_rate / sample_rate
104
+
105
+ # Determine the new length of the audio data after downsampling
106
+ new_length = int(len(audio) * resample_ratio)
107
+
108
+ # Downsample the audio data using resample
109
+ audio = signal.resample(audio, new_length)
110
+
111
  weight_dtype = torch.float32
112
  prompt = 'a photo of <*>'
113
 
 
157
  outputs="image",
158
  title='AudioToken',
159
  description=description,
160
+ # examples=examples
161
  )
162
  demo.launch()
requirements.txt CHANGED
@@ -9,3 +9,4 @@ Pillow
9
  pandas
10
  torchaudio
11
  datasets
 
 
9
  pandas
10
  torchaudio
11
  datasets
12
+ scipy