import gradio as gr import librosa import soundfile import tempfile import os import uuid import json import jieba import nemo.collections.asr as nemo_asr from nemo.collections.asr.models import ASRModel from nemo.utils import logging from align import main, AlignmentConfig, ASSFileConfig SAMPLE_RATE = 16000 # Pre-download and cache the model in disk space logging.setLevel(logging.ERROR) for tmp_model_name in [ "stt_en_fastconformer_hybrid_large_pc", "stt_de_fastconformer_hybrid_large_pc", "stt_es_fastconformer_hybrid_large_pc", "stt_fr_conformer_ctc_large", "stt_zh_citrinet_1024_gamma_0_25", ]: tmp_model = ASRModel.from_pretrained(tmp_model_name, map_location='cpu') del tmp_model logging.setLevel(logging.INFO) def get_audio_data_and_duration(file): data, sr = librosa.load(file) if sr != SAMPLE_RATE: data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) # monochannel data = librosa.to_mono(data) duration = librosa.get_duration(y=data, sr=SAMPLE_RATE) return data, duration def get_char_tokens(text, model): tokens = [] for character in text: if character in model.decoder.vocabulary: tokens.append(model.decoder.vocabulary.index(character)) else: tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token) return tokens def get_S_prime_and_T(text, model_name, model, audio_duration): # estimate T if "citrinet" in model_name or "_fastconformer_" in model_name: output_timestep_duration = 0.08 elif "_conformer_" in model_name: output_timestep_duration = 0.04 elif "quartznet" in model_name: output_timestep_duration = 0.02 else: raise RuntimeError("unexpected model name") T = int(audio_duration / output_timestep_duration) + 1 # calculate S_prime = num tokens + num repetitions if hasattr(model, 'tokenizer'): all_tokens = model.tokenizer.text_to_ids(text) elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based all_tokens = get_char_tokens(text, model) else: raise RuntimeError("cannot obtain tokens from this model") n_token_repetitions = 0 for i_tok in range(1, len(all_tokens)): if all_tokens[i_tok] == all_tokens[i_tok - 1]: n_token_repetitions += 1 S_prime = len(all_tokens) + n_token_repetitions return S_prime, T def hex_to_rgb_list(hex_string): hex_string = hex_string.lstrip("#") r = int(hex_string[:2], 16) g = int(hex_string[2:4], 16) b = int(hex_string[4:], 16) return [r, g, b] def delete_mp4s_except_given_filepath(filepath): files_in_dir = os.listdir() mp4_files_in_dir = [x for x in files_in_dir if x.endswith(".mp4")] for mp4_file in mp4_files_in_dir: if mp4_file != filepath: os.remove(mp4_file) def align(lang, Microphone, File_Upload, text, col1, col2, col3, progress=gr.Progress()): # Create utt_id, specify output_video_filepath and delete any MP4s # that are not that filepath. These stray MP4s can be created # if a user refreshes or exits the page while this 'align' function is executing. # This deletion will not delete any other users' video as long as this 'align' function # is run one at a time. utt_id = uuid.uuid4() output_video_filepath = f"{utt_id}.mp4" delete_mp4s_except_given_filepath(output_video_filepath) output_info = "" progress(0, desc="Validating input") # choose model if lang in ["en", "de", "es"]: model_name = f"stt_{lang}_fastconformer_hybrid_large_pc" elif lang in ["fr"]: model_name = f"stt_{lang}_conformer_ctc_large" elif lang in ["zh"]: model_name = f"stt_{lang}_citrinet_1024_gamma_0_25" # decide which of Mic / File_Upload is used as input & do error handling if (Microphone is not None) and (File_Upload is not None): raise gr.Error("Please use either the microphone or file upload input - not both") elif (Microphone is None) and (File_Upload is None): raise gr.Error("You have to either use the microphone or upload an audio file") elif Microphone is not None: file = Microphone else: file = File_Upload # check audio is not too long audio_data, duration = get_audio_data_and_duration(file) if duration > 4 * 60: raise gr.Error( f"Detected that uploaded audio has duration {duration/60:.1f} mins - please only upload audio of less than 4 mins duration" ) # loading model progress(0.1, desc="Loading speech recognition model") model = ASRModel.from_pretrained(model_name) if text: # check input text is not too long compared to audio S_prime, T = get_S_prime_and_T(text, model_name, model, duration) if S_prime > T: raise gr.Error( f"The number of tokens in the input text is too long compared to the duration of the audio." f" This model can handle {T} tokens + token repetitions at most. You have provided {S_prime} tokens + token repetitions. " f" (Adjacent tokens that are not in the model's vocabulary are also counted as a token repetition.)" ) with tempfile.TemporaryDirectory() as tmpdir: audio_path = os.path.join(tmpdir, f'{utt_id}.wav') soundfile.write(audio_path, audio_data, SAMPLE_RATE) # getting the text if it hasn't been provided if not text: progress(0.2, desc="Transcribing audio") text = model.transcribe([audio_path])[0] if 'hybrid' in model_name: text = text[0] if text == "": raise gr.Error( "ERROR: the ASR model did not detect any speech in the input audio. Please upload audio with speech." ) output_info += ( "You did not enter any input text, so the ASR model's transcription will be used:\n" "--------------------------\n" f"{text}\n" "--------------------------\n" f"You could try pasting the transcription into the text input box, correcting any" " transcription errors, and clicking 'Submit' again." ) if lang == "zh" and " " not in text: # use jieba to add spaces between zh characters text = " ".join(jieba.cut(text)) data = { "audio_filepath": audio_path, "text": text, } manifest_path = os.path.join(tmpdir, f"{utt_id}_manifest.json") with open(manifest_path, 'w') as fout: fout.write(f"{json.dumps(data)}\n") # run alignment if "|" in text: resegment_text_to_fill_space = False else: resegment_text_to_fill_space = True alignment_config = AlignmentConfig( pretrained_name=model_name, manifest_filepath=manifest_path, output_dir=f"{tmpdir}/nfa_output/", audio_filepath_parts_in_utt_id=1, batch_size=1, use_local_attention=True, additional_segment_grouping_separator="|", # transcribe_device='cpu', # viterbi_device='cpu', save_output_file_formats=["ass"], ass_file_config=ASSFileConfig( fontsize=45, resegment_text_to_fill_space=resegment_text_to_fill_space, max_lines_per_segment=4, text_already_spoken_rgb=hex_to_rgb_list(col1), text_being_spoken_rgb=hex_to_rgb_list(col2), text_not_yet_spoken_rgb=hex_to_rgb_list(col3), ), ) progress(0.5, desc="Aligning audio") main(alignment_config) progress(0.95, desc="Saving generated alignments") if lang=="zh": # make video file from the token-level ASS file ass_file_for_video = f"{tmpdir}/nfa_output/ass/tokens/{utt_id}.ass" else: # make video file from the word-level ASS file ass_file_for_video = f"{tmpdir}/nfa_output/ass/words/{utt_id}.ass" ffmpeg_command = ( f"ffmpeg -y -i {audio_path} " "-f lavfi -i color=c=white:s=1280x720:r=50 " "-crf 1 -shortest -vcodec libx264 -pix_fmt yuv420p " f"-vf 'ass={ass_file_for_video}' " f"{output_video_filepath}" ) os.system(ffmpeg_command) return output_video_filepath, gr.update(value=output_info, visible=True), output_video_filepath def delete_non_tmp_video(video_path): if video_path: if os.path.exists(video_path): os.remove(video_path) return None with gr.Blocks(title="NeMo Forced Aligner", theme="huggingface") as demo: non_tmp_output_video_filepath = gr.State([]) with gr.Row(): with gr.Column(): gr.Markdown("# NeMo Forced Aligner") gr.Markdown( "Demo for [NeMo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) (NFA). " "Upload audio and (optionally) the text spoken in the audio to generate a video where each part of the text will be highlighted as it is spoken. ", ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Input") lang_drop = gr.Dropdown(choices=["de", "en", "es", "fr", "zh"], value="en", label="Audio language",) mic_in = gr.Audio(source="microphone", type='filepath', label="Microphone input (max 4 mins)") audio_file_in = gr.Audio(source="upload", type='filepath', label="File upload (max 4 mins)") ref_text = gr.Textbox( label="[Optional] The reference text. Use '|' separators to specify which text will appear together. " "Leave this field blank to use an ASR model's transcription as the reference text instead." ) gr.Markdown("[Optional] For fun - adjust the colors of the text in the output video") with gr.Row(): col1 = gr.ColorPicker(label="text already spoken", value="#fcba03") col2 = gr.ColorPicker(label="text being spoken", value="#bf45bf") col3 = gr.ColorPicker(label="text to be spoken", value="#3e1af0") submit_button = gr.Button("Submit") with gr.Column(scale=1): gr.Markdown("## Output") video_out = gr.Video(label="output video") text_out = gr.Textbox(label="output info", visible=False) submit_button.click( fn=align, inputs=[lang_drop, mic_in, audio_file_in, ref_text, col1, col2, col3,], outputs=[video_out, text_out, non_tmp_output_video_filepath], ).then( fn=delete_non_tmp_video, inputs=[non_tmp_output_video_filepath], outputs=None, ) demo.queue() demo.launch()