diff --git "a/try.ipynb" "b/try.ipynb"
new file mode 100644--- /dev/null
+++ "b/try.ipynb"
@@ -0,0 +1,242 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "ModuleNotFoundError",
+ "evalue": "No module named 'model'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[5], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mipd\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m nn\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodel\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CNNEmotinoalClassifier\n",
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'model'"
+ ]
+ }
+ ],
+ "source": [
+ "import gradio as gr\n",
+ "import torch\n",
+ "# from lr_ed.model import CNNEmotinoalClassifier\n",
+ "import torchaudio\n",
+ "import IPython.display as ipd\n",
+ "from torch import nn\n",
+ "from model import CNNEmotinoalClassifier"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "CNNEmotinoalClassifier(\n",
+ " (conv1): Sequential(\n",
+ " (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): ReLU()\n",
+ " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (conv2): Sequential(\n",
+ " (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): ReLU()\n",
+ " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (conv3): Sequential(\n",
+ " (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))\n",
+ " (1): ReLU()\n",
+ " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (conv4): Sequential(\n",
+ " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))\n",
+ " (1): ReLU()\n",
+ " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (flatten): Flatten(start_dim=1, end_dim=-1)\n",
+ " (fully_connected): Sequential(\n",
+ " (0): Linear(in_features=32000, out_features=128, bias=True)\n",
+ " (1): ReLU()\n",
+ " (2): Linear(in_features=128, out_features=64, bias=True)\n",
+ " (3): ReLU()\n",
+ " (4): Linear(in_features=64, out_features=32, bias=True)\n",
+ " (5): ReLU()\n",
+ " (6): Linear(in_features=32, out_features=16, bias=True)\n",
+ " (7): ReLU()\n",
+ " (8): Linear(in_features=16, out_features=6, bias=True)\n",
+ " )\n",
+ " (softmax): Softmax(dim=1)\n",
+ ")"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model = CNNEmotinoalClassifier()\n",
+ "model.load_state_dict(torch.load('/raid/adal_abilbekov/lr_ed/CNN_emotional_classifier/cnn_class_17.pt'))\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 47,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# path = '/raid/adal_abilbekov/emodiff_try_2/Emo_diff/demo_190224/Akzhol_happy.wav'\n",
+ "# path = '/raid/adal_abilbekov/emodiff_try_2/Emo_diff/demo_190224/Akzhol_neutral.wav'\n",
+ "path = '/raid/adal_abilbekov/emodiff_try_2/Emo_diff/demo_190224/Marzhan_happy.wav'\n",
+ "waveform, sr = torchaudio.load(path)\n",
+ "ipd.Audio(data=waveform, rate=sr)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_melspec = torchaudio.transforms.MelSpectrogram(\n",
+ " sample_rate= 22050,\n",
+ " n_fft = 1024,\n",
+ " hop_length = 512,\n",
+ " n_mels=64\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def _get_right_pad(target_waveform, waveform):\n",
+ " target_waveform = target_waveform\n",
+ " waveform_samples_number = waveform.shape[1]\n",
+ " if waveform_samples_number < target_waveform:\n",
+ " right_pad = target_waveform - waveform_samples_number\n",
+ " padding_touple = (0, right_pad)\n",
+ " waveform_padded = nn.functional.pad(waveform, padding_touple)\n",
+ " else:\n",
+ " waveform_padded = waveform\n",
+ " return waveform_padded"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "waveform = _get_right_pad(400384, waveform)\n",
+ "input_x = to_melspec(waveform)\n",
+ "input_x = torch.unsqueeze(input_x, dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 67,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "probs = model(input_x)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "emotions = ['happy', 'angry', 'sad', 'neutral', 'surprised', 'fear']\n",
+ "emotions = sorted(emotions)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 59,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# def get_probs(input_x, emotions):\n",
+ "# probs = model(input_x)\n",
+ "# prediction = emotions[probs.argmax(dim=1).item()]\n",
+ "# return prediction, dict(zip(emotions, list(map(float, probs))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[2.9495e-18, 6.7292e-20, 9.9882e-01, 2.4566e-18, 1.0296e-12, 1.1847e-03]],\n",
+ " grad_fn=)"
+ ]
+ },
+ "execution_count": 70,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "probs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "asr_hug",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}