import os import cv2 import numpy as np import torch import subprocess current_dir = os.getcwd() parent_dir = os.path.dirname(os.path.dirname(current_dir)) os.environ["FFMPEG_PATH"] = f"{parent_dir}/ffmpeg-7.0.2-amd64-static" # Function to install dependencies # def install_dependencies(): # try: # # Run the pip install commands one by one # subprocess.run("pip install --no-cache-dir -U openmim", shell=True, check=True) # subprocess.run("mim install mmengine", shell=True, check=True) # subprocess.run('mim install "mmcv>=2.0.1"', shell=True, check=True) # subprocess.run('mim install "mmdet>=3.1.0"', shell=True, check=True) # subprocess.run('mim install "mmpose>=1.1.0"', shell=True, check=True) # print("All dependencies installed successfully!") # except subprocess.CalledProcessError as e: # print(f"An error occurred: {e}") # exit(1) # # Call this function before starting the Gradio app # install_dependencies() ffmpeg_path = os.getenv('FFMPEG_PATH') if ffmpeg_path is None: print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static") elif ffmpeg_path not in os.getenv('PATH'): print("add ffmpeg to path") print(f"torch version : {torch.__version__}") print(f"torch cuda version : {torch.version.cuda}") os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}" from musetalk.whisper.audio2feature import Audio2Feature from musetalk.models.vae import VAE from musetalk.models.unet import UNet,PositionalEncoding def load_all_model(): audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt") vae = VAE(model_path = "./models/sd-vae-ft-mse/") unet = UNet(unet_config="./models/musetalk/musetalk.json", model_path ="./models/musetalk/pytorch_model.bin") pe = PositionalEncoding(d_model=384) return audio_processor,vae,unet,pe def get_file_type(video_path): _, ext = os.path.splitext(video_path) if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']: return 'image' elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']: return 'video' else: return 'unsupported' def get_video_fps(video_path): video = cv2.VideoCapture(video_path) fps = video.get(cv2.CAP_PROP_FPS) video.release() return fps def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0): whisper_batch, latent_batch = [], [] for i, w in enumerate(whisper_chunks): idx = (i+delay_frame)%len(vae_encode_latents) latent = vae_encode_latents[idx] whisper_batch.append(w) latent_batch.append(latent) if len(latent_batch) >= batch_size: whisper_batch = np.asarray(whisper_batch) latent_batch = torch.cat(latent_batch, dim=0) yield whisper_batch, latent_batch whisper_batch, latent_batch = [], [] # the last batch may smaller than batch size if len(latent_batch) > 0: whisper_batch = np.asarray(whisper_batch) latent_batch = torch.cat(latent_batch, dim=0) yield whisper_batch, latent_batch