Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
import librosa | |
from glob import glob | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForTokenClassification, TokenClassificationPipeline, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM | |
SAMPLE_RATE = 16_000 | |
models = {} | |
models_paths = { | |
"en-US": "jonatasgrosman/wav2vec2-large-xlsr-53-english", | |
"fr-FR": "jonatasgrosman/wav2vec2-large-xlsr-53-french", | |
"nl-NL": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", | |
"pl-PL": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", | |
"it-IT": "jonatasgrosman/wav2vec2-large-xlsr-53-italian", | |
"ru-RU": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", | |
"pt-PT": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", | |
"de-DE": "jonatasgrosman/wav2vec2-large-xlsr-53-german", | |
"es-ES": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", | |
"ja-JP": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", | |
"ar-SA": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", | |
"fi-FI": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", | |
"hu-HU": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", | |
"zh-CN": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", | |
"el-GR": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", | |
} | |
# Classifier Intent | |
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification' | |
tokenizer_intent = AutoTokenizer.from_pretrained(model_name) | |
model_intent = AutoModelForSequenceClassification.from_pretrained(model_name) | |
classifier_intent = TextClassificationPipeline(model=model_intent, tokenizer=tokenizer_intent) | |
# Classifier Language | |
model_name = 'qanastek/51-languages-classifier' | |
tokenizer_langs = AutoTokenizer.from_pretrained(model_name) | |
model_langs = AutoModelForSequenceClassification.from_pretrained(model_name) | |
classifier_language = TextClassificationPipeline(model=model_langs, tokenizer=tokenizer_langs) | |
# NER Extractor | |
model_name = 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU' | |
tokenizer_ner = AutoTokenizer.from_pretrained(model_name) | |
model_ner = AutoModelForTokenClassification.from_pretrained(model_name) | |
predict_ner = TokenClassificationPipeline(model=model_ner, tokenizer=tokenizer_ner) | |
EXAMPLE_DIR = './wavs/' | |
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav'))) | |
examples = [[e, e.split("=")[0].split("/")[-1]] for e in examples] | |
def transcribe(audio_path, lang_code): | |
speech_array, sampling_rate = librosa.load(audio_path, sr=16_000) | |
if lang_code not in models: | |
models[lang_code] = {} | |
models[lang_code]["processor"] = Wav2Vec2Processor.from_pretrained(models_paths[lang_code]) | |
models[lang_code]["model"] = Wav2Vec2ForCTC.from_pretrained(models_paths[lang_code]) | |
# Load model | |
processor_asr = models[lang_code]["processor"] | |
model_asr = models[lang_code]["model"] | |
inputs = processor_asr(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
logits = model_asr(inputs.input_values, attention_mask=inputs.attention_mask).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
return processor_asr.batch_decode(predicted_ids)[0] | |
def getUniform(text): | |
idx = 0 | |
res = {} | |
for t in text: | |
raw = t["entity"].replace("B-","").replace("I-","") | |
word = t["word"].replace("β","") | |
if "B-" in t["entity"]: | |
res[f"{raw}|{idx}"] = [word] | |
idx += 1 | |
else: | |
res[f"{raw}|{idx}"].append(word) | |
res = [(r.split("|")[0], res[r]) for r in res] | |
return res | |
def predict(wav_file, lang_code): | |
if lang_code not in models_paths.keys(): | |
return { | |
"The language code is unknown!" | |
} | |
text = transcribe(wav_file, lang_code).replace("apizza","a pizza") + " ." | |
intent_class = classifier_intent(text)[0]["label"] | |
language_class = classifier_language(text)[0]["label"] | |
named_entities = getUniform(predict_ner(text)) | |
return { | |
"text": text, | |
"language": language_class, | |
"intent_class": intent_class, | |
"named_entities": named_entities, | |
} | |
iface = gr.Interface( | |
predict, | |
title='Alexa Clone π©βπΌ πͺ π€ Multilingual NLU', | |
description='Upload your wav file to test the models (<i>First execution take about 20s to 30s, then next run in less than 1s</i>)', | |
# thumbnail="", | |
inputs=[ | |
gr.inputs.Audio(label='wav file', source='microphone', type='filepath'), | |
gr.inputs.Dropdown(choices=list(models_paths.keys())), | |
], | |
outputs=[ | |
gr.outputs.JSON(label='ASR -> Slot Recognition + Intent Classification + Language Classification'), | |
], | |
examples=examples, | |
article='Made with β€οΈ by <a href="https://www.linkedin.com/in/yanis-labrak-8a7412145/" target="_blank">Yanis Labrak</a> thanks to π€', | |
) | |
iface.launch() |