Irsh Vijayvargia commited on
Commit
42a4544
1 Parent(s): 3e34e2e

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import numpy as np
4
+ import os
5
+ import webrtcvad
6
+ import wave
7
+ import contextlib
8
+ import gradio as gr
9
+
10
+ from utils.VAD_segments import *
11
+ from utils.hparam import hparam as hp
12
+ from utils.speech_embedder_net import *
13
+ from utils.evaluation import *
14
+
15
+ def read_wave(audio_data):
16
+ """Reads audio data and returns (PCM audio data, sample rate).
17
+ Assumes the input is a tuple (sample_rate, numpy_array).
18
+ If the sample rate is unsupported, resamples to 16000 Hz.
19
+ """
20
+ sample_rate, data = audio_data
21
+
22
+ # Ensure data is in the correct shape
23
+ assert len(data.shape) == 1, "Audio data must be a 1D array"
24
+
25
+ # Convert to floating point if necessary
26
+ if not np.issubdtype(data.dtype, np.floating):
27
+ data = data.astype(np.float32) / np.iinfo(data.dtype).max
28
+
29
+ # Supported sample rates
30
+ supported_sample_rates = (8000, 16000, 32000, 48000)
31
+
32
+ # If sample rate is not supported, resample to 16000 Hz
33
+ if sample_rate not in supported_sample_rates:
34
+ data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)
35
+ sample_rate = 16000
36
+
37
+ # Convert numpy array to PCM format
38
+ pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()
39
+
40
+ return data, pcm_data
41
+
42
+
43
+ def VAD_chunk(aggressiveness, data):
44
+ audio, byte_audio = read_wave(data)
45
+ vad = webrtcvad.Vad(int(aggressiveness))
46
+ frames = frame_generator(20, byte_audio, hp.data.sr)
47
+ frames = list(frames)
48
+ times = vad_collector(hp.data.sr, 20, 200, vad, frames)
49
+ speech_times = []
50
+ speech_segs = []
51
+ for i, time in enumerate(times):
52
+ start = np.round(time[0],decimals=2)
53
+ end = np.round(time[1],decimals=2)
54
+ j = start
55
+ while j + .4 < end:
56
+ end_j = np.round(j+.4,decimals=2)
57
+ speech_times.append((j, end_j))
58
+ speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
59
+ j = end_j
60
+ else:
61
+ speech_times.append((j, end))
62
+ speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
63
+ return speech_times, speech_segs
64
+
65
+
66
+ def get_embedding(data, embedder_net, device, n_threshold=-1):
67
+ times, segs = VAD_chunk(0, data)
68
+ if not segs:
69
+ print(f'No voice activity detected')
70
+ return None
71
+ concat_seg = concat_segs(times, segs)
72
+ if not concat_seg:
73
+ print(f'No concatenated segments')
74
+ return None
75
+ STFT_frames = get_STFTs(concat_seg)
76
+ if not STFT_frames:
77
+ #print(f'No STFT frames')
78
+ return None
79
+ STFT_frames = np.stack(STFT_frames, axis=2)
80
+ STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
81
+
82
+ with torch.no_grad():
83
+ embeddings = embedder_net(STFT_frames)
84
+ embeddings = embeddings[:n_threshold, :]
85
+
86
+ avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
87
+ return avg_embedding
88
+
89
+
90
+ model_path = "./speech_id_checkpoint/saved_02.model"
91
+
92
+
93
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
94
+
95
+ embedder_net = SpeechEmbedder().to(device)
96
+ embedder_net.load_state_dict(torch.load(model_path, map_location=device))
97
+ embedder_net.eval()
98
+
99
+ def process_audio(audio1, audio2, threshold):
100
+ e1 = get_embedding(audio1, embedder_net, device)
101
+ if(e1 is None):
102
+ return "No Voice Detected in file 1"
103
+ e2 = get_embedding(audio2, embedder_net, device)
104
+ if(e2 is None):
105
+ return "No Voice Detected in file 2"
106
+
107
+ cosi = cosine_similarity(e1, e2)
108
+
109
+ if(cosi > threshold):
110
+ return f"Same Speaker"
111
+ else:
112
+ return f"Different Speaker"
113
+
114
+ # Define the Gradio interface
115
+ def gradio_interface(audio1, audio2, threshold):
116
+ output_text = process_audio(audio1, audio2, threshold)
117
+ return output_text
118
+
119
+ # Create the Gradio interface with microphone inputs
120
+ iface = gr.Interface(
121
+ fn=gradio_interface,
122
+ inputs=[gr.Audio("microphone", type="numpy", label="Audio File 1"),
123
+ gr.Audio("microphone", type="numpy", label="Audio File 2"),
124
+ gr.Slider(0.0, 1.0, value=0.85, step=0.01, label="Threshold")
125
+ ],
126
+ outputs="text",
127
+ title="Gujarati Text Independent Speaker Verification",
128
+ description="Record two audio files and get the text output from the model."
129
+ )
130
+
131
+ # Launch the interface
132
+ iface.launch(share=False)
config/config.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training: !!bool "false"
2
+ device: "mps"
3
+ unprocessed_data: './DATA_DIR/*/*.wav'
4
+ ---
5
+ data:
6
+ train_path: './train_tisv'
7
+ train_path_unprocessed: './TIMIT/TRAIN/*/*/*.wav'
8
+ test_path: './test_tisv'
9
+ test_path_unprocessed: './TIMIT/TEST/*/*/*.wav'
10
+ data_preprocessed: !!bool "true"
11
+ sr: 16000
12
+ nfft: 512 #For mel spectrogram preprocess
13
+ window: 0.025 #(s)
14
+ hop: 0.01 #(s)
15
+ nmels: 40 #Number of mel energies
16
+ tisv_frame: 180 #Max number of time steps in input after preprocess
17
+ ---
18
+ model:
19
+ hidden: 768 #Number of LSTM hidden layer units
20
+ num_layer: 3 #Number of LSTM layers
21
+ proj: 256 #Embedding size
22
+ model_path: './speech_id_checkpoint/ckpt_epoch_840_batch_id_6.pth' #Model path for testing, inference, or resuming training
23
+ ---
24
+ train:
25
+ N : 4 #Number of speakers in batch
26
+ M : 6 #Number of utterances per speaker
27
+ num_workers: 0 #number of workers for dataloader
28
+ lr: 0.01
29
+ epochs: 1000 #Max training speaker epoch
30
+ log_interval: 30 #Epochs before printing progress
31
+ log_file: './speech_id_checkpoint/Stats'
32
+ checkpoint_interval: 100 #Save model after x speaker epochs
33
+ checkpoint_dir: './speech_id_checkpoint'
34
+ restore: !!bool "true" #Resume training from previous model path
35
+ ---
36
+ test:
37
+ N : 4 #Number of speakers in batch
38
+ M : 6 #Number of utterances per speaker
39
+ num_workers: 8 #number of workers for data laoder
40
+ epochs: 10 #testing speaker epochs
gradio.ipynb ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "23237138-936a-44b4-9eb6-f16045d2c91d",
6
+ "metadata": {},
7
+ "source": [
8
+ "### **Gradio Demo | LSTM Speaker Embedding Model for Gujarati Speaker Verification**\n",
9
+ "****\n",
10
+ "**Author:** Irsh Vijay <br>\n",
11
+ "**Organization**: Wadhwani Institute for Artificial Intelligence <br>\n",
12
+ "****\n",
13
+ "This notebook has the required code to run a gradio demo."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 8,
19
+ "id": "1d2cfd8b-9498-4236-9d32-718e9e0597cb",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "import torch\n",
24
+ "import librosa\n",
25
+ "import numpy as np\n",
26
+ "import os\n",
27
+ "import webrtcvad\n",
28
+ "import wave\n",
29
+ "import contextlib\n",
30
+ "\n",
31
+ "from utils.VAD_segments import *\n",
32
+ "from utils.hparam import hparam as hp\n",
33
+ "from utils.speech_embedder_net import *\n",
34
+ "from utils.evaluation import *"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 9,
40
+ "id": "3e9e1006-83d2-4492-a210-26b2c3717cd5",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def read_wave(audio_data):\n",
45
+ " \"\"\"Reads audio data and returns (PCM audio data, sample rate).\n",
46
+ " Assumes the input is a tuple (sample_rate, numpy_array).\n",
47
+ " If the sample rate is unsupported, resamples to 16000 Hz.\n",
48
+ " \"\"\"\n",
49
+ " sample_rate, data = audio_data\n",
50
+ "\n",
51
+ " # Ensure data is in the correct shape\n",
52
+ " assert len(data.shape) == 1, \"Audio data must be a 1D array\"\n",
53
+ "\n",
54
+ " # Convert to floating point if necessary\n",
55
+ " if not np.issubdtype(data.dtype, np.floating):\n",
56
+ " data = data.astype(np.float32) / np.iinfo(data.dtype).max\n",
57
+ " \n",
58
+ " # Supported sample rates\n",
59
+ " supported_sample_rates = (8000, 16000, 32000, 48000)\n",
60
+ " \n",
61
+ " # If sample rate is not supported, resample to 16000 Hz\n",
62
+ " if sample_rate not in supported_sample_rates:\n",
63
+ " data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)\n",
64
+ " sample_rate = 16000\n",
65
+ " \n",
66
+ " # Convert numpy array to PCM format\n",
67
+ " pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()\n",
68
+ "\n",
69
+ " return data, pcm_data"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 10,
75
+ "id": "0b56a2fc-83c3-4b36-95b8-5f1b656150ed",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "def VAD_chunk(aggressiveness, data):\n",
80
+ " audio, byte_audio = read_wave(data)\n",
81
+ " vad = webrtcvad.Vad(int(aggressiveness))\n",
82
+ " frames = frame_generator(20, byte_audio, hp.data.sr)\n",
83
+ " frames = list(frames)\n",
84
+ " times = vad_collector(hp.data.sr, 20, 200, vad, frames)\n",
85
+ " speech_times = []\n",
86
+ " speech_segs = []\n",
87
+ " for i, time in enumerate(times):\n",
88
+ " start = np.round(time[0],decimals=2)\n",
89
+ " end = np.round(time[1],decimals=2)\n",
90
+ " j = start\n",
91
+ " while j + .4 < end:\n",
92
+ " end_j = np.round(j+.4,decimals=2)\n",
93
+ " speech_times.append((j, end_j))\n",
94
+ " speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])\n",
95
+ " j = end_j\n",
96
+ " else:\n",
97
+ " speech_times.append((j, end))\n",
98
+ " speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])\n",
99
+ " return speech_times, speech_segs"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 11,
105
+ "id": "72f257cf-7d3f-4ec5-944a-57779ba377e6",
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "def get_embedding(data, embedder_net, device, n_threshold=-1):\n",
110
+ " times, segs = VAD_chunk(0, data)\n",
111
+ " if not segs:\n",
112
+ " print(f'No voice activity detected')\n",
113
+ " return None\n",
114
+ " concat_seg = concat_segs(times, segs)\n",
115
+ " if not concat_seg:\n",
116
+ " print(f'No concatenated segments')\n",
117
+ " return None\n",
118
+ " STFT_frames = get_STFTs(concat_seg)\n",
119
+ " if not STFT_frames:\n",
120
+ " #print(f'No STFT frames')\n",
121
+ " return None\n",
122
+ " STFT_frames = np.stack(STFT_frames, axis=2)\n",
123
+ " STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)\n",
124
+ "\n",
125
+ " with torch.no_grad():\n",
126
+ " embeddings = embedder_net(STFT_frames)\n",
127
+ " embeddings = embeddings[:n_threshold, :]\n",
128
+ " \n",
129
+ " avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()\n",
130
+ " return avg_embedding"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 12,
136
+ "id": "200df766-407d-4367-b0fb-7a6118653731",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "model_path = \"./speech_id_checkpoint/saved_01.model\""
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 13,
146
+ "id": "db7613e6-67a8-4920-a999-caca4a0de360",
147
+ "metadata": {},
148
+ "outputs": [
149
+ {
150
+ "data": {
151
+ "text/plain": [
152
+ "SpeechEmbedder(\n",
153
+ " (LSTM_stack): LSTM(40, 768, num_layers=3, batch_first=True)\n",
154
+ " (projection): Linear(in_features=768, out_features=256, bias=True)\n",
155
+ ")"
156
+ ]
157
+ },
158
+ "execution_count": 13,
159
+ "metadata": {},
160
+ "output_type": "execute_result"
161
+ }
162
+ ],
163
+ "source": [
164
+ "device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n",
165
+ "\n",
166
+ "embedder_net = SpeechEmbedder().to(device)\n",
167
+ "embedder_net.load_state_dict(torch.load(model_path, map_location=device))\n",
168
+ "embedder_net.eval()"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 14,
174
+ "id": "8a7dd9bd-7b40-41f9-8e2f-d68be18f2111",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "import gradio as gr"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 28,
184
+ "id": "bd6c073d-eab8-4ae6-8ba6-d90a0ec54c0e",
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Running on local URL: http://127.0.0.1:7868\n",
192
+ "\n",
193
+ "To create a public link, set `share=True` in `launch()`.\n"
194
+ ]
195
+ },
196
+ {
197
+ "data": {
198
+ "text/html": [
199
+ "<div><iframe src=\"http://127.0.0.1:7868/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
200
+ ],
201
+ "text/plain": [
202
+ "<IPython.core.display.HTML object>"
203
+ ]
204
+ },
205
+ "metadata": {},
206
+ "output_type": "display_data"
207
+ },
208
+ {
209
+ "data": {
210
+ "text/plain": []
211
+ },
212
+ "execution_count": 28,
213
+ "metadata": {},
214
+ "output_type": "execute_result"
215
+ }
216
+ ],
217
+ "source": [
218
+ "def process_audio(audio1, audio2, threshold):\n",
219
+ " e1 = get_embedding(audio1, embedder_net, device)\n",
220
+ " if(e1 is None):\n",
221
+ " return \"No Voice Detected in file 1\"\n",
222
+ " e2 = get_embedding(audio2, embedder_net, device)\n",
223
+ " if(e2 is None):\n",
224
+ " return \"No Voice Detected in file 2\"\n",
225
+ "\n",
226
+ " cosi = cosine_similarity(e1, e2)\n",
227
+ "\n",
228
+ " if(cosi > threshold):\n",
229
+ " return f\"Same Speaker\" \n",
230
+ " else:\n",
231
+ " return f\"Different Speaker\" \n",
232
+ "\n",
233
+ "# Define the Gradio interface\n",
234
+ "def gradio_interface(audio1, audio2, threshold):\n",
235
+ " output_text = process_audio(audio1, audio2, threshold)\n",
236
+ " return output_text\n",
237
+ "\n",
238
+ "# Create the Gradio interface with microphone inputs\n",
239
+ "iface = gr.Interface(\n",
240
+ " fn=gradio_interface,\n",
241
+ " inputs=[gr.Audio(\"microphone\", type=\"numpy\", label=\"Audio File 1\"),\n",
242
+ " gr.Audio(\"microphone\", type=\"numpy\", label=\"Audio File 2\"),\n",
243
+ " gr.Slider(0.0, 1.0, value=0.85, step=0.01, label=\"Threshold\")\n",
244
+ " ],\n",
245
+ " outputs=\"text\",\n",
246
+ " title=\"LSTM Based Speaker Verification\",\n",
247
+ " description=\"Record two audio files and get the text output from the model.\"\n",
248
+ ")\n",
249
+ "\n",
250
+ "# Launch the interface\n",
251
+ "iface.launch(share=False)"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "id": "a098495c-9e7b-4232-86fc-55a1890c5e27",
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": []
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "id": "b99a253e-9b91-4210-b934-8bd1b6a2d912",
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": []
269
+ }
270
+ ],
271
+ "metadata": {
272
+ "kernelspec": {
273
+ "display_name": "Python 3 (ipykernel)",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "codemirror_mode": {
279
+ "name": "ipython",
280
+ "version": 3
281
+ },
282
+ "file_extension": ".py",
283
+ "mimetype": "text/x-python",
284
+ "name": "python",
285
+ "nbconvert_exporter": "python",
286
+ "pygments_lexer": "ipython3",
287
+ "version": "3.9.19"
288
+ }
289
+ },
290
+ "nbformat": 4,
291
+ "nbformat_minor": 5
292
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ librosa
3
+ numpy
4
+ webrtcvad
5
+ wave
6
+ contextlib
7
+ gradio
8
+ PyYAML
speech_id_checkpoint/saved_02.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51b96ce4d80a01ebe039ed6bc67c1a9731315742d5814fed842d4a22785c5836
3
+ size 48543874
utils/.ipynb_checkpoints/VAD_segments-checkpoint.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Dec 18 16:22:41 2018
5
+
6
+ @author: Harry
7
+ Modified from https://github.com/wiseman/py-webrtcvad/blob/master/example.py
8
+ """
9
+
10
+ import collections
11
+ import contextlib
12
+ import numpy as np
13
+ import sys
14
+ import librosa
15
+ import wave
16
+
17
+ import webrtcvad
18
+
19
+ from utils.hparam import hparam as hp
20
+
21
+ def read_wave(path, sr):
22
+ """Reads a .wav file.
23
+ Takes the path, and returns (PCM audio data, sample rate).
24
+ Assumes sample width == 2
25
+ """
26
+ with contextlib.closing(wave.open(path, 'rb')) as wf:
27
+ num_channels = wf.getnchannels()
28
+ assert num_channels == 1
29
+ sample_width = wf.getsampwidth()
30
+ assert sample_width == 2
31
+ sample_rate = wf.getframerate()
32
+ assert sample_rate in (8000, 16000, 32000, 48000)
33
+ pcm_data = wf.readframes(wf.getnframes())
34
+ data, _ = librosa.load(path, sr=sr)
35
+ assert len(data.shape) == 1
36
+ assert sr in (8000, 16000, 32000, 48000)
37
+ return data, pcm_data
38
+
39
+ class Frame(object):
40
+ """Represents a "frame" of audio data."""
41
+ def __init__(self, bytes, timestamp, duration):
42
+ self.bytes = bytes
43
+ self.timestamp = timestamp
44
+ self.duration = duration
45
+
46
+
47
+ def frame_generator(frame_duration_ms, audio, sample_rate):
48
+ """Generates audio frames from PCM audio data.
49
+ Takes the desired frame duration in milliseconds, the PCM data, and
50
+ the sample rate.
51
+ Yields Frames of the requested duration.
52
+ """
53
+ n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
54
+ offset = 0
55
+ timestamp = 0.0
56
+ duration = (float(n) / sample_rate) / 2.0
57
+ while offset + n < len(audio):
58
+ yield Frame(audio[offset:offset + n], timestamp, duration)
59
+ timestamp += duration
60
+ offset += n
61
+
62
+
63
+ def vad_collector(sample_rate, frame_duration_ms,
64
+ padding_duration_ms, vad, frames):
65
+ """Filters out non-voiced audio frames.
66
+ Given a webrtcvad.Vad and a source of audio frames, yields only
67
+ the voiced audio.
68
+ Uses a padded, sliding window algorithm over the audio frames.
69
+ When more than 90% of the frames in the window are voiced (as
70
+ reported by the VAD), the collector triggers and begins yielding
71
+ audio frames. Then the collector waits until 90% of the frames in
72
+ the window are unvoiced to detrigger.
73
+ The window is padded at the front and back to provide a small
74
+ amount of silence or the beginnings/endings of speech around the
75
+ voiced frames.
76
+ Arguments:
77
+ sample_rate - The audio sample rate, in Hz.
78
+ frame_duration_ms - The frame duration in milliseconds.
79
+ padding_duration_ms - The amount to pad the window, in milliseconds.
80
+ vad - An instance of webrtcvad.Vad.
81
+ frames - a source of audio frames (sequence or generator).
82
+ Returns: A generator that yields PCM audio data.
83
+ """
84
+ num_padding_frames = int(padding_duration_ms / frame_duration_ms)
85
+ # We use a deque for our sliding window/ring buffer.
86
+ ring_buffer = collections.deque(maxlen=num_padding_frames)
87
+ # We have two states: TRIGGERED and NOTTRIGGERED. We start in the
88
+ # NOTTRIGGERED state.
89
+ triggered = False
90
+
91
+ voiced_frames = []
92
+ for frame in frames:
93
+ is_speech = vad.is_speech(frame.bytes, sample_rate)
94
+
95
+ if not triggered:
96
+ ring_buffer.append((frame, is_speech))
97
+ num_voiced = len([f for f, speech in ring_buffer if speech])
98
+ # If we're NOTTRIGGERED and more than 90% of the frames in
99
+ # the ring buffer are voiced frames, then enter the
100
+ # TRIGGERED state.
101
+ if num_voiced > 0.9 * ring_buffer.maxlen:
102
+ triggered = True
103
+ start = ring_buffer[0][0].timestamp
104
+ # We want to yield all the audio we see from now until
105
+ # we are NOTTRIGGERED, but we have to start with the
106
+ # audio that's already in the ring buffer.
107
+ for f, s in ring_buffer:
108
+ voiced_frames.append(f)
109
+ ring_buffer.clear()
110
+ else:
111
+ # We're in the TRIGGERED state, so collect the audio data
112
+ # and add it to the ring buffer.
113
+ voiced_frames.append(frame)
114
+ ring_buffer.append((frame, is_speech))
115
+ num_unvoiced = len([f for f, speech in ring_buffer if not speech])
116
+ # If more than 90% of the frames in the ring buffer are
117
+ # unvoiced, then enter NOTTRIGGERED and yield whatever
118
+ # audio we've collected.
119
+ if num_unvoiced > 0.9 * ring_buffer.maxlen:
120
+ triggered = False
121
+ yield (start, frame.timestamp + frame.duration)
122
+ ring_buffer.clear()
123
+ voiced_frames = []
124
+ # If we have any leftover voiced audio when we run out of input,
125
+ # yield it.
126
+ if voiced_frames:
127
+ yield (start, frame.timestamp + frame.duration)
128
+
129
+
130
+ def VAD_chunk(aggressiveness, path):
131
+ audio, byte_audio = read_wave(path, sr=hp.data.sr)
132
+ vad = webrtcvad.Vad(int(aggressiveness))
133
+ frames = frame_generator(20, byte_audio, hp.data.sr)
134
+ frames = list(frames)
135
+ times = vad_collector(hp.data.sr, 20, 200, vad, frames)
136
+ speech_times = []
137
+ speech_segs = []
138
+ for i, time in enumerate(times):
139
+ start = np.round(time[0],decimals=2)
140
+ end = np.round(time[1],decimals=2)
141
+ j = start
142
+ while j + .4 < end:
143
+ end_j = np.round(j+.4,decimals=2)
144
+ speech_times.append((j, end_j))
145
+ speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
146
+ j = end_j
147
+ else:
148
+ speech_times.append((j, end))
149
+ speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
150
+ return speech_times, speech_segs
151
+
152
+ if __name__ == '__main__':
153
+ speech_times, speech_segs = VAD_chunk(sys.argv[1], sys.argv[2])
utils/.ipynb_checkpoints/__init__-checkpoint.py ADDED
File without changes
utils/.ipynb_checkpoints/data_load-checkpoint.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly copied from https://github.com/HarryVolek/PyTorch_Speaker_Verification
3
+ """
4
+ import glob
5
+ import numpy as np
6
+ import os
7
+ import random
8
+ from random import shuffle
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+
12
+ from utils.hparam import hparam as hp
13
+ from utils.utils import mfccs_and_spec
14
+
15
+ class GujaratiSpeakerVerificationDataset(Dataset):
16
+
17
+ def __init__(self, shuffle=True, utter_start=0, split='train'):
18
+ # data path
19
+ if split!='val':
20
+ self.path = hp.data.train_path
21
+ self.utter_num = hp.train.M
22
+ else:
23
+ self.path = hp.data.test_path
24
+ self.utter_num = hp.test.M
25
+ self.file_list = os.listdir(self.path)
26
+ self.shuffle=shuffle
27
+ self.utter_start = utter_start
28
+ self.split = split
29
+
30
+ def __len__(self):
31
+ return len(self.file_list)
32
+
33
+ def __getitem__(self, idx):
34
+
35
+ np_file_list = os.listdir(self.path)
36
+
37
+ if self.shuffle:
38
+ selected_file = random.sample(np_file_list, 1)[0] # select random speaker
39
+ else:
40
+ selected_file = np_file_list[idx]
41
+
42
+ utters = np.load(os.path.join(self.path, selected_file))
43
+
44
+ # load utterance spectrogram of selected speaker
45
+ if self.shuffle:
46
+ utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
47
+ utterance = utters[utter_index]
48
+ else:
49
+ utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
50
+
51
+ utterance = utterance[:,:,:160] # TODO implement variable length batch size
52
+
53
+ utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
54
+ return utterance
55
+
56
+ def __repr__(self):
57
+ return f"{self.__class__.__name__}(split={self.split!r}, num_speakers={len(self.file_list)}, num_utterances={self.utter_num})"
utils/.ipynb_checkpoints/evaluation-checkpoint.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from tqdm.auto import tqdm
3
+ import os
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ import random
8
+ from numpy.linalg import norm
9
+
10
+ from utils.VAD_segments import VAD_chunk
11
+ from utils.hparam import hparam as hp
12
+
13
+ class GujaratiSpeakerVerificationDatasetTest(Dataset):
14
+ def __init__(self, path, shuffle=True, utter_start=0):
15
+ # data path
16
+ self.path = path
17
+ self.file_list = os.listdir(self.path)
18
+ self.shuffle=shuffle
19
+ self.utter_start = utter_start
20
+ self.utter_num = 4
21
+
22
+ def __len__(self):
23
+ return len(self.file_list)
24
+
25
+ def __getitem__(self, idx):
26
+
27
+ np_file_list = self.file_list
28
+
29
+ selected_file = np_file_list[idx]
30
+
31
+ utters = np.load(os.path.join(self.path, selected_file))
32
+
33
+ # load utterance spectrogram of selected speaker
34
+ if self.shuffle:
35
+ utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
36
+ utterance = utters[utter_index]
37
+ else:
38
+ utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
39
+
40
+ utterance = utterance[:,:,:160] # TODO implement variable length batch size
41
+
42
+ utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
43
+ return utterance
44
+
45
+ def concat_segs(times, segs):
46
+ concat_seg = []
47
+ seg_concat = segs[0]
48
+ for i in range(0, len(times)-1):
49
+ if times[i][1] == times[i+1][0]:
50
+ seg_concat = np.concatenate((seg_concat, segs[i+1]))
51
+ else:
52
+ concat_seg.append(seg_concat)
53
+ seg_concat = segs[i+1]
54
+ else:
55
+ concat_seg.append(seg_concat)
56
+ return concat_seg
57
+
58
+
59
+ def get_STFTs(segs):
60
+ sr = 16000
61
+ STFT_frames = []
62
+ for seg in segs:
63
+ S = librosa.core.stft(y=seg, n_fft=hp.data.nfft,
64
+ win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr))
65
+ S = np.abs(S)**2
66
+ mel_basis = librosa.filters.mel(sr=sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels)
67
+ S = np.log10(np.dot(mel_basis, S) + 1e-6)
68
+ for j in range(0, S.shape[1], int(.12/hp.data.hop)):
69
+ if j + 24 < S.shape[1]:
70
+ STFT_frames.append(S[:, j:j+24])
71
+ else:
72
+ break
73
+ return STFT_frames
74
+
75
+
76
+ def get_embedding(file_path, embedder_net, device, n_threshold=-1):
77
+ times, segs = VAD_chunk(2, file_path)
78
+ if not segs:
79
+ print(f'No voice activity detected in {file_path}')
80
+ return None
81
+ concat_seg = concat_segs(times, segs)
82
+ if not concat_seg:
83
+ print(f'No concatenated segments for {file_path}')
84
+ return None
85
+ STFT_frames = get_STFTs(concat_seg)
86
+ if not STFT_frames:
87
+ #print(f'No STFT frames for {file_path}')
88
+ return None
89
+ STFT_frames = np.stack(STFT_frames, axis=2)
90
+ STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
91
+
92
+ with torch.no_grad():
93
+ embeddings = embedder_net(STFT_frames)
94
+ embeddings = embeddings[:n_threshold, :]
95
+
96
+ avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
97
+ return avg_embedding
98
+
99
+ def get_speaker_embeddings_listdir(embedder_net, device, list_dir, k):
100
+ speaker_embeddings = {}
101
+ for speaker_name in tqdm(list_dir, leave = False):
102
+ speaker_dir = speaker_name
103
+ if os.path.isdir(speaker_dir) and speaker_dir[0] != ".DS_Store":
104
+ speaker_embeddings[speaker_name] = []
105
+ for i in range(10):
106
+ embeddings = []
107
+ audio_files = [os.path.join(speaker_dir, f) for f in os.listdir(speaker_dir) if f.endswith('.wav')]
108
+ random.shuffle(audio_files)
109
+ count = 0
110
+ iter_ = 0
111
+ while(count <= k):
112
+ file_path = audio_files[iter_]
113
+ embedding = get_embedding(file_path, embedder_net, device)
114
+ try:
115
+ _ = embedding.shape
116
+ embeddings.append(embedding)
117
+ count+=1
118
+ iter_+=1
119
+ except:
120
+ iter_+=1
121
+ speaker_embeddings[speaker_name].append(np.mean(embeddings, axis=0))
122
+ return speaker_embeddings
123
+
124
+ def create_pairs(speaker_embeddings):
125
+ pairs = []
126
+ labels = []
127
+ speakers = list(speaker_embeddings.keys())
128
+
129
+ for i in range(len(speakers)):
130
+ for j in range(len(speakers)):
131
+ for k1 in range(10):
132
+ for k2 in range(10):
133
+ emb1 = speaker_embeddings[speakers[i]][k1]
134
+ emb2 = speaker_embeddings[speakers[j]][k2]
135
+ pairs.append((emb1, emb2))
136
+ if i == j and not((emb1 == emb2).all()):
137
+ labels.append(1) # Same speaker
138
+ else:
139
+ labels.append(0) # Different speakers
140
+ return pairs, labels
141
+
142
+ class EmbeddingPairDataset(Dataset):
143
+ def __init__(self, pairs, labels):
144
+ self.pairs = pairs
145
+ self.labels = labels
146
+
147
+ def __len__(self):
148
+ return len(self.pairs)
149
+
150
+ def __getitem__(self, idx):
151
+ emb1, emb2 = self.pairs[idx]
152
+ label = self.labels[idx]
153
+
154
+ emb1, emb2 = torch.tensor(emb1, dtype=torch.float32), torch.tensor(emb2, dtype=torch.float32)
155
+
156
+ concatenated = torch.cat((emb1, emb2), dim=1)
157
+
158
+ return concatenated.squeeze(), torch.tensor(label, dtype=torch.float32)
159
+
160
+ def __len__(self):
161
+ return len(self.labels)
162
+
163
+ def __repr__(self):
164
+ return f"{self.__class__.__name__}(length={self.__len__()})"
165
+
166
+
167
+ def cosine_similarity(A, B):
168
+ A = A.flatten().astype(np.float64)
169
+ B = B.flatten().astype(np.float64)
170
+ cosine = np.dot(A,B)/(norm(A)*norm(B))
171
+ return cosine
172
+
173
+
174
+ def create_subset(dataset, num_zeros):
175
+ pairs = dataset.pairs
176
+ labels = dataset.labels
177
+
178
+ pairs_1 = [pairs[i] for i in range(len(pairs)) if labels[i] == 1]
179
+ labels_1 = [labels[i] for i in range(len(labels)) if labels[i] == 1]
180
+
181
+ pairs_0 = [pairs[i] for i in range(len(pairs)) if labels[i] == 0]
182
+ labels_0 = [labels[i] for i in range(len(labels)) if labels[i] == 0]
183
+
184
+ num_zeros = min(num_zeros, len(pairs_0))
185
+
186
+ pairs_0 = pairs_0[:num_zeros]
187
+ labels_0 = labels_0[:num_zeros]
188
+
189
+ filtered_pairs = pairs_1 + pairs_0
190
+ filtered_labels = labels_1 + labels_0
191
+
192
+ return filtered_pairs, filtered_labels
utils/.ipynb_checkpoints/hparam-checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #!/usr/bin/env python
3
+
4
+ import yaml
5
+
6
+ def load_hparam(filename):
7
+ stream = open(filename, 'r')
8
+ docs = yaml.load_all(stream, Loader=yaml.Loader)
9
+ hparam_dict = dict()
10
+ for doc in docs:
11
+ for k, v in doc.items():
12
+ hparam_dict[k] = v
13
+ return hparam_dict
14
+
15
+ def merge_dict(user, default):
16
+ if isinstance(user, dict) and isinstance(default, dict):
17
+ for k, v in default.items():
18
+ if k not in user:
19
+ user[k] = v
20
+ else:
21
+ user[k] = merge_dict(user[k], v)
22
+ return user
23
+
24
+
25
+ class Dotdict(dict):
26
+ """
27
+ a dictionary that supports dot notation
28
+ as well as dictionary access notation
29
+ usage: d = DotDict() or d = DotDict({'val1':'first'})
30
+ set attributes: d.val2 = 'second' or d['val2'] = 'second'
31
+ get attributes: d.val2 or d['val2']
32
+ """
33
+ __getattr__ = dict.__getitem__
34
+ __setattr__ = dict.__setitem__
35
+ __delattr__ = dict.__delitem__
36
+
37
+ def __init__(self, dct=None):
38
+ dct = dict() if not dct else dct
39
+ for key, value in dct.items():
40
+ if hasattr(value, 'keys'):
41
+ value = Dotdict(value)
42
+ self[key] = value
43
+
44
+
45
+ class Hparam(Dotdict):
46
+
47
+ def __init__(self, file='config/config.yaml'):
48
+ super(Dotdict, self).__init__()
49
+ hp_dict = load_hparam(file)
50
+ hp_dotdict = Dotdict(hp_dict)
51
+ for k, v in hp_dotdict.items():
52
+ setattr(self, k, v)
53
+
54
+ __getattr__ = Dotdict.__getitem__
55
+ __setattr__ = Dotdict.__setitem__
56
+ __delattr__ = Dotdict.__delitem__
57
+
58
+
59
+ hparam = Hparam()
utils/.ipynb_checkpoints/kan-checkpoint.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class KANLinear(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ out_features,
11
+ grid_size=5,
12
+ spline_order=3,
13
+ scale_noise=0.1,
14
+ scale_base=1.0,
15
+ scale_spline=1.0,
16
+ enable_standalone_scale_spline=True,
17
+ base_activation=torch.nn.SiLU,
18
+ grid_eps=0.02,
19
+ grid_range=[-1, 1],
20
+ ):
21
+ super(KANLinear, self).__init__()
22
+ self.in_features = in_features
23
+ self.out_features = out_features
24
+ self.grid_size = grid_size
25
+ self.spline_order = spline_order
26
+
27
+ h = (grid_range[1] - grid_range[0]) / grid_size
28
+ grid = (
29
+ (
30
+ torch.arange(-spline_order, grid_size + spline_order + 1) * h
31
+ + grid_range[0]
32
+ )
33
+ .expand(in_features, -1)
34
+ .contiguous()
35
+ )
36
+ self.register_buffer("grid", grid)
37
+
38
+ self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
39
+ self.spline_weight = torch.nn.Parameter(
40
+ torch.Tensor(out_features, in_features, grid_size + spline_order)
41
+ )
42
+ if enable_standalone_scale_spline:
43
+ self.spline_scaler = torch.nn.Parameter(
44
+ torch.Tensor(out_features, in_features)
45
+ )
46
+
47
+ self.scale_noise = scale_noise
48
+ self.scale_base = scale_base
49
+ self.scale_spline = scale_spline
50
+ self.enable_standalone_scale_spline = enable_standalone_scale_spline
51
+ self.base_activation = base_activation()
52
+ self.grid_eps = grid_eps
53
+
54
+ self.reset_parameters()
55
+
56
+ def reset_parameters(self):
57
+ torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
58
+ with torch.no_grad():
59
+ noise = (
60
+ (
61
+ torch.rand(self.grid_size + 1, self.in_features, self.out_features)
62
+ - 1 / 2
63
+ )
64
+ * self.scale_noise
65
+ / self.grid_size
66
+ )
67
+ self.spline_weight.data.copy_(
68
+ (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
69
+ * self.curve2coeff(
70
+ self.grid.T[self.spline_order : -self.spline_order],
71
+ noise,
72
+ )
73
+ )
74
+ if self.enable_standalone_scale_spline:
75
+ # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
76
+ torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
77
+
78
+ def b_splines(self, x: torch.Tensor):
79
+ """
80
+ Compute the B-spline bases for the given input tensor.
81
+
82
+ Args:
83
+ x (torch.Tensor): Input tensor of shape (batch_size, in_features).
84
+
85
+ Returns:
86
+ torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
87
+ """
88
+ assert x.dim() == 2 and x.size(1) == self.in_features
89
+
90
+ grid: torch.Tensor = (
91
+ self.grid
92
+ ) # (in_features, grid_size + 2 * spline_order + 1)
93
+ x = x.unsqueeze(-1)
94
+ bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
95
+ for k in range(1, self.spline_order + 1):
96
+ bases = (
97
+ (x - grid[:, : -(k + 1)])
98
+ / (grid[:, k:-1] - grid[:, : -(k + 1)])
99
+ * bases[:, :, :-1]
100
+ ) + (
101
+ (grid[:, k + 1 :] - x)
102
+ / (grid[:, k + 1 :] - grid[:, 1:(-k)])
103
+ * bases[:, :, 1:]
104
+ )
105
+
106
+ assert bases.size() == (
107
+ x.size(0),
108
+ self.in_features,
109
+ self.grid_size + self.spline_order,
110
+ )
111
+ return bases.contiguous()
112
+
113
+ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
114
+ """
115
+ Compute the coefficients of the curve that interpolates the given points.
116
+
117
+ Args:
118
+ x (torch.Tensor): Input tensor of shape (batch_size, in_features).
119
+ y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
120
+
121
+ Returns:
122
+ torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
123
+ """
124
+ assert x.dim() == 2 and x.size(1) == self.in_features
125
+ assert y.size() == (x.size(0), self.in_features, self.out_features)
126
+
127
+ A = self.b_splines(x).transpose(
128
+ 0, 1
129
+ ) # (in_features, batch_size, grid_size + spline_order)
130
+ B = y.transpose(0, 1) # (in_features, batch_size, out_features)
131
+ solution = torch.linalg.lstsq(
132
+ A, B
133
+ ).solution # (in_features, grid_size + spline_order, out_features)
134
+ result = solution.permute(
135
+ 2, 0, 1
136
+ ) # (out_features, in_features, grid_size + spline_order)
137
+
138
+ assert result.size() == (
139
+ self.out_features,
140
+ self.in_features,
141
+ self.grid_size + self.spline_order,
142
+ )
143
+ return result.contiguous()
144
+
145
+ @property
146
+ def scaled_spline_weight(self):
147
+ return self.spline_weight * (
148
+ self.spline_scaler.unsqueeze(-1)
149
+ if self.enable_standalone_scale_spline
150
+ else 1.0
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor):
154
+ assert x.size(-1) == self.in_features
155
+ original_shape = x.shape
156
+ x = x.view(-1, self.in_features)
157
+
158
+ base_output = F.linear(self.base_activation(x), self.base_weight)
159
+ spline_output = F.linear(
160
+ self.b_splines(x).view(x.size(0), -1),
161
+ self.scaled_spline_weight.view(self.out_features, -1),
162
+ )
163
+ output = base_output + spline_output
164
+
165
+ output = output.view(*original_shape[:-1], self.out_features)
166
+ return output
167
+
168
+ @torch.no_grad()
169
+ def update_grid(self, x: torch.Tensor, margin=0.01):
170
+ assert x.dim() == 2 and x.size(1) == self.in_features
171
+ batch = x.size(0)
172
+
173
+ splines = self.b_splines(x) # (batch, in, coeff)
174
+ splines = splines.permute(1, 0, 2) # (in, batch, coeff)
175
+ orig_coeff = self.scaled_spline_weight # (out, in, coeff)
176
+ orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
177
+ unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
178
+ unreduced_spline_output = unreduced_spline_output.permute(
179
+ 1, 0, 2
180
+ ) # (batch, in, out)
181
+
182
+ # sort each channel individually to collect data distribution
183
+ x_sorted = torch.sort(x, dim=0)[0]
184
+ grid_adaptive = x_sorted[
185
+ torch.linspace(
186
+ 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
187
+ )
188
+ ]
189
+
190
+ uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
191
+ grid_uniform = (
192
+ torch.arange(
193
+ self.grid_size + 1, dtype=torch.float32, device=x.device
194
+ ).unsqueeze(1)
195
+ * uniform_step
196
+ + x_sorted[0]
197
+ - margin
198
+ )
199
+
200
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
201
+ grid = torch.concatenate(
202
+ [
203
+ grid[:1]
204
+ - uniform_step
205
+ * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
206
+ grid,
207
+ grid[-1:]
208
+ + uniform_step
209
+ * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
210
+ ],
211
+ dim=0,
212
+ )
213
+
214
+ self.grid.copy_(grid.T)
215
+ self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
216
+
217
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
218
+ """
219
+ Compute the regularization loss.
220
+
221
+ This is a dumb simulation of the original L1 regularization as stated in the
222
+ paper, since the original one requires computing absolutes and entropy from the
223
+ expanded (batch, in_features, out_features) intermediate tensor, which is hidden
224
+ behind the F.linear function if we want an memory efficient implementation.
225
+
226
+ The L1 regularization is now computed as mean absolute value of the spline
227
+ weights. The authors implementation also includes this term in addition to the
228
+ sample-based regularization.
229
+ """
230
+ l1_fake = self.spline_weight.abs().mean(-1)
231
+ regularization_loss_activation = l1_fake.sum()
232
+ p = l1_fake / regularization_loss_activation
233
+ regularization_loss_entropy = -torch.sum(p * p.log())
234
+ return (
235
+ regularize_activation * regularization_loss_activation
236
+ + regularize_entropy * regularization_loss_entropy
237
+ )
238
+
239
+
240
+ class KAN(torch.nn.Module):
241
+ def __init__(
242
+ self,
243
+ layers_hidden,
244
+ grid_size=5,
245
+ spline_order=3,
246
+ scale_noise=0.1,
247
+ scale_base=1.0,
248
+ scale_spline=1.0,
249
+ base_activation=torch.nn.SiLU,
250
+ grid_eps=0.02,
251
+ grid_range=[-1, 1],
252
+ ):
253
+ super(KAN, self).__init__()
254
+ self.grid_size = grid_size
255
+ self.spline_order = spline_order
256
+
257
+ self.layers = torch.nn.ModuleList()
258
+ for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
259
+ self.layers.append(
260
+ KANLinear(
261
+ in_features,
262
+ out_features,
263
+ grid_size=grid_size,
264
+ spline_order=spline_order,
265
+ scale_noise=scale_noise,
266
+ scale_base=scale_base,
267
+ scale_spline=scale_spline,
268
+ base_activation=base_activation,
269
+ grid_eps=grid_eps,
270
+ grid_range=grid_range,
271
+ )
272
+ )
273
+
274
+ def forward(self, x: torch.Tensor, update_grid=False):
275
+ for layer in self.layers:
276
+ if update_grid:
277
+ layer.update_grid(x)
278
+ x = layer(x)
279
+ return x
280
+
281
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
282
+ return sum(
283
+ layer.regularization_loss(regularize_activation, regularize_entropy)
284
+ for layer in self.layers
285
+ )
utils/.ipynb_checkpoints/speech_embedder_net-checkpoint.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Sep 5 20:58:34 2018
5
+
6
+ @author: harry
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from utils.hparam import hparam as hp
13
+ from utils.utils import get_centroids, get_cossim, calc_loss
14
+ from utils.kan import KANLinear
15
+
16
+ class SpeechEmbedder(nn.Module):
17
+
18
+ def __init__(self):
19
+ super(SpeechEmbedder, self).__init__()
20
+ self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
21
+ for name, param in self.LSTM_stack.named_parameters():
22
+ if 'bias' in name:
23
+ nn.init.constant_(param, 0.0)
24
+ elif 'weight' in name:
25
+ nn.init.xavier_normal_(param)
26
+ self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
27
+
28
+ def forward(self, x):
29
+ x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
30
+ #only use last frame
31
+ x = x[:,x.size(1)-1]
32
+ x = self.projection(x.float())
33
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
34
+ return x
35
+
36
+
37
+ class SpeechEmbedderGRU(nn.Module):
38
+ def __init__(self):
39
+ super(SpeechEmbedderGRU, self).__init__()
40
+ self.GRU_stack = nn.GRU(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
41
+ for name, param in self.GRU_stack.named_parameters():
42
+ if 'bias' in name:
43
+ nn.init.constant_(param, 0.0)
44
+ elif 'weight' in name:
45
+ nn.init.xavier_normal_(param)
46
+ self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
47
+
48
+ def forward(self, x):
49
+ x, _ = self.GRU_stack(x.float()) #(batch, frames, n_mels)
50
+ #only use last frame
51
+ x = x[:,x.size(1)-1]
52
+ x = self.projection(x.float())
53
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
54
+ return x
55
+
56
+ class SpeechEmbedderKAN(nn.Module):
57
+ def __init__(self):
58
+ super(SpeechEmbedderKAN, self).__init__()
59
+ self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
60
+ for name, param in self.LSTM_stack.named_parameters():
61
+ if 'bias' in name:
62
+ nn.init.constant_(param, 0.0)
63
+ elif 'weight' in name:
64
+ nn.init.xavier_normal_(param)
65
+ self.projection = KANLinear(hp.model.hidden, hp.model.proj)
66
+
67
+ def forward(self, x):
68
+ x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
69
+ #only use last frame
70
+ x = x[:,x.size(1)-1]
71
+ x = self.projection(x.float())
72
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
73
+ return x
74
+
75
+
76
+
77
+ class SpeechEmbedderBidirectional(nn.Module):
78
+ def __init__(self):
79
+ super(SpeechEmbedderBidirectional, self).__init__()
80
+ self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True, bidirectional=True)
81
+ for name, param in self.LSTM_stack.named_parameters():
82
+ if 'bias' in name:
83
+ nn.init.constant_(param, 0.0)
84
+ elif 'weight' in name:
85
+ nn.init.xavier_normal_(param)
86
+ self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
87
+
88
+ def forward(self, x):
89
+ x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
90
+ #only use last frame
91
+ x = x[:, :, :hp.model.hidden]
92
+
93
+ x = x[:,x.size(1)-1]
94
+ x = self.projection(x.float())
95
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
96
+ return x
97
+
98
+ class GE2ELoss(nn.Module):
99
+
100
+ def __init__(self, device):
101
+ super(GE2ELoss, self).__init__()
102
+ self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True)
103
+ self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True)
104
+ self.device = device
105
+
106
+ def forward(self, embeddings):
107
+ torch.clamp(self.w, 1e-6)
108
+ centroids = get_centroids(embeddings)
109
+ cossim = get_cossim(embeddings, centroids)
110
+ sim_matrix = self.w*cossim.to(self.device) + self.b
111
+ loss, _ = calc_loss(sim_matrix)
112
+ return loss
utils/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Sep 20 16:56:19 2018
5
+
6
+ @author: harry
7
+ """
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import torch.autograd as grad
12
+ import torch.nn.functional as F
13
+
14
+ from utils.hparam import hparam as hp
15
+
16
+ def get_centroids_prior(embeddings):
17
+ centroids = []
18
+ for speaker in embeddings:
19
+ centroid = 0
20
+ for utterance in speaker:
21
+ centroid = centroid + utterance
22
+ centroid = centroid/len(speaker)
23
+ centroids.append(centroid)
24
+ centroids = torch.stack(centroids)
25
+ return centroids
26
+
27
+ def get_centroids(embeddings):
28
+ centroids = embeddings.mean(dim=1)
29
+ return centroids
30
+
31
+ def get_centroid(embeddings, speaker_num, utterance_num):
32
+ centroid = 0
33
+ for utterance_id, utterance in enumerate(embeddings[speaker_num]):
34
+ if utterance_id == utterance_num:
35
+ continue
36
+ centroid = centroid + utterance
37
+ centroid = centroid/(len(embeddings[speaker_num])-1)
38
+ return centroid
39
+
40
+ def get_utterance_centroids(embeddings):
41
+ """
42
+ Returns the centroids for each utterance of a speaker, where
43
+ the utterance centroid is the speaker centroid without considering
44
+ this utterance
45
+
46
+ Shape of embeddings should be:
47
+ (speaker_ct, utterance_per_speaker_ct, embedding_size)
48
+ """
49
+ sum_centroids = embeddings.sum(dim=1)
50
+ # we want to subtract out each utterance, prior to calculating the
51
+ # the utterance centroid
52
+ sum_centroids = sum_centroids.reshape(
53
+ sum_centroids.shape[0], 1, sum_centroids.shape[-1]
54
+ )
55
+ # we want the mean but not including the utterance itself, so -1
56
+ num_utterances = embeddings.shape[1] - 1
57
+ centroids = (sum_centroids - embeddings) / num_utterances
58
+ return centroids
59
+
60
+ def get_cossim_prior(embeddings, centroids):
61
+ # Calculates cosine similarity matrix. Requires (N, M, feature) input
62
+ cossim = torch.zeros(embeddings.size(0),embeddings.size(1),centroids.size(0))
63
+ for speaker_num, speaker in enumerate(embeddings):
64
+ for utterance_num, utterance in enumerate(speaker):
65
+ for centroid_num, centroid in enumerate(centroids):
66
+ if speaker_num == centroid_num:
67
+ centroid = get_centroid(embeddings, speaker_num, utterance_num)
68
+ output = F.cosine_similarity(utterance,centroid,dim=0)+1e-6
69
+ cossim[speaker_num][utterance_num][centroid_num] = output
70
+ return cossim
71
+
72
+ def get_cossim(embeddings, centroids):
73
+ # number of utterances per speaker
74
+ num_utterances = embeddings.shape[1]
75
+ utterance_centroids = get_utterance_centroids(embeddings)
76
+
77
+ # flatten the embeddings and utterance centroids to just utterance,
78
+ # so we can do cosine similarity
79
+ utterance_centroids_flat = utterance_centroids.view(
80
+ utterance_centroids.shape[0] * utterance_centroids.shape[1],
81
+ -1
82
+ )
83
+ embeddings_flat = embeddings.view(
84
+ embeddings.shape[0] * num_utterances,
85
+ -1
86
+ )
87
+ # the cosine distance between utterance and the associated centroids
88
+ # for that utterance
89
+ # this is each speaker's utterances against his own centroid, but each
90
+ # comparison centroid has the current utterance removed
91
+ cos_same = F.cosine_similarity(embeddings_flat, utterance_centroids_flat)
92
+
93
+ # now we get the cosine distance between each utterance and the other speakers'
94
+ # centroids
95
+ # to do so requires comparing each utterance to each centroid. To keep the
96
+ # operation fast, we vectorize by using matrices L (embeddings) and
97
+ # R (centroids) where L has each utterance repeated sequentially for all
98
+ # comparisons and R has the entire centroids frame repeated for each utterance
99
+ centroids_expand = centroids.repeat((num_utterances * embeddings.shape[0], 1))
100
+ embeddings_expand = embeddings_flat.unsqueeze(1).repeat(1, embeddings.shape[0], 1)
101
+ embeddings_expand = embeddings_expand.view(
102
+ embeddings_expand.shape[0] * embeddings_expand.shape[1],
103
+ embeddings_expand.shape[-1]
104
+ )
105
+ cos_diff = F.cosine_similarity(embeddings_expand, centroids_expand)
106
+ cos_diff = cos_diff.view(
107
+ embeddings.size(0),
108
+ num_utterances,
109
+ centroids.size(0)
110
+ )
111
+ # assign the cosine distance for same speakers to the proper idx
112
+ same_idx = list(range(embeddings.size(0)))
113
+ cos_diff[same_idx, :, same_idx] = cos_same.view(embeddings.shape[0], num_utterances)
114
+ cos_diff = cos_diff + 1e-6
115
+ return cos_diff
116
+
117
+ def calc_loss_prior(sim_matrix):
118
+ # Calculates loss from (N, M, K) similarity matrix
119
+ per_embedding_loss = torch.zeros(sim_matrix.size(0), sim_matrix.size(1))
120
+ for j in range(len(sim_matrix)):
121
+ for i in range(sim_matrix.size(1)):
122
+ per_embedding_loss[j][i] = -(sim_matrix[j][i][j] - ((torch.exp(sim_matrix[j][i]).sum()+1e-6).log_()))
123
+ loss = per_embedding_loss.sum()
124
+ return loss, per_embedding_loss
125
+
126
+ def calc_loss(sim_matrix):
127
+ same_idx = list(range(sim_matrix.size(0)))
128
+ pos = sim_matrix[same_idx, :, same_idx]
129
+ neg = (torch.exp(sim_matrix).sum(dim=2) + 1e-6).log_()
130
+ per_embedding_loss = -1 * (pos - neg)
131
+ loss = per_embedding_loss.sum()
132
+ return loss, per_embedding_loss
133
+
134
+ def normalize_0_1(values, max_value, min_value):
135
+ normalized = np.clip((values - min_value) / (max_value - min_value), 0, 1)
136
+ return normalized
137
+
138
+ def mfccs_and_spec(wav_file, wav_process = False, calc_mfccs=False, calc_mag_db=False):
139
+ sound_file, _ = librosa.core.load(wav_file, sr=hp.data.sr)
140
+ window_length = int(hp.data.window*hp.data.sr)
141
+ hop_length = int(hp.data.hop*hp.data.sr)
142
+ duration = hp.data.tisv_frame * hp.data.hop + hp.data.window
143
+
144
+ # Cut silence and fix length
145
+ if wav_process == True:
146
+ sound_file, index = librosa.effects.trim(sound_file, frame_length=window_length, hop_length=hop_length)
147
+ length = int(hp.data.sr * duration)
148
+ sound_file = librosa.util.fix_length(sound_file, length)
149
+
150
+ spec = librosa.stft(sound_file, n_fft=hp.data.nfft, hop_length=hop_length, win_length=window_length)
151
+ mag_spec = np.abs(spec)
152
+
153
+ mel_basis = librosa.filters.mel(hp.data.sr, hp.data.nfft, n_mels=hp.data.nmels)
154
+ mel_spec = np.dot(mel_basis, mag_spec)
155
+
156
+ mag_db = librosa.amplitude_to_db(mag_spec)
157
+ #db mel spectrogram
158
+ mel_db = librosa.amplitude_to_db(mel_spec).T
159
+
160
+ mfccs = None
161
+ if calc_mfccs:
162
+ mfccs = np.dot(librosa.filters.dct(40, mel_db.shape[0]), mel_db).T
163
+
164
+ return mfccs, mel_db, mag_db
165
+
166
+ if __name__ == "__main__":
167
+ w = grad.Variable(torch.tensor(1.0))
168
+ b = grad.Variable(torch.tensor(0.0))
169
+ embeddings = torch.tensor([[0,1,0],[0,0,1], [0,1,0], [0,1,0], [1,0,0], [1,0,0]]).to(torch.float).reshape(3,2,3)
170
+ centroids = get_centroids(embeddings)
171
+ cossim = get_cossim(embeddings, centroids)
172
+ sim_matrix = w*cossim + b
173
+ loss, per_embedding_loss = calc_loss(sim_matrix)
utils/VAD_segments.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Dec 18 16:22:41 2018
5
+
6
+ @author: Harry
7
+ Modified from https://github.com/wiseman/py-webrtcvad/blob/master/example.py
8
+ """
9
+
10
+ import collections
11
+ import contextlib
12
+ import numpy as np
13
+ import sys
14
+ import librosa
15
+ import wave
16
+
17
+ import webrtcvad
18
+
19
+ from utils.hparam import hparam as hp
20
+
21
+ def read_wave(path, sr):
22
+ """Reads a .wav file.
23
+ Takes the path, and returns (PCM audio data, sample rate).
24
+ Assumes sample width == 2
25
+ """
26
+ with contextlib.closing(wave.open(path, 'rb')) as wf:
27
+ num_channels = wf.getnchannels()
28
+ assert num_channels == 1
29
+ sample_width = wf.getsampwidth()
30
+ assert sample_width == 2
31
+ sample_rate = wf.getframerate()
32
+ assert sample_rate in (8000, 16000, 32000, 48000)
33
+ pcm_data = wf.readframes(wf.getnframes())
34
+ data, _ = librosa.load(path, sr=sr)
35
+ assert len(data.shape) == 1
36
+ assert sr in (8000, 16000, 32000, 48000)
37
+ return data, pcm_data
38
+
39
+ class Frame(object):
40
+ """Represents a "frame" of audio data."""
41
+ def __init__(self, bytes, timestamp, duration):
42
+ self.bytes = bytes
43
+ self.timestamp = timestamp
44
+ self.duration = duration
45
+
46
+
47
+ def frame_generator(frame_duration_ms, audio, sample_rate):
48
+ """Generates audio frames from PCM audio data.
49
+ Takes the desired frame duration in milliseconds, the PCM data, and
50
+ the sample rate.
51
+ Yields Frames of the requested duration.
52
+ """
53
+ n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
54
+ offset = 0
55
+ timestamp = 0.0
56
+ duration = (float(n) / sample_rate) / 2.0
57
+ while offset + n < len(audio):
58
+ yield Frame(audio[offset:offset + n], timestamp, duration)
59
+ timestamp += duration
60
+ offset += n
61
+
62
+
63
+ def vad_collector(sample_rate, frame_duration_ms,
64
+ padding_duration_ms, vad, frames):
65
+ """Filters out non-voiced audio frames.
66
+ Given a webrtcvad.Vad and a source of audio frames, yields only
67
+ the voiced audio.
68
+ Uses a padded, sliding window algorithm over the audio frames.
69
+ When more than 90% of the frames in the window are voiced (as
70
+ reported by the VAD), the collector triggers and begins yielding
71
+ audio frames. Then the collector waits until 90% of the frames in
72
+ the window are unvoiced to detrigger.
73
+ The window is padded at the front and back to provide a small
74
+ amount of silence or the beginnings/endings of speech around the
75
+ voiced frames.
76
+ Arguments:
77
+ sample_rate - The audio sample rate, in Hz.
78
+ frame_duration_ms - The frame duration in milliseconds.
79
+ padding_duration_ms - The amount to pad the window, in milliseconds.
80
+ vad - An instance of webrtcvad.Vad.
81
+ frames - a source of audio frames (sequence or generator).
82
+ Returns: A generator that yields PCM audio data.
83
+ """
84
+ num_padding_frames = int(padding_duration_ms / frame_duration_ms)
85
+ # We use a deque for our sliding window/ring buffer.
86
+ ring_buffer = collections.deque(maxlen=num_padding_frames)
87
+ # We have two states: TRIGGERED and NOTTRIGGERED. We start in the
88
+ # NOTTRIGGERED state.
89
+ triggered = False
90
+
91
+ voiced_frames = []
92
+ for frame in frames:
93
+ is_speech = vad.is_speech(frame.bytes, sample_rate)
94
+
95
+ if not triggered:
96
+ ring_buffer.append((frame, is_speech))
97
+ num_voiced = len([f for f, speech in ring_buffer if speech])
98
+ # If we're NOTTRIGGERED and more than 90% of the frames in
99
+ # the ring buffer are voiced frames, then enter the
100
+ # TRIGGERED state.
101
+ if num_voiced > 0.9 * ring_buffer.maxlen:
102
+ triggered = True
103
+ start = ring_buffer[0][0].timestamp
104
+ # We want to yield all the audio we see from now until
105
+ # we are NOTTRIGGERED, but we have to start with the
106
+ # audio that's already in the ring buffer.
107
+ for f, s in ring_buffer:
108
+ voiced_frames.append(f)
109
+ ring_buffer.clear()
110
+ else:
111
+ # We're in the TRIGGERED state, so collect the audio data
112
+ # and add it to the ring buffer.
113
+ voiced_frames.append(frame)
114
+ ring_buffer.append((frame, is_speech))
115
+ num_unvoiced = len([f for f, speech in ring_buffer if not speech])
116
+ # If more than 90% of the frames in the ring buffer are
117
+ # unvoiced, then enter NOTTRIGGERED and yield whatever
118
+ # audio we've collected.
119
+ if num_unvoiced > 0.9 * ring_buffer.maxlen:
120
+ triggered = False
121
+ yield (start, frame.timestamp + frame.duration)
122
+ ring_buffer.clear()
123
+ voiced_frames = []
124
+ # If we have any leftover voiced audio when we run out of input,
125
+ # yield it.
126
+ if voiced_frames:
127
+ yield (start, frame.timestamp + frame.duration)
128
+
129
+
130
+ def VAD_chunk(aggressiveness, path):
131
+ audio, byte_audio = read_wave(path, sr=hp.data.sr)
132
+ vad = webrtcvad.Vad(int(aggressiveness))
133
+ frames = frame_generator(20, byte_audio, hp.data.sr)
134
+ frames = list(frames)
135
+ times = vad_collector(hp.data.sr, 20, 200, vad, frames)
136
+ speech_times = []
137
+ speech_segs = []
138
+ for i, time in enumerate(times):
139
+ start = np.round(time[0],decimals=2)
140
+ end = np.round(time[1],decimals=2)
141
+ j = start
142
+ while j + .4 < end:
143
+ end_j = np.round(j+.4,decimals=2)
144
+ speech_times.append((j, end_j))
145
+ speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
146
+ j = end_j
147
+ else:
148
+ speech_times.append((j, end))
149
+ speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
150
+ return speech_times, speech_segs
151
+
152
+ if __name__ == '__main__':
153
+ speech_times, speech_segs = VAD_chunk(sys.argv[1], sys.argv[2])
utils/__init__.py ADDED
File without changes
utils/__pycache__/VAD_segments.cpython-39.pyc ADDED
Binary file (4.68 kB). View file
 
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (162 Bytes). View file
 
