{ "cells": [ { "cell_type": "markdown", "id": "23237138-936a-44b4-9eb6-f16045d2c91d", "metadata": {}, "source": [ "### **Gradio Demo | LSTM Speaker Embedding Model for Gujarati Speaker Verification**\n", "****\n", "**Author:** Irsh Vijay
\n", "**Organization**: Wadhwani Institute for Artificial Intelligence
\n", "****\n", "This notebook has the required code to run a gradio demo." ] }, { "cell_type": "code", "execution_count": 8, "id": "1d2cfd8b-9498-4236-9d32-718e9e0597cb", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import librosa\n", "import numpy as np\n", "import os\n", "import webrtcvad\n", "import wave\n", "import contextlib\n", "\n", "from utils.VAD_segments import *\n", "from utils.hparam import hparam as hp\n", "from utils.speech_embedder_net import *\n", "from utils.evaluation import *" ] }, { "cell_type": "code", "execution_count": 9, "id": "3e9e1006-83d2-4492-a210-26b2c3717cd5", "metadata": {}, "outputs": [], "source": [ "def read_wave(audio_data):\n", " \"\"\"Reads audio data and returns (PCM audio data, sample rate).\n", " Assumes the input is a tuple (sample_rate, numpy_array).\n", " If the sample rate is unsupported, resamples to 16000 Hz.\n", " \"\"\"\n", " sample_rate, data = audio_data\n", "\n", " # Ensure data is in the correct shape\n", " assert len(data.shape) == 1, \"Audio data must be a 1D array\"\n", "\n", " # Convert to floating point if necessary\n", " if not np.issubdtype(data.dtype, np.floating):\n", " data = data.astype(np.float32) / np.iinfo(data.dtype).max\n", " \n", " # Supported sample rates\n", " supported_sample_rates = (8000, 16000, 32000, 48000)\n", " \n", " # If sample rate is not supported, resample to 16000 Hz\n", " if sample_rate not in supported_sample_rates:\n", " data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)\n", " sample_rate = 16000\n", " \n", " # Convert numpy array to PCM format\n", " pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()\n", "\n", " return data, pcm_data" ] }, { "cell_type": "code", "execution_count": 10, "id": "0b56a2fc-83c3-4b36-95b8-5f1b656150ed", "metadata": {}, "outputs": [], "source": [ "def VAD_chunk(aggressiveness, data):\n", " audio, byte_audio = read_wave(data)\n", " vad = webrtcvad.Vad(int(aggressiveness))\n", " frames = frame_generator(20, byte_audio, hp.data.sr)\n", " frames = list(frames)\n", " times = vad_collector(hp.data.sr, 20, 200, vad, frames)\n", " speech_times = []\n", " speech_segs = []\n", " for i, time in enumerate(times):\n", " start = np.round(time[0],decimals=2)\n", " end = np.round(time[1],decimals=2)\n", " j = start\n", " while j + .4 < end:\n", " end_j = np.round(j+.4,decimals=2)\n", " speech_times.append((j, end_j))\n", " speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])\n", " j = end_j\n", " else:\n", " speech_times.append((j, end))\n", " speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])\n", " return speech_times, speech_segs" ] }, { "cell_type": "code", "execution_count": 11, "id": "72f257cf-7d3f-4ec5-944a-57779ba377e6", "metadata": {}, "outputs": [], "source": [ "def get_embedding(data, embedder_net, device, n_threshold=-1):\n", " times, segs = VAD_chunk(0, data)\n", " if not segs:\n", " print(f'No voice activity detected')\n", " return None\n", " concat_seg = concat_segs(times, segs)\n", " if not concat_seg:\n", " print(f'No concatenated segments')\n", " return None\n", " STFT_frames = get_STFTs(concat_seg)\n", " if not STFT_frames:\n", " #print(f'No STFT frames')\n", " return None\n", " STFT_frames = np.stack(STFT_frames, axis=2)\n", " STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)\n", "\n", " with torch.no_grad():\n", " embeddings = embedder_net(STFT_frames)\n", " embeddings = embeddings[:n_threshold, :]\n", " \n", " avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()\n", " return avg_embedding" ] }, { "cell_type": "code", "execution_count": 12, "id": "200df766-407d-4367-b0fb-7a6118653731", "metadata": {}, "outputs": [], "source": [ "model_path = \"./speech_id_checkpoint/saved_01.model\"" ] }, { "cell_type": "code", "execution_count": 13, "id": "db7613e6-67a8-4920-a999-caca4a0de360", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SpeechEmbedder(\n", " (LSTM_stack): LSTM(40, 768, num_layers=3, batch_first=True)\n", " (projection): Linear(in_features=768, out_features=256, bias=True)\n", ")" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"mps\" if torch.backends.mps.is_available() else \"cpu\")\n", "\n", "embedder_net = SpeechEmbedder().to(device)\n", "embedder_net.load_state_dict(torch.load(model_path, map_location=device))\n", "embedder_net.eval()" ] }, { "cell_type": "code", "execution_count": 14, "id": "8a7dd9bd-7b40-41f9-8e2f-d68be18f2111", "metadata": {}, "outputs": [], "source": [ "import gradio as gr" ] }, { "cell_type": "code", "execution_count": 28, "id": "bd6c073d-eab8-4ae6-8ba6-d90a0ec54c0e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7868\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def process_audio(audio1, audio2, threshold):\n", " e1 = get_embedding(audio1, embedder_net, device)\n", " if(e1 is None):\n", " return \"No Voice Detected in file 1\"\n", " e2 = get_embedding(audio2, embedder_net, device)\n", " if(e2 is None):\n", " return \"No Voice Detected in file 2\"\n", "\n", " cosi = cosine_similarity(e1, e2)\n", "\n", " if(cosi > threshold):\n", " return f\"Same Speaker\" \n", " else:\n", " return f\"Different Speaker\" \n", "\n", "# Define the Gradio interface\n", "def gradio_interface(audio1, audio2, threshold):\n", " output_text = process_audio(audio1, audio2, threshold)\n", " return output_text\n", "\n", "# Create the Gradio interface with microphone inputs\n", "iface = gr.Interface(\n", " fn=gradio_interface,\n", " inputs=[gr.Audio(\"microphone\", type=\"numpy\", label=\"Audio File 1\"),\n", " gr.Audio(\"microphone\", type=\"numpy\", label=\"Audio File 2\"),\n", " gr.Slider(0.0, 1.0, value=0.85, step=0.01, label=\"Threshold\")\n", " ],\n", " outputs=\"text\",\n", " title=\"LSTM Based Speaker Verification\",\n", " description=\"Record two audio files and get the text output from the model.\"\n", ")\n", "\n", "# Launch the interface\n", "iface.launch(share=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "a098495c-9e7b-4232-86fc-55a1890c5e27", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "b99a253e-9b91-4210-b934-8bd1b6a2d912", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }