Spaces:
Running
on
A10G
Running
on
A10G
import os | |
import shutil | |
from huggingface_hub import snapshot_download | |
import gradio as gr | |
from gradio_client import Client, handle_file | |
from mutagen.mp3 import MP3 | |
from pydub import AudioSegment | |
from PIL import Image | |
import ffmpeg | |
os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
from scripts.inference import inference_process | |
import argparse | |
import uuid | |
is_shared_ui = True if "fffiloni/tts-hallo-talking-portrait" in os.environ['SPACE_ID'] else False | |
hallo_dir = snapshot_download(repo_id="fudan-generative-ai/hallo", local_dir="pretrained_models") | |
AUDIO_MAX_DURATION = 4000 | |
############# | |
# UTILITIES # | |
############# | |
def is_mp3(file_path): | |
try: | |
audio = MP3(file_path) | |
return True | |
except Exception as e: | |
return False | |
def convert_mp3_to_wav(mp3_file_path, wav_file_path): | |
# Load the MP3 file | |
audio = AudioSegment.from_mp3(mp3_file_path) | |
# Export as WAV file | |
audio.export(wav_file_path, format="wav") | |
return wav_file_path | |
def trim_audio(file_path, output_path, max_duration): | |
# Load the audio file | |
audio = AudioSegment.from_wav(file_path) | |
# Check the length of the audio in milliseconds | |
audio_length = len(audio) | |
# If the audio is longer than the maximum duration, trim it | |
if audio_length > max_duration: | |
trimmed_audio = audio[:max_duration] | |
else: | |
trimmed_audio = audio | |
# Export the trimmed audio to a new file | |
trimmed_audio.export(output_path, format="wav") | |
return output_path | |
def add_silence_to_wav(wav_file_path, duration_s=1): | |
# Load the WAV file | |
audio = AudioSegment.from_wav(wav_file_path) | |
# Create 1 second of silence | |
silence = AudioSegment.silent(duration=duration_s * 1000) # duration is in milliseconds | |
# Add silence to the end of the audio file | |
audio_with_silence = audio + silence | |
# Export the modified audio | |
audio_with_silence.export(wav_file_path, format="wav") | |
return wav_file_path | |
def check_mp3(file_path): | |
if is_mp3(file_path): | |
unique_id = uuid.uuid4() | |
wav_file_path = f"{os.path.splitext(file_path)[0]}-{unique_id}.wav" | |
converted_audio = convert_mp3_to_wav(file_path, wav_file_path) | |
print(f"File converted to {wav_file_path}") | |
return converted_audio, gr.update(value=converted_audio, visible=True) | |
else: | |
print("The file is not an MP3 file.") | |
return file_path, gr.update(value=file_path, visible=True) | |
def check_and_convert_webp_to_png(input_path, output_path): | |
try: | |
# Open the image file | |
with Image.open(input_path) as img: | |
# Check if the image is in WebP format | |
if img.format == 'WEBP': | |
# Convert and save as PNG | |
img.save(output_path, 'PNG') | |
print(f"Converted {input_path} to {output_path}") | |
return output_path | |
else: | |
print(f"The file {input_path} is not in WebP format.") | |
return input_path | |
except IOError: | |
print(f"Cannot open {input_path}. The file might not exist or is not an image.") | |
def convert_user_uploded_webp(input_path): | |
# convert to png if necessary | |
input_file = input_path | |
unique_id = uuid.uuid4() | |
output_file = f"converted_to_png_portrait-{unique_id}.png" | |
ready_png = check_and_convert_webp_to_png(input_file, output_file) | |
print(f"PORTRAIT PNG FILE: {ready_png}") | |
return ready_png | |
def clear_audio_elms(): | |
return gr.update(value=None, visible=False) | |
def change_video_codec(input_file, output_file, codec='libx264', audio_codec='aac'): | |
try: | |
( | |
ffmpeg | |
.input(input_file) | |
.output(output_file, vcodec=codec, acodec=audio_codec) | |
.run(overwrite_output=True) | |
) | |
print(f'Successfully changed codec of {input_file} and saved as {output_file}') | |
except ffmpeg.Error as e: | |
print(f'Error occurred: {e.stderr.decode()}') | |
####################################################### | |
# Gradio APIs for optional image and voice generation # | |
####################################################### | |
def generate_portrait(prompt_image): | |
if prompt_image is None or prompt_image == "": | |
raise gr.Error("Can't generate a portrait without a prompt !") | |
try: | |
client = Client("ByteDance/SDXL-Lightning") | |
except: | |
raise gr.Error(f"ByteDance/SDXL-Lightning space's api might not be ready, please wait, or upload an image instead.") | |
result = client.predict( | |
prompt = prompt_image, | |
ckpt = "4-Step", | |
api_name = "/generate_image" | |
) | |
print(result) | |
# convert to png if necessary | |
input_file = result | |
unique_id = uuid.uuid4() | |
output_file = f"converted_to_png_portrait-{unique_id}.png" | |
ready_png = check_and_convert_webp_to_png(input_file, output_file) | |
print(f"PORTRAIT PNG FILE: {ready_png}") | |
return ready_png | |
def generate_voice_with_parler(prompt_audio, voice_description): | |
if prompt_audio is None or prompt_audio == "" : | |
raise gr.Error(f"Can't generate a voice without text to synthetize !") | |
if voice_description is None or voice_description == "": | |
gr.Info( | |
"For better control, You may want to provide a voice character description next time.", | |
duration = 10, | |
visible = True | |
) | |
try: | |
client = Client("parler-tts/parler_tts_mini") | |
except: | |
raise gr.Error(f"parler-tts/parler_tts_mini space's api might not be ready, please wait, or upload an audio instead.") | |
result = client.predict( | |
text = prompt_audio, | |
description = voice_description, | |
api_name = "/gen_tts" | |
) | |
print(result) | |
return result, gr.update(value=result, visible=True) | |
def get_whisperspeech(prompt_audio_whisperspeech, audio_to_clone): | |
try: | |
client = Client("collabora/WhisperSpeech") | |
except: | |
raise gr.Error(f"collabora/WhisperSpeech space's api might not be ready, please wait, or upload an audio instead.") | |
result = client.predict( | |
multilingual_text = prompt_audio_whisperspeech, | |
speaker_audio = handle_file(audio_to_clone), | |
speaker_url = "", | |
cps = 14, | |
api_name = "/whisper_speech_demo" | |
) | |
print(result) | |
return result, gr.update(value=result, visible=True) | |
def get_maskGCT_TTS(prompt_audio_maskGCT, audio_to_clone): | |
try: | |
client = Client("amphion/maskgct") | |
except: | |
raise gr.Error(f"amphion/maskgct space's api might not be ready, please wait, or upload an audio instead.") | |
result = client.predict( | |
prompt_wav = handle_file(audio_to_clone), | |
target_text = prompt_audio_maskGCT, | |
target_len=-1, | |
n_timesteps=25, | |
api_name="/predict" | |
) | |
print(result) | |
return result, gr.update(value=result, visible=True) | |
######################## | |
# TALKING PORTRAIT GEN # | |
######################## | |
def run_hallo(source_image, driving_audio, progress=gr.Progress(track_tqdm=True)): | |
unique_id = uuid.uuid4() | |
args = argparse.Namespace( | |
config = 'configs/inference/default.yaml', | |
source_image = source_image, | |
driving_audio = driving_audio, | |
output = f'output-{unique_id}.mp4', | |
pose_weight = 1.0, | |
face_weight = 1.0, | |
lip_weight = 1.0, | |
face_expand_ratio = 1.2, | |
checkpoint = None | |
) | |
inference_process(args) | |
return f'output-{unique_id}.mp4' | |
def generate_talking_portrait(portrait, voice, progress=gr.Progress(track_tqdm=True)): | |
if portrait is None: | |
raise gr.Error("Please provide a portrait to animate.") | |
if voice is None: | |
raise gr.Error("Please provide audio (4 seconds max).") | |
if is_shared_ui : | |
# Trim audio to AUDIO_MAX_DURATION for better shared experience with community | |
input_file = voice | |
unique_id = uuid.uuid4() | |
trimmed_output_file = f"-{unique_id}.wav" | |
trimmed_output_file = trim_audio(input_file, trimmed_output_file, AUDIO_MAX_DURATION) | |
voice = trimmed_output_file | |
# Add 1 second of silence at the end to avoid last word being cut by hallo | |
ready_audio = add_silence_to_wav(voice) | |
print(f"1 second of silence added to {voice}") | |
# Call hallo | |
talking_portrait_vid = run_hallo(portrait, ready_audio) | |
# Convert video to readable format | |
final_output_file = f"converted_{talking_portrait_vid}" | |
change_video_codec(talking_portrait_vid, final_output_file) | |
return final_output_file | |
css = ''' | |
#col-container { | |
margin: 0 auto; | |
} | |
#column-names { | |
margin-top: 50px; | |
} | |
#main-group { | |
background-color: none; | |
} | |
.tabs { | |
background-color: unset; | |
} | |
#image-block { | |
flex: 1; | |
} | |
#video-block { | |
flex: 9; | |
} | |
#audio-block, #audio-clone-elm, audio-clone-elm-maskGCT { | |
flex: 1; | |
} | |
div#audio-clone-elm > .audio-container > button { | |
height: 180px!important; | |
} | |
div#audio-clone-elm > .audio-container > button > .wrap { | |
font-size: 0.9em; | |
} | |
div#audio-clone-elm-maskGCT > .audio-container > button { | |
height: 180px!important; | |
} | |
div#audio-clone-elm-maskGCT > .audio-container > button > .wrap { | |
font-size: 0.9em; | |
} | |
#text-synth, #voice-desc{ | |
height: 130px; | |
} | |
#text-synth-wsp { | |
height: 120px; | |
} | |
#text-synth-maskGCT { | |
height: 120px; | |
} | |
#audio-column, #result-column { | |
display: flex; | |
} | |
#gen-voice-btn { | |
flex: 1; | |
} | |
#parler-tab, #whisperspeech-tab, #maskGCT-tab { | |
padding: 0; | |
} | |
#main-submit{ | |
flex: 1; | |
} | |
#pro-tips { | |
margin-top: 50px; | |
} | |
div#warning-ready { | |
background-color: #ecfdf5; | |
padding: 0 16px 16px; | |
margin: 20px 0; | |
color: #030303!important; | |
} | |
div#warning-ready > .gr-prose > h2, div#warning-ready > .gr-prose > p { | |
color: #057857!important; | |
} | |
div#warning-duplicate { | |
background-color: #ebf5ff; | |
padding: 0 16px 16px; | |
margin: 20px 0; | |
color: #030303!important; | |
} | |
div#warning-duplicate > .gr-prose > h2, div#warning-duplicate > .gr-prose > p { | |
color: #0f4592!important; | |
} | |
div#warning-duplicate strong { | |
color: #0f4592; | |
} | |
p.actions { | |
display: flex; | |
align-items: center; | |
margin: 20px 0; | |
} | |
div#warning-duplicate .actions a { | |
display: inline-block; | |
margin-right: 10px; | |
} | |
.dark #warning-duplicate { | |
background-color: #0c0c0c !important; | |
border: 1px solid white !important; | |
} | |
''' | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(""" | |
# TTS x Hallo Talking Portrait Generator | |
This demo allows you to generate a talking portrait with the help of several open-source projects: SDXL Lightning | Parler TTS | WhisperSpeech | Hallo | |
To let the community try and enjoy this demo, video length is limited to 4 seconds audio maximum. | |
Duplicate this space to skip the queue and get unlimited video duration. 4-5 seconds of audio will take ~5 minutes per inference, please be patient. | |
""") | |
with gr.Row(elem_id="column-names"): | |
gr.Markdown("## 1. Load Portrait") | |
gr.Markdown("## 2. Load Voice") | |
gr.Markdown("## 3. Result") | |
with gr.Group(elem_id="main-group"): | |
with gr.Row(): | |
with gr.Column(): | |
portrait = gr.Image( | |
sources = ["upload"], | |
type = "filepath", | |
format = "png", | |
elem_id = "image-block" | |
) | |
prompt_image = gr.Textbox( | |
label = "Generate image", | |
lines = 2, | |
max_lines = 2 | |
) | |
gen_image_btn = gr.Button("Generate portrait (optional)") | |
with gr.Column(elem_id="audio-column"): | |
voice = gr.Audio( | |
type = "filepath", | |
elem_id = "audio-block" | |
) | |
preprocess_audio_file = gr.File(visible=False) | |
with gr.Tab("Parler TTS", elem_id="parler-tab"): | |
prompt_audio = gr.Textbox( | |
label = "Text to synthetize", | |
lines = 3, | |
max_lines = 3, | |
elem_id = "text-synth" | |
) | |
voice_description = gr.Textbox( | |
label = "Voice description", | |
lines = 3, | |
max_lines = 3, | |
elem_id = "voice-desc" | |
) | |
gen_voice_btn = gr.Button("Generate voice (optional)") | |
with gr.Tab("WhisperSpeech", elem_id="whisperspeech-tab"): | |
prompt_audio_whisperspeech = gr.Textbox( | |
label = "Text to synthetize", | |
lines = 2, | |
max_lines = 2, | |
elem_id = "text-synth-wsp" | |
) | |
audio_to_clone = gr.Audio( | |
label = "Voice to clone", | |
type = "filepath", | |
elem_id = "audio-clone-elm" | |
) | |
gen_wsp_voice_btn = gr.Button("Generate voice clone (optional)") | |
with gr.Tab("MaskGCT TTS", elem_id="maskGCT-tab"): | |
prompt_audio_maskGCT = gr.Textbox( | |
label = "Text to synthetize", | |
lines = 2, | |
max_lines = 2, | |
elem_id = "text-synth-maskGCT" | |
) | |
audio_to_clone_maskGCT = gr.Audio( | |
label = "Voice to clone", | |
type = "filepath", | |
elem_id = "audio-clone-elm-maskGCT" | |
) | |
gen_maskGCT_voice_btn = gr.Button("Generate voice clone (optional)") | |
with gr.Column(elem_id="result-column"): | |
result = gr.Video( | |
elem_id="video-block" | |
) | |
submit_btn = gr.Button("Go talking Portrait !", elem_id="main-submit") | |
with gr.Row(elem_id="pro-tips"): | |
gr.Markdown(""" | |
# Hallo Pro Tips: | |
Hallo has a few simple requirements for input data: | |
For the source image: | |
1. It should be cropped into squares. | |
2. The face should be the main focus, making up 50%-70% of the image. | |
3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles). | |
For the driving audio: | |
1. It must be in WAV format. | |
2. It must be in English since our training datasets are only in this language. | |
3. Ensure the vocals are clear; background music is acceptable. | |
""") | |
gr.Markdown(""" | |
# TTS Pro Tips: | |
For Parler TTS: | |
- Include the term "very clear audio" to generate the highest quality audio, and "very noisy audio" for high levels of background noise | |
- Punctuation can be used to control the prosody of the generations, e.g. use commas to add small breaks in speech | |
- The remaining speech features (gender, speaking rate, pitch and reverberation) can be controlled directly through the prompt | |
For WhisperSpeech: | |
WhisperSpeech is able to quickly clone a voice from an audio sample. | |
- Upload a voice sample in the WhisperSpeech tab | |
- Add text to synthetize, hit Generate voice clone button | |
""") | |
portrait.upload( | |
fn = convert_user_uploded_webp, | |
inputs = [portrait], | |
outputs = [portrait], | |
queue = False, | |
show_api = False | |
) | |
voice.upload( | |
fn = check_mp3, | |
inputs = [voice], | |
outputs = [voice, preprocess_audio_file], | |
queue = False, | |
show_api = False | |
) | |
voice.clear( | |
fn = clear_audio_elms, | |
inputs = None, | |
outputs = [preprocess_audio_file], | |
queue = False, | |
show_api = False | |
) | |
gen_image_btn.click( | |
fn = generate_portrait, | |
inputs = [prompt_image], | |
outputs = [portrait], | |
queue = False, | |
show_api = False | |
) | |
gen_voice_btn.click( | |
fn = generate_voice_with_parler, | |
inputs = [prompt_audio, voice_description], | |
outputs = [voice, preprocess_audio_file], | |
queue = False, | |
show_api = False | |
) | |
gen_wsp_voice_btn.click( | |
fn = get_whisperspeech, | |
inputs = [prompt_audio_whisperspeech, audio_to_clone], | |
outputs = [voice, preprocess_audio_file], | |
queue = False, | |
show_api = False | |
) | |
gen_maskGCT_voice_btn.click( | |
fn = get_maskGCT_TTS, | |
inputs = [prompt_audio_maskGCT, audio_to_clone_maskGCT], | |
outputs = [voice, preprocess_audio_file], | |
queue = False, | |
show_api = False | |
) | |
submit_btn.click( | |
fn = generate_talking_portrait, | |
inputs = [portrait, voice], | |
outputs = [result], | |
show_api = False | |
) | |
demo.queue(max_size=2).launch(show_error=True, show_api=False) |