utils/__pycache__/data_load.cpython-39.pyc ADDED
Binary file (2.05 kB). View file
 
utils/__pycache__/evaluation.cpython-39.pyc ADDED
Binary file (6.62 kB). View file
 
utils/__pycache__/hparam.cpython-39.pyc ADDED
Binary file (1.98 kB). View file
 
utils/__pycache__/kan.cpython-39.pyc ADDED
Binary file (7.57 kB). View file
 
utils/__pycache__/speech_embedder_net.cpython-39.pyc ADDED
Binary file (4.45 kB). View file
 
utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.7 kB). View file
 
utils/data_load.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly copied from https://github.com/HarryVolek/PyTorch_Speaker_Verification
3
+ """
4
+ import glob
5
+ import numpy as np
6
+ import os
7
+ import random
8
+ from random import shuffle
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+
12
+ from utils.hparam import hparam as hp
13
+ from utils.utils import mfccs_and_spec
14
+
15
+ class GujaratiSpeakerVerificationDataset(Dataset):
16
+
17
+ def __init__(self, shuffle=True, utter_start=0, split='train'):
18
+ # data path
19
+ if split!='val':
20
+ self.path = hp.data.train_path
21
+ self.utter_num = hp.train.M
22
+ else:
23
+ self.path = hp.data.test_path
24
+ self.utter_num = hp.test.M
25
+ self.file_list = os.listdir(self.path)
26
+ self.shuffle=shuffle
27
+ self.utter_start = utter_start
28
+ self.split = split
29
+
30
+ def __len__(self):
31
+ return len(self.file_list)
32
+
33
+ def __getitem__(self, idx):
34
+
35
+ np_file_list = os.listdir(self.path)
36
+
37
+ if self.shuffle:
38
+ selected_file = random.sample(np_file_list, 1)[0] # select random speaker
39
+ else:
40
+ selected_file = np_file_list[idx]
41
+
42
+ utters = np.load(os.path.join(self.path, selected_file))
43
+
44
+ # load utterance spectrogram of selected speaker
45
+ if self.shuffle:
46
+ utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
47
+ utterance = utters[utter_index]
48
+ else:
49
+ utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
50
+
51
+ utterance = utterance[:,:,:160] # TODO implement variable length batch size
52
+
53
+ utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
54
+ return utterance
55
+
56
+ def __repr__(self):
57
+ return f"{self.__class__.__name__}(split={self.split!r}, num_speakers={len(self.file_list)}, num_utterances={self.utter_num})"
utils/evaluation.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from tqdm.auto import tqdm
3
+ import os
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ import random
8
+ from numpy.linalg import norm
9
+
10
+ from utils.VAD_segments import VAD_chunk
11
+ from utils.hparam import hparam as hp
12
+
13
+ class GujaratiSpeakerVerificationDatasetTest(Dataset):
14
+ def __init__(self, path, shuffle=True, utter_start=0):
15
+ # data path
16
+ self.path = path
17
+ self.file_list = os.listdir(self.path)
18
+ self.shuffle=shuffle
19
+ self.utter_start = utter_start
20
+ self.utter_num = 4
21
+
22
+ def __len__(self):
23
+ return len(self.file_list)
24
+
25
+ def __getitem__(self, idx):
26
+
27
+ np_file_list = self.file_list
28
+
29
+ selected_file = np_file_list[idx]
30
+
31
+ utters = np.load(os.path.join(self.path, selected_file))
32
+
33
+ # load utterance spectrogram of selected speaker
34
+ if self.shuffle:
35
+ utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker
36
+ utterance = utters[utter_index]
37
+ else:
38
+ utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames]
39
+
40
+ utterance = utterance[:,:,:160] # TODO implement variable length batch size
41
+
42
+ utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels]
43
+ return utterance
44
+
45
+ def concat_segs(times, segs):
46
+ concat_seg = []
47
+ seg_concat = segs[0]
48
+ for i in range(0, len(times)-1):
49
+ if times[i][1] == times[i+1][0]:
50
+ seg_concat = np.concatenate((seg_concat, segs[i+1]))
51
+ else:
52
+ concat_seg.append(seg_concat)
53
+ seg_concat = segs[i+1]
54
+ else:
55
+ concat_seg.append(seg_concat)
56
+ return concat_seg
57
+
58
+
59
+ def get_STFTs(segs):
60
+ sr = 16000
61
+ STFT_frames = []
62
+ for seg in segs:
63
+ S = librosa.core.stft(y=seg, n_fft=hp.data.nfft,
64
+ win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr))
65
+ S = np.abs(S)**2
66
+ mel_basis = librosa.filters.mel(sr=sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels)
67
+ S = np.log10(np.dot(mel_basis, S) + 1e-6)
68
+ for j in range(0, S.shape[1], int(.12/hp.data.hop)):
69
+ if j + 24 < S.shape[1]:
70
+ STFT_frames.append(S[:, j:j+24])
71
+ else:
72
+ break
73
+ return STFT_frames
74
+
75
+
76
+ def get_embedding(file_path, embedder_net, device, n_threshold=-1):
77
+ times, segs = VAD_chunk(2, file_path)
78
+ if not segs:
79
+ print(f'No voice activity detected in {file_path}')
80
+ return None
81
+ concat_seg = concat_segs(times, segs)
82
+ if not concat_seg:
83
+ print(f'No concatenated segments for {file_path}')
84
+ return None
85
+ STFT_frames = get_STFTs(concat_seg)
86
+ if not STFT_frames:
87
+ #print(f'No STFT frames for {file_path}')
88
+ return None
89
+ STFT_frames = np.stack(STFT_frames, axis=2)
90
+ STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
91
+
92
+ with torch.no_grad():
93
+ embeddings = embedder_net(STFT_frames)
94
+ embeddings = embeddings[:n_threshold, :]
95
+
96
+ avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
97
+ return avg_embedding
98
+
99
+ def get_speaker_embeddings_listdir(embedder_net, device, list_dir, k):
100
+ speaker_embeddings = {}
101
+ for speaker_name in tqdm(list_dir, leave = False):
102
+ speaker_dir = speaker_name
103
+ if os.path.isdir(speaker_dir) and speaker_dir[0] != ".DS_Store":
104
+ speaker_embeddings[speaker_name] = []
105
+ for i in range(10):
106
+ embeddings = []
107
+ audio_files = [os.path.join(speaker_dir, f) for f in os.listdir(speaker_dir) if f.endswith('.wav')]
108
+ random.shuffle(audio_files)
109
+ count = 0
110
+ iter_ = 0
111
+ while(count <= k):
112
+ file_path = audio_files[iter_]
113
+ embedding = get_embedding(file_path, embedder_net, device)
114
+ try:
115
+ _ = embedding.shape
116
+ embeddings.append(embedding)
117
+ count+=1
118
+ iter_+=1
119
+ except:
120
+ iter_+=1
121
+ speaker_embeddings[speaker_name].append(np.mean(embeddings, axis=0))
122
+ return speaker_embeddings
123
+
124
+ def create_pairs(speaker_embeddings):
125
+ pairs = []
126
+ labels = []
127
+ speakers = list(speaker_embeddings.keys())
128
+
129
+ for i in range(len(speakers)):
130
+ for j in range(len(speakers)):
131
+ for k1 in range(10):
132
+ for k2 in range(10):
133
+ emb1 = speaker_embeddings[speakers[i]][k1]
134
+ emb2 = speaker_embeddings[speakers[j]][k2]
135
+ pairs.append((emb1, emb2))
136
+ if i == j and not((emb1 == emb2).all()):
137
+ labels.append(1) # Same speaker
138
+ else:
139
+ labels.append(0) # Different speakers
140
+ return pairs, labels
141
+
142
+ class EmbeddingPairDataset(Dataset):
143
+ def __init__(self, pairs, labels):
144
+ self.pairs = pairs
145
+ self.labels = labels
146
+
147
+ def __len__(self):
148
+ return len(self.pairs)
149
+
150
+ def __getitem__(self, idx):
151
+ emb1, emb2 = self.pairs[idx]
152
+ label = self.labels[idx]
153
+
154
+ emb1, emb2 = torch.tensor(emb1, dtype=torch.float32), torch.tensor(emb2, dtype=torch.float32)
155
+
156
+ concatenated = torch.cat((emb1, emb2), dim=1)
157
+
158
+ return concatenated.squeeze(), torch.tensor(label, dtype=torch.float32)
159
+
160
+ def __len__(self):
161
+ return len(self.labels)
162
+
163
+ def __repr__(self):
164
+ return f"{self.__class__.__name__}(length={self.__len__()})"
165
+
166
+
167
+ def cosine_similarity(A, B):
168
+ A = A.flatten().astype(np.float64)
169
+ B = B.flatten().astype(np.float64)
170
+ cosine = np.dot(A,B)/(norm(A)*norm(B))
171
+ return cosine
172
+
173
+
174
+ def create_subset(dataset, num_zeros):
175
+ pairs = dataset.pairs
176
+ labels = dataset.labels
177
+
178
+ pairs_1 = [pairs[i] for i in range(len(pairs)) if labels[i] == 1]
179
+ labels_1 = [labels[i] for i in range(len(labels)) if labels[i] == 1]
180
+
181
+ pairs_0 = [pairs[i] for i in range(len(pairs)) if labels[i] == 0]
182
+ labels_0 = [labels[i] for i in range(len(labels)) if labels[i] == 0]
183
+
184
+ num_zeros = min(num_zeros, len(pairs_0))
185
+
186
+ pairs_0 = pairs_0[:num_zeros]
187
+ labels_0 = labels_0[:num_zeros]
188
+
189
+ filtered_pairs = pairs_1 + pairs_0
190
+ filtered_labels = labels_1 + labels_0
191
+
192
+ return filtered_pairs, filtered_labels
utils/hparam.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #!/usr/bin/env python
3
+
4
+ import yaml
5
+
6
+ def load_hparam(filename):
7
+ stream = open(filename, 'r')
8
+ docs = yaml.load_all(stream, Loader=yaml.Loader)
9
+ hparam_dict = dict()
10
+ for doc in docs:
11
+ for k, v in doc.items():
12
+ hparam_dict[k] = v
13
+ return hparam_dict
14
+
15
+ def merge_dict(user, default):
16
+ if isinstance(user, dict) and isinstance(default, dict):
17
+ for k, v in default.items():
18
+ if k not in user:
19
+ user[k] = v
20
+ else:
21
+ user[k] = merge_dict(user[k], v)
22
+ return user
23
+
24
+
25
+ class Dotdict(dict):
26
+ """
27
+ a dictionary that supports dot notation
28
+ as well as dictionary access notation
29
+ usage: d = DotDict() or d = DotDict({'val1':'first'})
30
+ set attributes: d.val2 = 'second' or d['val2'] = 'second'
31
+ get attributes: d.val2 or d['val2']
32
+ """
33
+ __getattr__ = dict.__getitem__
34
+ __setattr__ = dict.__setitem__
35
+ __delattr__ = dict.__delitem__
36
+
37
+ def __init__(self, dct=None):
38
+ dct = dict() if not dct else dct
39
+ for key, value in dct.items():
40
+ if hasattr(value, 'keys'):
41
+ value = Dotdict(value)
42
+ self[key] = value
43
+
44
+
45
+ class Hparam(Dotdict):
46
+
47
+ def __init__(self, file='config/config.yaml'):
48
+ super(Dotdict, self).__init__()
49
+ hp_dict = load_hparam(file)
50
+ hp_dotdict = Dotdict(hp_dict)
51
+ for k, v in hp_dotdict.items():
52
+ setattr(self, k, v)
53
+
54
+ __getattr__ = Dotdict.__getitem__
55
+ __setattr__ = Dotdict.__setitem__
56
+ __delattr__ = Dotdict.__delitem__
57
+
58
+
59
+ hparam = Hparam()
utils/kan.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class KANLinear(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ out_features,
11
+ grid_size=5,
12
+ spline_order=3,
13
+ scale_noise=0.1,
14
+ scale_base=1.0,
15
+ scale_spline=1.0,
16
+ enable_standalone_scale_spline=True,
17
+ base_activation=torch.nn.SiLU,
18
+ grid_eps=0.02,
19
+ grid_range=[-1, 1],
20
+ ):
21
+ super(KANLinear, self).__init__()
22
+ self.in_features = in_features
23
+ self.out_features = out_features
24
+ self.grid_size = grid_size
25
+ self.spline_order = spline_order
26
+
27
+ h = (grid_range[1] - grid_range[0]) / grid_size
28
+ grid = (
29
+ (
30
+ torch.arange(-spline_order, grid_size + spline_order + 1) * h
31
+ + grid_range[0]
32
+ )
33
+ .expand(in_features, -1)
34
+ .contiguous()
35
+ )
36
+ self.register_buffer("grid", grid)
37
+
38
+ self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
39
+ self.spline_weight = torch.nn.Parameter(
40
+ torch.Tensor(out_features, in_features, grid_size + spline_order)
41
+ )
42
+ if enable_standalone_scale_spline:
43
+ self.spline_scaler = torch.nn.Parameter(
44
+ torch.Tensor(out_features, in_features)
45
+ )
46
+
47
+ self.scale_noise = scale_noise
48
+ self.scale_base = scale_base
49
+ self.scale_spline = scale_spline
50
+ self.enable_standalone_scale_spline = enable_standalone_scale_spline
51
+ self.base_activation = base_activation()
52
+ self.grid_eps = grid_eps
53
+
54
+ self.reset_parameters()
55
+
56
+ def reset_parameters(self):
57
+ torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
58
+ with torch.no_grad():
59
+ noise = (
60
+ (
61
+ torch.rand(self.grid_size + 1, self.in_features, self.out_features)
62
+ - 1 / 2
63
+ )
64
+ * self.scale_noise
65
+ / self.grid_size
66
+ )
67
+ self.spline_weight.data.copy_(
68
+ (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
69
+ * self.curve2coeff(
70
+ self.grid.T[self.spline_order : -self.spline_order],
71
+ noise,
72
+ )
73
+ )
74
+ if self.enable_standalone_scale_spline:
75
+ # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
76
+ torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
77
+
78
+ def b_splines(self, x: torch.Tensor):
79
+ """
80
+ Compute the B-spline bases for the given input tensor.
81
+
82
+ Args:
83
+ x (torch.Tensor): Input tensor of shape (batch_size, in_features).
84
+
85
+ Returns:
86
+ torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
87
+ """
88
+ assert x.dim() == 2 and x.size(1) == self.in_features
89
+
90
+ grid: torch.Tensor = (
91
+ self.grid
92
+ ) # (in_features, grid_size + 2 * spline_order + 1)
93
+ x = x.unsqueeze(-1)
94
+ bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
95
+ for k in range(1, self.spline_order + 1):
96
+ bases = (
97
+ (x - grid[:, : -(k + 1)])
98
+ / (grid[:, k:-1] - grid[:, : -(k + 1)])
99
+ * bases[:, :, :-1]
100
+ ) + (
101
+ (grid[:, k + 1 :] - x)
102
+ / (grid[:, k + 1 :] - grid[:, 1:(-k)])
103
+ * bases[:, :, 1:]
104
+ )
105
+
106
+ assert bases.size() == (
107
+ x.size(0),
108
+ self.in_features,
109
+ self.grid_size + self.spline_order,
110
+ )
111
+ return bases.contiguous()
112
+
113
+ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
114
+ """
115
+ Compute the coefficients of the curve that interpolates the given points.
116
+
117
+ Args:
118
+ x (torch.Tensor): Input tensor of shape (batch_size, in_features).
119
+ y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
120
+
121
+ Returns:
122
+ torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
123
+ """
124
+ assert x.dim() == 2 and x.size(1) == self.in_features
125
+ assert y.size() == (x.size(0), self.in_features, self.out_features)
126
+
127
+ A = self.b_splines(x).transpose(
128
+ 0, 1
129
+ ) # (in_features, batch_size, grid_size + spline_order)
130
+ B = y.transpose(0, 1) # (in_features, batch_size, out_features)
131
+ solution = torch.linalg.lstsq(
132
+ A, B
133
+ ).solution # (in_features, grid_size + spline_order, out_features)
134
+ result = solution.permute(
135
+ 2, 0, 1
136
+ ) # (out_features, in_features, grid_size + spline_order)
137
+
138
+ assert result.size() == (
139
+ self.out_features,
140
+ self.in_features,
141
+ self.grid_size + self.spline_order,
142
+ )
143
+ return result.contiguous()
144
+
145
+ @property
146
+ def scaled_spline_weight(self):
147
+ return self.spline_weight * (
148
+ self.spline_scaler.unsqueeze(-1)
149
+ if self.enable_standalone_scale_spline
150
+ else 1.0
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor):
154
+ assert x.size(-1) == self.in_features
155
+ original_shape = x.shape
156
+ x = x.view(-1, self.in_features)
157
+
158
+ base_output = F.linear(self.base_activation(x), self.base_weight)
159
+ spline_output = F.linear(
160
+ self.b_splines(x).view(x.size(0), -1),
161
+ self.scaled_spline_weight.view(self.out_features, -1),
162
+ )
163
+ output = base_output + spline_output
164
+
165
+ output = output.view(*original_shape[:-1], self.out_features)
166
+ return output
167
+
168
+ @torch.no_grad()
169
+ def update_grid(self, x: torch.Tensor, margin=0.01):
170
+ assert x.dim() == 2 and x.size(1) == self.in_features
171
+ batch = x.size(0)
172
+
173
+ splines = self.b_splines(x) # (batch, in, coeff)
174
+ splines = splines.permute(1, 0, 2) # (in, batch, coeff)
175
+ orig_coeff = self.scaled_spline_weight # (out, in, coeff)
176
+ orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
177
+ unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
178
+ unreduced_spline_output = unreduced_spline_output.permute(
179
+ 1, 0, 2
180
+ ) # (batch, in, out)
181
+
182
+ # sort each channel individually to collect data distribution
183
+ x_sorted = torch.sort(x, dim=0)[0]
184
+ grid_adaptive = x_sorted[
185
+ torch.linspace(
186
+ 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
187
+ )
188
+ ]
189
+
190
+ uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
191
+ grid_uniform = (
192
+ torch.arange(
193
+ self.grid_size + 1, dtype=torch.float32, device=x.device
194
+ ).unsqueeze(1)
195
+ * uniform_step
196
+ + x_sorted[0]
197
+ - margin
198
+ )
199
+
200
+ grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
201
+ grid = torch.concatenate(
202
+ [
203
+ grid[:1]
204
+ - uniform_step
205
+ * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
206
+ grid,
207
+ grid[-1:]
208
+ + uniform_step
209
+ * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
210
+ ],
211
+ dim=0,
212
+ )
213
+
214
+ self.grid.copy_(grid.T)
215
+ self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
216
+
217
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
218
+ """
219
+ Compute the regularization loss.
220
+
221
+ This is a dumb simulation of the original L1 regularization as stated in the
222
+ paper, since the original one requires computing absolutes and entropy from the
223
+ expanded (batch, in_features, out_features) intermediate tensor, which is hidden
224
+ behind the F.linear function if we want an memory efficient implementation.
225
+
226
+ The L1 regularization is now computed as mean absolute value of the spline
227
+ weights. The authors implementation also includes this term in addition to the
228
+ sample-based regularization.
229
+ """
230
+ l1_fake = self.spline_weight.abs().mean(-1)
231
+ regularization_loss_activation = l1_fake.sum()
232
+ p = l1_fake / regularization_loss_activation
233
+ regularization_loss_entropy = -torch.sum(p * p.log())
234
+ return (
235
+ regularize_activation * regularization_loss_activation
236
+ + regularize_entropy * regularization_loss_entropy
237
+ )
238
+
239
+
240
+ class KAN(torch.nn.Module):
241
+ def __init__(
242
+ self,
243
+ layers_hidden,
244
+ grid_size=5,
245
+ spline_order=3,
246
+ scale_noise=0.1,
247
+ scale_base=1.0,
248
+ scale_spline=1.0,
249
+ base_activation=torch.nn.SiLU,
250
+ grid_eps=0.02,
251
+ grid_range=[-1, 1],
252
+ ):
253
+ super(KAN, self).__init__()
254
+ self.grid_size = grid_size
255
+ self.spline_order = spline_order
256
+
257
+ self.layers = torch.nn.ModuleList()
258
+ for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
259
+ self.layers.append(
260
+ KANLinear(
261
+ in_features,
262
+ out_features,
263
+ grid_size=grid_size,
264
+ spline_order=spline_order,
265
+ scale_noise=scale_noise,
266
+ scale_base=scale_base,
267
+ scale_spline=scale_spline,
268
+ base_activation=base_activation,
269
+ grid_eps=grid_eps,
270
+ grid_range=grid_range,
271
+ )
272
+ )
273
+
274
+ def forward(self, x: torch.Tensor, update_grid=False):
275
+ for layer in self.layers:
276
+ if update_grid:
277
+ layer.update_grid(x)
278
+ x = layer(x)
279
+ return x
280
+
281
+ def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
282
+ return sum(
283
+ layer.regularization_loss(regularize_activation, regularize_entropy)
284
+ for layer in self.layers
285
+ )
utils/speech_embedder_net.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Wed Sep 5 20:58:34 2018
5
+
6
+ @author: harry
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from utils.hparam import hparam as hp
13
+ from utils.utils import get_centroids, get_cossim, calc_loss
14
+ from utils.kan import KANLinear
15
+
16
+ class SpeechEmbedder(nn.Module):
17
+
18
+ def __init__(self):
19
+ super(SpeechEmbedder, self).__init__()
20
+ self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
21
+ for name, param in self.LSTM_stack.named_parameters():
22
+ if 'bias' in name:
23
+ nn.init.constant_(param, 0.0)
24
+ elif 'weight' in name:
25
+ nn.init.xavier_normal_(param)
26
+ self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
27
+
28
+ def forward(self, x):
29
+ x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
30
+ #only use last frame
31
+ x = x[:,x.size(1)-1]
32
+ x = self.projection(x.float())
33
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
34
+ return x
35
+
36
+
37
+ class SpeechEmbedderGRU(nn.Module):
38
+ def __init__(self):
39
+ super(SpeechEmbedderGRU, self).__init__()
40
+ self.GRU_stack = nn.GRU(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
41
+ for name, param in self.GRU_stack.named_parameters():
42
+ if 'bias' in name:
43
+ nn.init.constant_(param, 0.0)
44
+ elif 'weight' in name:
45
+ nn.init.xavier_normal_(param)
46
+ self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
47
+
48
+ def forward(self, x):
49
+ x, _ = self.GRU_stack(x.float()) #(batch, frames, n_mels)
50
+ #only use last frame
51
+ x = x[:,x.size(1)-1]
52
+ x = self.projection(x.float())
53
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
54
+ return x
55
+
56
+ class SpeechEmbedderKAN(nn.Module):
57
+ def __init__(self):
58
+ super(SpeechEmbedderKAN, self).__init__()
59
+ self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True)
60
+ for name, param in self.LSTM_stack.named_parameters():
61
+ if 'bias' in name:
62
+ nn.init.constant_(param, 0.0)
63
+ elif 'weight' in name:
64
+ nn.init.xavier_normal_(param)
65
+ self.projection = KANLinear(hp.model.hidden, hp.model.proj)
66
+
67
+ def forward(self, x):
68
+ x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
69
+ #only use last frame
70
+ x = x[:,x.size(1)-1]
71
+ x = self.projection(x.float())
72
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
73
+ return x
74
+
75
+
76
+
77
+ class SpeechEmbedderBidirectional(nn.Module):
78
+ def __init__(self):
79
+ super(SpeechEmbedderBidirectional, self).__init__()
80
+ self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True, bidirectional=True)
81
+ for name, param in self.LSTM_stack.named_parameters():
82
+ if 'bias' in name:
83
+ nn.init.constant_(param, 0.0)
84
+ elif 'weight' in name:
85
+ nn.init.xavier_normal_(param)
86
+ self.projection = nn.Linear(hp.model.hidden, hp.model.proj)
87
+
88
+ def forward(self, x):
89
+ x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels)
90
+ #only use last frame
91
+ x = x[:, :, :hp.model.hidden]
92
+
93
+ x = x[:,x.size(1)-1]
94
+ x = self.projection(x.float())
95
+ x = x / torch.norm(x, dim=1).unsqueeze(1)
96
+ return x
97
+
98
+ class GE2ELoss(nn.Module):
99
+
100
+ def __init__(self, device):
101
+ super(GE2ELoss, self).__init__()
102
+ self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True)
103
+ self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True)
104
+ self.device = device
105
+
106
+ def forward(self, embeddings):
107
+ torch.clamp(self.w, 1e-6)
108
+ centroids = get_centroids(embeddings)
109
+ cossim = get_cossim(embeddings, centroids)
110
+ sim_matrix = self.w*cossim.to(self.device) + self.b
111
+ loss, _ = calc_loss(sim_matrix)
112
+ return loss
utils/utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Sep 20 16:56:19 2018
5
+
6
+ @author: harry
7
+ """
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ import torch.autograd as grad
12
+ import torch.nn.functional as F
13
+
14
+ from utils.hparam import hparam as hp
15
+
16
+ def get_centroids_prior(embeddings):
17
+ centroids = []
18
+ for speaker in embeddings:
19
+ centroid = 0
20
+ for utterance in speaker:
21
+ centroid = centroid + utterance
22
+ centroid = centroid/len(speaker)
23
+ centroids.append(centroid)
24
+ centroids = torch.stack(centroids)
25
+ return centroids
26
+
27
+ def get_centroids(embeddings):
28
+ centroids = embeddings.mean(dim=1)
29
+ return centroids
30
+
31
+ def get_centroid(embeddings, speaker_num, utterance_num):
32
+ centroid = 0
33
+ for utterance_id, utterance in enumerate(embeddings[speaker_num]):
34
+ if utterance_id == utterance_num:
35
+ continue
36
+ centroid = centroid + utterance
37
+ centroid = centroid/(len(embeddings[speaker_num])-1)
38
+ return centroid
39
+
40
+ def get_utterance_centroids(embeddings):
41
+ """
42
+ Returns the centroids for each utterance of a speaker, where
43
+ the utterance centroid is the speaker centroid without considering
44
+ this utterance
45
+
46
+ Shape of embeddings should be:
47
+ (speaker_ct, utterance_per_speaker_ct, embedding_size)
48
+ """
49
+ sum_centroids = embeddings.sum(dim=1)
50
+ # we want to subtract out each utterance, prior to calculating the
51
+ # the utterance centroid
52
+ sum_centroids = sum_centroids.reshape(
53
+ sum_centroids.shape[0], 1, sum_centroids.shape[-1]
54
+ )
55
+ # we want the mean but not including the utterance itself, so -1
56
+ num_utterances = embeddings.shape[1] - 1
57
+ centroids = (sum_centroids - embeddings) / num_utterances
58
+ return centroids
59
+
60
+ def get_cossim_prior(embeddings, centroids):
61
+ # Calculates cosine similarity matrix. Requires (N, M, feature) input
62
+ cossim = torch.zeros(embeddings.size(0),embeddings.size(1),centroids.size(0))
63
+ for speaker_num, speaker in enumerate(embeddings):
64
+ for utterance_num, utterance in enumerate(speaker):
65
+ for centroid_num, centroid in enumerate(centroids):
66
+ if speaker_num == centroid_num:
67
+ centroid = get_centroid(embeddings, speaker_num, utterance_num)
68
+ output = F.cosine_similarity(utterance,centroid,dim=0)+1e-6
69
+ cossim[speaker_num][utterance_num][centroid_num] = output
70
+ return cossim
71
+
72
+ def get_cossim(embeddings, centroids):
73
+ # number of utterances per speaker
74
+ num_utterances = embeddings.shape[1]
75
+ utterance_centroids = get_utterance_centroids(embeddings)
76
+
77
+ # flatten the embeddings and utterance centroids to just utterance,
78
+ # so we can do cosine similarity
79
+ utterance_centroids_flat = utterance_centroids.view(
80
+ utterance_centroids.shape[0] * utterance_centroids.shape[1],
81
+ -1
82
+ )
83
+ embeddings_flat = embeddings.view(
84
+ embeddings.shape[0] * num_utterances,
85
+ -1
86
+ )
87
+ # the cosine distance between utterance and the associated centroids
88
+ # for that utterance
89
+ # this is each speaker's utterances against his own centroid, but each
90
+ # comparison centroid has the current utterance removed
91
+ cos_same = F.cosine_similarity(embeddings_flat, utterance_centroids_flat)
92
+
93
+ # now we get the cosine distance between each utterance and the other speakers'
94
+ # centroids
95
+ # to do so requires comparing each utterance to each centroid. To keep the
96
+ # operation fast, we vectorize by using matrices L (embeddings) and
97
+ # R (centroids) where L has each utterance repeated sequentially for all
98
+ # comparisons and R has the entire centroids frame repeated for each utterance
99
+ centroids_expand = centroids.repeat((num_utterances * embeddings.shape[0], 1))
100
+ embeddings_expand = embeddings_flat.unsqueeze(1).repeat(1, embeddings.shape[0], 1)
101
+ embeddings_expand = embeddings_expand.view(
102
+ embeddings_expand.shape[0] * embeddings_expand.shape[1],
103
+ embeddings_expand.shape[-1]
104
+ )
105
+ cos_diff = F.cosine_similarity(embeddings_expand, centroids_expand)
106
+ cos_diff = cos_diff.view(
107
+ embeddings.size(0),
108
+ num_utterances,
109
+ centroids.size(0)
110
+ )
111
+ # assign the cosine distance for same speakers to the proper idx
112
+ same_idx = list(range(embeddings.size(0)))
113
+ cos_diff[same_idx, :, same_idx] = cos_same.view(embeddings.shape[0], num_utterances)
114
+ cos_diff = cos_diff + 1e-6
115
+ return cos_diff
116
+
117
+ def calc_loss_prior(sim_matrix):
118
+ # Calculates loss from (N, M, K) similarity matrix
119
+ per_embedding_loss = torch.zeros(sim_matrix.size(0), sim_matrix.size(1))
120
+ for j in range(len(sim_matrix)):
121
+ for i in range(sim_matrix.size(1)):
122
+ per_embedding_loss[j][i] = -(sim_matrix[j][i][j] - ((torch.exp(sim_matrix[j][i]).sum()+1e-6).log_()))
123
+ loss = per_embedding_loss.sum()
124
+ return loss, per_embedding_loss
125
+
126
+ def calc_loss(sim_matrix):
127
+ same_idx = list(range(sim_matrix.size(0)))
128
+ pos = sim_matrix[same_idx, :, same_idx]
129
+ neg = (torch.exp(sim_matrix).sum(dim=2) + 1e-6).log_()
130
+ per_embedding_loss = -1 * (pos - neg)
131
+ loss = per_embedding_loss.sum()
132
+ return loss, per_embedding_loss
133
+
134
+ def normalize_0_1(values, max_value, min_value):
135
+ normalized = np.clip((values - min_value) / (max_value - min_value), 0, 1)
136
+ return normalized
137
+
138
+ def mfccs_and_spec(wav_file, wav_process = False, calc_mfccs=False, calc_mag_db=False):
139
+ sound_file, _ = librosa.core.load(wav_file, sr=hp.data.sr)
140
+ window_length = int(hp.data.window*hp.data.sr)
141
+ hop_length = int(hp.data.hop*hp.data.sr)
142
+ duration = hp.data.tisv_frame * hp.data.hop + hp.data.window
143
+
144
+ # Cut silence and fix length
145
+ if wav_process == True:
146
+ sound_file, index = librosa.effects.trim(sound_file, frame_length=window_length, hop_length=hop_length)
147
+ length = int(hp.data.sr * duration)
148
+ sound_file = librosa.util.fix_length(sound_file, length)
149
+
150
+ spec = librosa.stft(sound_file, n_fft=hp.data.nfft, hop_length=hop_length, win_length=window_length)
151
+ mag_spec = np.abs(spec)
152
+
153
+ mel_basis = librosa.filters.mel(hp.data.sr, hp.data.nfft, n_mels=hp.data.nmels)
154
+ mel_spec = np.dot(mel_basis, mag_spec)
155
+
156
+ mag_db = librosa.amplitude_to_db(mag_spec)
157
+ #db mel spectrogram
158
+ mel_db = librosa.amplitude_to_db(mel_spec).T
159
+
160
+ mfccs = None
161
+ if calc_mfccs:
162
+ mfccs = np.dot(librosa.filters.dct(40, mel_db.shape[0]), mel_db).T
163
+
164
+ return mfccs, mel_db, mag_db
165
+
166
+ if __name__ == "__main__":
167
+ w = grad.Variable(torch.tensor(1.0))
168
+ b = grad.Variable(torch.tensor(0.0))
169
+ embeddings = torch.tensor([[0,1,0],[0,0,1], [0,1,0], [0,1,0], [1,0,0], [1,0,0]]).to(torch.float).reshape(3,2,3)
170
+ centroids = get_centroids(embeddings)
171
+ cossim = get_cossim(embeddings, centroids)
172
+ sim_matrix = w*cossim + b
173
+ loss, per_embedding_loss = calc_loss(sim_matrix)