jhj0517 commited on
Commit
08d7176
1 Parent(s): ada247c

add audio loader

Browse files
Files changed (1) hide show
  1. modules/diarize/audio_loader.py +161 -0
modules/diarize/audio_loader.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from functools import lru_cache
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def exact_div(x, y):
11
+ assert x % y == 0
12
+ return x // y
13
+
14
+ # hard-coded audio hyperparameters
15
+ SAMPLE_RATE = 16000
16
+ N_FFT = 400
17
+ HOP_LENGTH = 160
18
+ CHUNK_LENGTH = 30
19
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
20
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
21
+
22
+ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
23
+ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
24
+ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
25
+
26
+
27
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
28
+ """
29
+ Open an audio file and read as mono waveform, resampling as necessary
30
+
31
+ Parameters
32
+ ----------
33
+ file: str
34
+ The audio file to open
35
+
36
+ sr: int
37
+ The sample rate to resample the audio if necessary
38
+
39
+ Returns
40
+ -------
41
+ A NumPy array containing the audio waveform, in float32 dtype.
42
+ """
43
+ try:
44
+ # Launches a subprocess to decode audio while down-mixing and resampling as necessary.
45
+ # Requires the ffmpeg CLI to be installed.
46
+ cmd = [
47
+ "ffmpeg",
48
+ "-nostdin",
49
+ "-threads",
50
+ "0",
51
+ "-i",
52
+ file,
53
+ "-f",
54
+ "s16le",
55
+ "-ac",
56
+ "1",
57
+ "-acodec",
58
+ "pcm_s16le",
59
+ "-ar",
60
+ str(sr),
61
+ "-",
62
+ ]
63
+ out = subprocess.run(cmd, capture_output=True, check=True).stdout
64
+ except subprocess.CalledProcessError as e:
65
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
66
+
67
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
68
+
69
+
70
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
71
+ """
72
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
73
+ """
74
+ if torch.is_tensor(array):
75
+ if array.shape[axis] > length:
76
+ array = array.index_select(
77
+ dim=axis, index=torch.arange(length, device=array.device)
78
+ )
79
+
80
+ if array.shape[axis] < length:
81
+ pad_widths = [(0, 0)] * array.ndim
82
+ pad_widths[axis] = (0, length - array.shape[axis])
83
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
84
+ else:
85
+ if array.shape[axis] > length:
86
+ array = array.take(indices=range(length), axis=axis)
87
+
88
+ if array.shape[axis] < length:
89
+ pad_widths = [(0, 0)] * array.ndim
90
+ pad_widths[axis] = (0, length - array.shape[axis])
91
+ array = np.pad(array, pad_widths)
92
+
93
+ return array
94
+
95
+
96
+ @lru_cache(maxsize=None)
97
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
98
+ """
99
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
100
+ Allows decoupling librosa dependency; saved using:
101
+
102
+ np.savez_compressed(
103
+ "mel_filters.npz",
104
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
105
+ )
106
+ """
107
+ assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
108
+ with np.load(
109
+ os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
110
+ ) as f:
111
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
112
+
113
+
114
+ def log_mel_spectrogram(
115
+ audio: Union[str, np.ndarray, torch.Tensor],
116
+ n_mels: int,
117
+ padding: int = 0,
118
+ device: Optional[Union[str, torch.device]] = None,
119
+ ):
120
+ """
121
+ Compute the log-Mel spectrogram of
122
+
123
+ Parameters
124
+ ----------
125
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
126
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
127
+
128
+ n_mels: int
129
+ The number of Mel-frequency filters, only 80 is supported
130
+
131
+ padding: int
132
+ Number of zero samples to pad to the right
133
+
134
+ device: Optional[Union[str, torch.device]]
135
+ If given, the audio tensor is moved to this device before STFT
136
+
137
+ Returns
138
+ -------
139
+ torch.Tensor, shape = (80, n_frames)
140
+ A Tensor that contains the Mel spectrogram
141
+ """
142
+ if not torch.is_tensor(audio):
143
+ if isinstance(audio, str):
144
+ audio = load_audio(audio)
145
+ audio = torch.from_numpy(audio)
146
+
147
+ if device is not None:
148
+ audio = audio.to(device)
149
+ if padding > 0:
150
+ audio = F.pad(audio, (0, padding))
151
+ window = torch.hann_window(N_FFT).to(audio.device)
152
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
153
+ magnitudes = stft[..., :-1].abs() ** 2
154
+
155
+ filters = mel_filters(audio.device, n_mels)
156
+ mel_spec = filters @ magnitudes
157
+
158
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
159
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
160
+ log_spec = (log_spec + 4.0) / 4.0
161
+ return log_spec