wav2vec / app.py
spycoder's picture
Update app.py
15a6f16
raw
history blame
3.6 kB
import gradio as gr
import torch
import soundfile as sf
import os
import numpy as np
import os
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from collections import Counter
device = torch.device("cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
model_path = "dysarthria_classifier12.pth"
# model_path = '/home/user/app/dysarthria_classifier12.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# if os.path.exists(model_path):
# print(f"Loading saved model {model_path}")
# model.load_state_dict(torch.load(model_path))
title = "Upload an mp3 file for parkinsons detection! (Thai Language)"
description = """
The model was trained on Thai audio recordings with the following sentences: \n
ชาวไร่ตัดต้นสนทำท่อนซุง\n
ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n
อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n
เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n
“อาาาาาาาาาาา”\n
“อีีีีีีีีี”\n
“อาาาา” (ดังขึ้นเรื่อยๆ)\n
“อาา อาาา อาาาาา”\n
<img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
"""
def predict(file_upload,microphone):
max_length = 100000
file_path =file_upload
if (microphone is not None) and (file_upload is not None):
warn_output = (
"WARNING: You've uploaded an audio file and used the microphone. "
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
)
elif (microphone is None) and (file_upload is None):
return "ERROR: You have to either use the microphone or upload an audio file"
if(microphone is not None):
file_path = microphone
if(file_upload is not None):
file_path = file_upload
model.eval()
with torch.no_grad():
wav_data, _ = sf.read(file_path)
inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values.squeeze(0)
if max_length - input_values.shape[-1] > 0:
input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
else:
input_values = input_values[:max_length]
input_values = input_values.unsqueeze(0).to(device)
inputs = {"input_values": input_values}
logits = model(**inputs).logits
logits = logits.squeeze()
predicted_class_id = torch.argmax(logits, dim=-1).item()
return "You probably have SP" if predicted_class_id == 1 else "You probably don't have SP"
gr.Interface(
fn=predict,
inputs=[
gr.inputs.Audio(source="upload", type="filepath", optional=True),
gr.inputs.Audio(source="microphone", type="filepath", optional=True),
],
outputs="text",
title=title,
description=description,
).launch()
# iface = gr.Interface(fn=predict, inputs="file", outputs="text")
# iface.launch()