riteshkr's picture
Update app.py
3f04ac3 verified
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
# Check if a GPU is available and set the device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the Whisper ASR model
whisper_model_id = "riteshkr/quantized-whisper-large-v3"
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_id)
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_id)
# Set the language to English using forced_decoder_ids
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=whisper_processor.tokenizer,
feature_extractor=whisper_processor.feature_extractor,
device=0 if torch.cuda.is_available() else -1
)
# Load the SpeechT5 TTS model
tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
tts_model.to(device)
vocoder.to(device)
# Load speaker embeddings for TTS
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
# Set target data type and max range for speech
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
# Define the transcription function (Whisper ASR)
def transcribe_speech(filepath):
batch_size = 16 if torch.cuda.is_available() else 4
output = whisper_pipe(
filepath,
max_new_tokens=256,
generate_kwargs={"forced_decoder_ids": forced_decoder_ids},
chunk_length_s=30,
batch_size=batch_size,
)
return output["text"]
# Define the synthesis function (SpeechT5 TTS)
def synthesise(text):
inputs = tts_processor(text=text, return_tensors="pt")
speech = tts_model.generate_speech(
inputs["input_ids"].to(device), speaker_embeddings, vocoder=vocoder
)
return speech.cpu()
# Define the speech-to-speech translation function
def speech_to_speech_translation(audio):
# Transcribe speech
translated_text = transcribe_speech(audio)
# Synthesize speech
synthesised_speech = synthesise(translated_text)
# Convert speech to desired format
synthesised_speech = (synthesised_speech.numpy() * max_range).astype(np.int16)
return 16000, synthesised_speech
# Define the Gradio interfaces for microphone input and file upload
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
)
# Define the Gradio interfaces for transcription
mic_transcribe = gr.Interface(
fn=transcribe_speech,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Textbox(),
)
file_transcribe = gr.Interface(
fn=transcribe_speech,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Textbox(),
)
# Create the app using Gradio Blocks with tabbed interfaces
demo = gr.Blocks()
with demo:
gr.TabbedInterface(
[
mic_transcribe, file_transcribe, # For transcription
mic_translate, file_translate # For speech-to-speech translation
],
[
"Transcribe Microphone", "Transcribe Audio File",
"Translate Microphone", "Translate Audio File"
]
)
# Launch the app with debugging enabled
if __name__ == "__main__":
demo.launch(debug=True, share=True)