Sofian Hadiwijaya
update mmcv
7f0a0c4
raw
history blame
No virus
3.22 kB
import os
import cv2
import numpy as np
import torch
import subprocess
current_dir = os.getcwd()
parent_dir = os.path.dirname(os.path.dirname(current_dir))
os.environ["FFMPEG_PATH"] = f"{parent_dir}/ffmpeg-7.0.2-amd64-static"
# Function to install dependencies
# def install_dependencies():
# try:
# # Run the pip install commands one by one
# subprocess.run("pip install --no-cache-dir -U openmim", shell=True, check=True)
# subprocess.run("mim install mmengine", shell=True, check=True)
# subprocess.run('mim install "mmcv>=2.0.1"', shell=True, check=True)
# subprocess.run('mim install "mmdet>=3.1.0"', shell=True, check=True)
# subprocess.run('mim install "mmpose>=1.1.0"', shell=True, check=True)
# print("All dependencies installed successfully!")
# except subprocess.CalledProcessError as e:
# print(f"An error occurred: {e}")
# exit(1)
# # Call this function before starting the Gradio app
# install_dependencies()
ffmpeg_path = os.getenv('FFMPEG_PATH')
if ffmpeg_path is None:
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
elif ffmpeg_path not in os.getenv('PATH'):
print("add ffmpeg to path")
print(f"torch version : {torch.__version__}")
print(f"torch cuda version : {torch.version.cuda}")
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
from musetalk.whisper.audio2feature import Audio2Feature
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model():
audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt")
vae = VAE(model_path = "./models/sd-vae-ft-mse/")
unet = UNet(unet_config="./models/musetalk/musetalk.json",
model_path ="./models/musetalk/pytorch_model.bin")
pe = PositionalEncoding(d_model=384)
return audio_processor,vae,unet,pe
def get_file_type(video_path):
_, ext = os.path.splitext(video_path)
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
return 'image'
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
return 'video'
else:
return 'unsupported'
def get_video_fps(video_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
video.release()
return fps
def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
latent = vae_encode_latents[idx]
whisper_batch.append(w)
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = np.asarray(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = np.asarray(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch