|
import gradio as gr |
|
import torch |
|
|
|
import torchaudio |
|
|
|
from torch import nn |
|
from model import CNNEmotinoalClassifier |
|
|
|
model = CNNEmotinoalClassifier() |
|
model.load_state_dict(torch.load('./cnn_class_17.pt', map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
to_melspec = torchaudio.transforms.MelSpectrogram( |
|
sample_rate= 22050, |
|
n_fft = 1024, |
|
hop_length = 512, |
|
n_mels=64 |
|
) |
|
|
|
def _get_right_pad(target_waveform, waveform): |
|
target_waveform = target_waveform |
|
waveform_samples_number = waveform.shape[1] |
|
if waveform_samples_number < target_waveform: |
|
right_pad = target_waveform - waveform_samples_number |
|
padding_touple = (0, right_pad) |
|
waveform_padded = nn.functional.pad(waveform, padding_touple) |
|
else: |
|
waveform_padded = waveform |
|
return waveform_padded |
|
|
|
def get_probs(mic=None, file=None): |
|
if mic is not None: |
|
audio = mic |
|
elif file is not None: |
|
audio = file |
|
emotions = ['happy', 'angry', 'sad', 'neutral', 'surprised', 'fear'] |
|
emotions = sorted(emotions) |
|
|
|
waveform, sr = torchaudio.load(audio) |
|
waveform = _get_right_pad(400384, waveform) |
|
input_x = to_melspec(waveform) |
|
input_x = torch.unsqueeze(input_x, dim=1) |
|
|
|
probs = model(input_x) |
|
prediction = emotions[probs.argmax(dim=1).item()] |
|
return dict(zip(emotions, list(map(float, probs[0])))) |
|
|
|
|
|
input = gr.Audio(sources="microphone", type="filepath") |
|
label = gr.Label() |
|
examples = ['Akzhol_happy.wav'] |
|
|
|
iface = gr.Interface(fn=get_probs, inputs=input, outputs=label, examples=examples) |
|
iface.launch() |