diff --git a/.DS_Store b/.DS_Store index 4bfa35927bc56f98766c57df980c19883591cdbd..c608c17fb991145650faf16b087f0401d2d40e0f 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/aniportrait/.DS_Store b/aniportrait/.DS_Store deleted file mode 100644 index c34099dd3c099c3da89663afea31152c2a477eca..0000000000000000000000000000000000000000 Binary files a/aniportrait/.DS_Store and /dev/null differ diff --git a/aniportrait/audio2ldmk.py b/aniportrait/audio2ldmk.py deleted file mode 100644 index 3a4ff41987743161dd65f3a536e2e7dbc7c7c65d..0000000000000000000000000000000000000000 --- a/aniportrait/audio2ldmk.py +++ /dev/null @@ -1,310 +0,0 @@ -import argparse -import os -# import ffmpeg -import random -import numpy as np -import cv2 -import torch -import torchvision -from omegaconf import OmegaConf -from PIL import Image - -from src.audio_models.model import Audio2MeshModel -from src.audio_models.pose_model import Audio2PoseModel -from src.utils.audio_util import prepare_audio_feature -from src.utils.mp_utils import LMKExtractor -from src.utils.pose_util import project_points, smooth_pose_seq - - -PARTS = [ - ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)), - ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)), - ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)), - ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)), - ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)), - ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)), - ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)), - ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)), - ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)), - ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)), - ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)), - ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)), - ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)), - ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)), - ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)), -] - - -def draw_landmarks(keypoints, h, w): - - image = np.zeros((h, w, 3)) - - for name, indices, color in PARTS: - # 选择当前部分的关键点 - indices = np.array(indices) - 1 - current_part_keypoints = keypoints[indices] - - # 绘制关键点 - # for point in current_part_keypoints: - # x, y = point - # image[y, x, :] = color - - # 绘制连接线 - for i in range(len(indices) - 1): - x1, y1 = current_part_keypoints[i] - x2, y2 = current_part_keypoints[i + 1] - cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2) - - return image - - - -def convert_ldmk_to_68(mediapipe_ldmk): - return np.stack([ - # face coutour - mediapipe_ldmk[:, 234], - mediapipe_ldmk[:, 93], - mediapipe_ldmk[:, 132], - mediapipe_ldmk[:, 58], - mediapipe_ldmk[:, 172], - mediapipe_ldmk[:, 136], - mediapipe_ldmk[:, 150], - mediapipe_ldmk[:, 176], - mediapipe_ldmk[:, 152], - mediapipe_ldmk[:, 400], - mediapipe_ldmk[:, 379], - mediapipe_ldmk[:, 365], - mediapipe_ldmk[:, 397], - mediapipe_ldmk[:, 288], - mediapipe_ldmk[:, 361], - mediapipe_ldmk[:, 323], - mediapipe_ldmk[:, 454], - # right eyebrow - mediapipe_ldmk[:, 70], - mediapipe_ldmk[:, 63], - mediapipe_ldmk[:, 105], - mediapipe_ldmk[:, 66], - mediapipe_ldmk[:, 107], - # left eyebrow - mediapipe_ldmk[:, 336], - mediapipe_ldmk[:, 296], - mediapipe_ldmk[:, 334], - mediapipe_ldmk[:, 293], - mediapipe_ldmk[:, 300], - # nose - mediapipe_ldmk[:, 168], - mediapipe_ldmk[:, 6], - mediapipe_ldmk[:, 195], - mediapipe_ldmk[:, 4], - # nose down - mediapipe_ldmk[:, 239], - mediapipe_ldmk[:, 241], - mediapipe_ldmk[:, 19], - mediapipe_ldmk[:, 461], - mediapipe_ldmk[:, 459], - # right eye - mediapipe_ldmk[:, 33], - mediapipe_ldmk[:, 160], - mediapipe_ldmk[:, 158], - mediapipe_ldmk[:, 133], - mediapipe_ldmk[:, 153], - mediapipe_ldmk[:, 144], - # left eye - mediapipe_ldmk[:, 362], - mediapipe_ldmk[:, 385], - mediapipe_ldmk[:, 387], - mediapipe_ldmk[:, 263], - mediapipe_ldmk[:, 373], - mediapipe_ldmk[:, 380], - # outer lips - mediapipe_ldmk[:, 61], - mediapipe_ldmk[:, 40], - mediapipe_ldmk[:, 37], - mediapipe_ldmk[:, 0], - mediapipe_ldmk[:, 267], - mediapipe_ldmk[:, 270], - mediapipe_ldmk[:, 291], - mediapipe_ldmk[:, 321], - mediapipe_ldmk[:, 314], - mediapipe_ldmk[:, 17], - mediapipe_ldmk[:, 84], - mediapipe_ldmk[:, 91], - # inner lips - mediapipe_ldmk[:, 78], - mediapipe_ldmk[:, 81], - mediapipe_ldmk[:, 13], - mediapipe_ldmk[:, 311], - mediapipe_ldmk[:, 308], - mediapipe_ldmk[:, 402], - mediapipe_ldmk[:, 14], - mediapipe_ldmk[:, 178], -], axis=1) - - - -# def parse_args(): -# parser = argparse.ArgumentParser() -# parser.add_argument("--config", type=str, default='./configs/prompts/animation_audio.yaml') -# parser.add_argument("-W", type=int, default=512) -# parser.add_argument("-H", type=int, default=512) -# parser.add_argument("-L", type=int) -# parser.add_argument("--seed", type=int, default=42) -# parser.add_argument("--cfg", type=float, default=3.5) -# parser.add_argument("--steps", type=int, default=25) -# parser.add_argument("--fps", type=int, default=30) -# parser.add_argument("-acc", "--accelerate", action='store_true') -# parser.add_argument("--fi_step", type=int, default=3) -# args = parser.parse_args() - -# return args - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--ref_image_path", type=str, required=True) - parser.add_argument("--audio_path", type=str, required=True) - parser.add_argument("--save_dir", type=str, required=True) - parser.add_argument("--fps", type=int, default=25) - parser.add_argument("--sr", type=int, default=16000) - args = parser.parse_args() - - return args - - -def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def main(): - args = parse_args() - - config = OmegaConf.load('aniportrait/configs/config.yaml') - - set_seed(42) - - # if config.weight_dtype == "fp16": - # weight_dtype = torch.float16 - # else: - # weight_dtype = torch.float32 - - audio_infer_config = OmegaConf.load(config.audio_inference_config) - # prepare model - a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) - a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False) - a2m_model.cuda().eval() - - a2p_model = Audio2PoseModel(audio_infer_config['a2p_model']) - a2p_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2p_ckpt']), strict=False) - a2p_model.cuda().eval() - - lmk_extractor = LMKExtractor() - - ref_image_path = args.ref_image_path - audio_path = args.audio_path - save_dir = args.save_dir - - ref_image_pil = Image.open(ref_image_path).convert("RGB") - ref_image_np = cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR) - height, width, _ = ref_image_np.shape - - face_result = lmk_extractor(ref_image_np) - assert face_result is not None, "No face detected." - lmks = face_result['lmks'].astype(np.float32) - lmks[:, 0] *= width - lmks[:, 1] *= height - - # print(lmks.shape) - - # assert False - - sample = prepare_audio_feature(audio_path, fps=args.fps, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) - sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() - sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) - - # print(sample['audio_feature'].shape) - - # inference - pred = a2m_model.infer(sample['audio_feature'], sample['seq_len']) - pred = pred.squeeze().detach().cpu().numpy() - pred = pred.reshape(pred.shape[0], -1, 3) - - pred = pred + face_result['lmks3d'] - - # print(pred.shape) - - # assert False - - id_seed = 42 - id_seed = torch.LongTensor([id_seed]).cuda() - - # Currently, only inference up to a maximum length of 10 seconds is supported. - chunk_duration = 5 # 5 seconds - chunk_size = args.sr * chunk_duration - - - audio_chunks = list(sample['audio_feature'].split(chunk_size, dim=1)) - seq_len_list = [chunk_duration*args.fps] * (len(audio_chunks) - 1) + [sample['seq_len'] % (chunk_duration*args.fps)] - audio_chunks[-2] = torch.cat((audio_chunks[-2], audio_chunks[-1]), dim=1) - seq_len_list[-2] = seq_len_list[-2] + seq_len_list[-1] - del audio_chunks[-1] - del seq_len_list[-1] - - # assert False - - pose_seq = [] - for audio, seq_len in zip(audio_chunks, seq_len_list): - pose_seq_chunk = a2p_model.infer(audio, seq_len, id_seed) - pose_seq_chunk = pose_seq_chunk.squeeze().detach().cpu().numpy() - pose_seq_chunk[:, :3] *= 0.5 - pose_seq.append(pose_seq_chunk) - - pose_seq = np.concatenate(pose_seq, 0) - pose_seq = smooth_pose_seq(pose_seq, 7) - - # project 3D mesh to 2D landmark - projected_vertices = project_points(pred, face_result['trans_mat'], pose_seq, [height, width]) - projected_vertices = np.concatenate([lmks[:468, :2][None, :], projected_vertices], axis=0) - projected_vertices = convert_ldmk_to_68(projected_vertices) - - # print(projected_vertices.shape) - - pose_images = [] - for i in range(projected_vertices.shape[0]): - pose_img = draw_landmarks(projected_vertices[i], height, width) - pose_images.append(pose_img) - pose_images = np.array(pose_images) - - # print(pose_images.shape) - - ref_image_np = cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB) - ref_imgs = np.stack([ref_image_np]*(pose_images.shape[0]), axis=0) - - all_np = np.concatenate([ref_imgs, pose_images], axis=2) - - # print(projected_vertices.shape) - - os.makedirs(save_dir, exist_ok=True) - - np.save(os.path.join(save_dir, 'landmarks.npy'), projected_vertices) - - torchvision.io.write_video(os.path.join(save_dir, 'landmarks.mp4'), all_np, fps=args.fps, video_codec='h264', options={'crf': '10'}) - - # stream = ffmpeg.input(os.path.join(save_dir, 'landmarks.mp4')) - # audio = ffmpeg.input(args.audio_path) - # ffmpeg.output(stream.video, audio.audio, os.path.join(save_dir, 'landmarks_audio.mp4'), vcodec='copy', acodec='aac').run() - - - - - - - -if __name__ == "__main__": - main() - \ No newline at end of file diff --git a/aniportrait/configs/config.yaml b/aniportrait/configs/config.yaml deleted file mode 100644 index 1be3d696e3f5acc118053ebeff327ff603abb53b..0000000000000000000000000000000000000000 --- a/aniportrait/configs/config.yaml +++ /dev/null @@ -1,12 +0,0 @@ -pretrained_base_model_path: 'ckpts/aniportrait/stable-diffusion-v1-5' -pretrained_vae_path: 'ckpts/aniportrait/sd-vae-ft-mse' -image_encoder_path: 'ckpts/aniportrait/image_encoder' - -denoising_unet_path: "ckpts/aniportrait/denoising_unet.pth" -reference_unet_path: "ckpts/aniportrait/reference_unet.pth" -pose_guider_path: "ckpts/aniportrait/pose_guider.pth" -motion_module_path: "ckpts/aniportrait/motion_module.pth" - -audio_inference_config: "aniportrait/configs/inference_audio.yaml" -inference_config: "aniportrait/configs/inference_v2.yaml" -weight_dtype: 'fp16' diff --git a/aniportrait/configs/inference_audio.yaml b/aniportrait/configs/inference_audio.yaml deleted file mode 100644 index c63499327424dc58f81660cd7d7fbc3c9b1a4b39..0000000000000000000000000000000000000000 --- a/aniportrait/configs/inference_audio.yaml +++ /dev/null @@ -1,17 +0,0 @@ -a2m_model: - out_dim: 1404 - latent_dim: 512 - model_path: ckpts/aniportrait/wav2vec2-base-960h - only_last_fetures: True - from_pretrained: True - -a2p_model: - out_dim: 6 - latent_dim: 512 - model_path: ckpts/aniportrait/wav2vec2-base-960h - only_last_fetures: True - from_pretrained: True - -pretrained_model: - a2m_ckpt: ckpts/aniportrait/audio2mesh.pt - a2p_ckpt: ckpts/aniportrait/audio2pose.pt diff --git a/aniportrait/configs/inference_v2.yaml b/aniportrait/configs/inference_v2.yaml deleted file mode 100644 index d613dca2d2e48a41295a89f47b5a82fd7032dba5..0000000000000000000000000000000000000000 --- a/aniportrait/configs/inference_v2.yaml +++ /dev/null @@ -1,35 +0,0 @@ -unet_additional_kwargs: - use_inflated_groupnorm: true - unet_use_cross_frame_attention: false - unet_use_temporal_attention: false - use_motion_module: true - motion_module_resolutions: - - 1 - - 2 - - 4 - - 8 - motion_module_mid_block: true - motion_module_decoder_only: false - motion_module_type: Vanilla - motion_module_kwargs: - num_attention_heads: 8 - num_transformer_block: 1 - attention_block_types: - - Temporal_Self - - Temporal_Self - temporal_position_encoding: true - temporal_position_encoding_max_len: 32 - temporal_attention_dim_div: 1 - -noise_scheduler_kwargs: - beta_start: 0.00085 - beta_end: 0.012 - beta_schedule: "linear" - clip_sample: false - steps_offset: 1 - ### Zero-SNR params - prediction_type: "v_prediction" - rescale_betas_zero_snr: True - timestep_spacing: "trailing" - -sampler: DDIM \ No newline at end of file diff --git a/aniportrait/src/.DS_Store b/aniportrait/src/.DS_Store deleted file mode 100644 index 8126747bcf280f1d8dbd91bb5cddc97b93191099..0000000000000000000000000000000000000000 Binary files a/aniportrait/src/.DS_Store and /dev/null differ diff --git a/aniportrait/src/audio_models/mish.py b/aniportrait/src/audio_models/mish.py deleted file mode 100644 index 607b95d33edd40bb53f93682bdcd9e0ff31ffbe4..0000000000000000000000000000000000000000 --- a/aniportrait/src/audio_models/mish.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Applies the mish function element-wise: -mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) -""" - -# import pytorch -import torch -import torch.nn.functional as F -from torch import nn - -@torch.jit.script -def mish(input): - """ - Applies the mish function element-wise: - mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) - See additional documentation for mish class. - """ - return input * torch.tanh(F.softplus(input)) - -class Mish(nn.Module): - """ - Applies the mish function element-wise: - mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) - - Shape: - - Input: (N, *) where * means, any number of additional - dimensions - - Output: (N, *), same shape as the input - - Examples: - >>> m = Mish() - >>> input = torch.randn(2) - >>> output = m(input) - - Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html - """ - - def __init__(self): - """ - Init method. - """ - super().__init__() - - def forward(self, input): - """ - Forward pass of the function. - """ - if torch.__version__ >= "1.9": - return F.mish(input) - else: - return mish(input) \ No newline at end of file diff --git a/aniportrait/src/audio_models/model.py b/aniportrait/src/audio_models/model.py deleted file mode 100644 index 54040ee238677d2fc70b4f34dd78c191c13ed874..0000000000000000000000000000000000000000 --- a/aniportrait/src/audio_models/model.py +++ /dev/null @@ -1,71 +0,0 @@ -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import Wav2Vec2Config - -from .torch_utils import get_mask_from_lengths -from .wav2vec2 import Wav2Vec2Model - - -class Audio2MeshModel(nn.Module): - def __init__( - self, - config - ): - super().__init__() - out_dim = config['out_dim'] - latent_dim = config['latent_dim'] - model_path = config['model_path'] - only_last_fetures = config['only_last_fetures'] - from_pretrained = config['from_pretrained'] - - self._only_last_features = only_last_fetures - - self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) - if from_pretrained: - self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) - else: - self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) - self.audio_encoder.feature_extractor._freeze_parameters() - - hidden_size = self.audio_encoder_config.hidden_size - - self.in_fn = nn.Linear(hidden_size, latent_dim) - - self.out_fn = nn.Linear(latent_dim, out_dim) - nn.init.constant_(self.out_fn.weight, 0) - nn.init.constant_(self.out_fn.bias, 0) - - def forward(self, audio, label, audio_len=None): - attention_mask = ~get_mask_from_lengths(audio_len) if audio_len else None - - seq_len = label.shape[1] - - embeddings = self.audio_encoder(audio, seq_len=seq_len, output_hidden_states=True, - attention_mask=attention_mask) - - if self._only_last_features: - hidden_states = embeddings.last_hidden_state - else: - hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) - - layer_in = self.in_fn(hidden_states) - out = self.out_fn(layer_in) - - return out, None - - def infer(self, input_value, seq_len): - embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) - - if self._only_last_features: - hidden_states = embeddings.last_hidden_state - else: - hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) - - layer_in = self.in_fn(hidden_states) - out = self.out_fn(layer_in) - - return out - - diff --git a/aniportrait/src/audio_models/pose_model.py b/aniportrait/src/audio_models/pose_model.py deleted file mode 100644 index f72f7477d45ac82c7fb277902cec612773c14257..0000000000000000000000000000000000000000 --- a/aniportrait/src/audio_models/pose_model.py +++ /dev/null @@ -1,125 +0,0 @@ -import os -import math -import torch -import torch.nn as nn -from transformers import Wav2Vec2Config - -from .torch_utils import get_mask_from_lengths -from .wav2vec2 import Wav2Vec2Model - - -def init_biased_mask(n_head, max_seq_len, period): - def get_slopes(n): - def get_slopes_power_of_2(n): - start = (2**(-2**-(math.log2(n)-3))) - ratio = start - return [start*ratio**i for i in range(n)] - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2] - slopes = torch.Tensor(get_slopes(n_head)) - bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1,period).view(-1)//(period) - bias = - torch.flip(bias,dims=[0]) - alibi = torch.zeros(max_seq_len, max_seq_len) - for i in range(max_seq_len): - alibi[i, :i+1] = bias[-(i+1):] - alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) - mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) - mask = mask.unsqueeze(0) + alibi - return mask - - -def enc_dec_mask(device, T, S): - mask = torch.ones(T, S) - for i in range(T): - mask[i, i] = 0 - return (mask==1).to(device=device) - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model, max_len=600): - super(PositionalEncoding, self).__init__() - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len).unsqueeze(1).float() - div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) - - def forward(self, x): - x = x + self.pe[:, :x.size(1)] - return x - - -class Audio2PoseModel(nn.Module): - def __init__( - self, - config - ): - - super().__init__() - - latent_dim = config['latent_dim'] - model_path = config['model_path'] - only_last_fetures = config['only_last_fetures'] - from_pretrained = config['from_pretrained'] - out_dim = config['out_dim'] - - self.out_dim = out_dim - - self._only_last_features = only_last_fetures - - self.audio_encoder_config = Wav2Vec2Config.from_pretrained(model_path, local_files_only=True) - if from_pretrained: - self.audio_encoder = Wav2Vec2Model.from_pretrained(model_path, local_files_only=True) - else: - self.audio_encoder = Wav2Vec2Model(self.audio_encoder_config) - self.audio_encoder.feature_extractor._freeze_parameters() - - hidden_size = self.audio_encoder_config.hidden_size - - self.pose_map = nn.Linear(out_dim, latent_dim) - self.in_fn = nn.Linear(hidden_size, latent_dim) - - self.PPE = PositionalEncoding(latent_dim) - self.biased_mask = init_biased_mask(n_head = 8, max_seq_len = 600, period=1) - decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=8, dim_feedforward=2*latent_dim, batch_first=True) - self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=8) - self.pose_map_r = nn.Linear(latent_dim, out_dim) - - self.id_embed = nn.Embedding(100, latent_dim) # 100 ids - - - def infer(self, input_value, seq_len, id_seed=None): - embeddings = self.audio_encoder(input_value, seq_len=seq_len, output_hidden_states=True) - - if self._only_last_features: - hidden_states = embeddings.last_hidden_state - else: - hidden_states = sum(embeddings.hidden_states) / len(embeddings.hidden_states) - - hidden_states = self.in_fn(hidden_states) - - id_embedding = self.id_embed(id_seed).unsqueeze(1) - - init_pose = torch.zeros([hidden_states.shape[0], 1, self.out_dim]).to(hidden_states.device) - for i in range(seq_len): - if i==0: - pose_emb = self.pose_map(init_pose) - pose_input = self.PPE(pose_emb) - else: - pose_input = self.PPE(pose_emb) - - pose_input = pose_input + id_embedding - tgt_mask = self.biased_mask[:, :pose_input.shape[1], :pose_input.shape[1]].clone().detach().to(hidden_states.device) - memory_mask = enc_dec_mask(hidden_states.device, pose_input.shape[1], hidden_states.shape[1]) - pose_out = self.transformer_decoder(pose_input, hidden_states, tgt_mask=tgt_mask, memory_mask=memory_mask) - pose_out = self.pose_map_r(pose_out) - new_output = self.pose_map(pose_out[:,-1,:]).unsqueeze(1) - pose_emb = torch.cat((pose_emb, new_output), 1) - return pose_out - diff --git a/aniportrait/src/audio_models/torch_utils.py b/aniportrait/src/audio_models/torch_utils.py deleted file mode 100644 index a91940405797c55cf685e8fa8f669adbc6089067..0000000000000000000000000000000000000000 --- a/aniportrait/src/audio_models/torch_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import torch.nn.functional as F - - -def get_mask_from_lengths(lengths, max_len=None): - lengths = lengths.to(torch.long) - if max_len is None: - max_len = torch.max(lengths).item() - - ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device) - mask = ids < lengths.unsqueeze(1).expand(-1, max_len) - - return mask - - -def linear_interpolation(features, seq_len): - features = features.transpose(1, 2) - output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') - return output_features.transpose(1, 2) - - -if __name__ == "__main__": - import numpy as np - mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6]))) - import pdb; pdb.set_trace() \ No newline at end of file diff --git a/aniportrait/src/audio_models/wav2vec2.py b/aniportrait/src/audio_models/wav2vec2.py deleted file mode 100644 index 5ec9c2b93454d47f6820b53c511e70208710e408..0000000000000000000000000000000000000000 --- a/aniportrait/src/audio_models/wav2vec2.py +++ /dev/null @@ -1,125 +0,0 @@ -from transformers import Wav2Vec2Config, Wav2Vec2Model -from transformers.modeling_outputs import BaseModelOutput - -from .torch_utils import linear_interpolation - -# the implementation of Wav2Vec2Model is borrowed from -# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py -# initialize our encoder with the pre-trained wav2vec 2.0 weights. -class Wav2Vec2Model(Wav2Vec2Model): - def __init__(self, config: Wav2Vec2Config): - super().__init__(config) - - def forward( - self, - input_values, - seq_len, - attention_mask=None, - mask_time_indices=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - self.config.output_attentions = True - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - extract_features = self.feature_extractor(input_values) - extract_features = extract_features.transpose(1, 2) - extract_features = linear_interpolation(extract_features, seq_len=seq_len) - - if attention_mask is not None: - # compute reduced attention_mask corresponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask( - extract_features.shape[1], attention_mask, add_adapter=False - ) - - hidden_states, extract_features = self.feature_projection(extract_features) - hidden_states = self._mask_hidden_states( - hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask - ) - - encoder_outputs = self.encoder( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = encoder_outputs[0] - - if self.adapter is not None: - hidden_states = self.adapter(hidden_states) - - if not return_dict: - return (hidden_states, ) + encoder_outputs[1:] - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - - def feature_extract( - self, - input_values, - seq_len, - ): - extract_features = self.feature_extractor(input_values) - extract_features = extract_features.transpose(1, 2) - extract_features = linear_interpolation(extract_features, seq_len=seq_len) - - return extract_features - - def encode( - self, - extract_features, - attention_mask=None, - mask_time_indices=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - self.config.output_attentions = True - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if attention_mask is not None: - # compute reduced attention_mask corresponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask( - extract_features.shape[1], attention_mask, add_adapter=False - ) - - - hidden_states, extract_features = self.feature_projection(extract_features) - hidden_states = self._mask_hidden_states( - hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask - ) - - encoder_outputs = self.encoder( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = encoder_outputs[0] - - if self.adapter is not None: - hidden_states = self.adapter(hidden_states) - - if not return_dict: - return (hidden_states, ) + encoder_outputs[1:] - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) diff --git a/aniportrait/src/utils/audio_util.py b/aniportrait/src/utils/audio_util.py deleted file mode 100644 index 7b42c5c8c25ae86f41ac87439223f822f55bb5c0..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/audio_util.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import math - -import librosa -import numpy as np -from transformers import Wav2Vec2FeatureExtractor - - -class DataProcessor: - def __init__(self, sampling_rate, wav2vec_model_path): - self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) - self._sampling_rate = sampling_rate - - def extract_feature(self, audio_path): - speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate) - input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values) - return input_value - - -def prepare_audio_feature(wav_file, fps=25, sampling_rate=16000, wav2vec_model_path=None): - data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path) - - input_value = data_preprocessor.extract_feature(wav_file) - seq_len = math.ceil(len(input_value)/sampling_rate*fps) - return { - "audio_feature": input_value, - "seq_len": seq_len - } - - diff --git a/aniportrait/src/utils/draw_util.py b/aniportrait/src/utils/draw_util.py deleted file mode 100644 index 96dbf49426c1f60dbf617f0970536ea7d37187f5..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/draw_util.py +++ /dev/null @@ -1,149 +0,0 @@ -import cv2 -import mediapipe as mp -import numpy as np -from mediapipe.framework.formats import landmark_pb2 - -class FaceMeshVisualizer: - def __init__(self, forehead_edge=False): - self.mp_drawing = mp.solutions.drawing_utils - mp_face_mesh = mp.solutions.face_mesh - self.mp_face_mesh = mp_face_mesh - self.forehead_edge = forehead_edge - - DrawingSpec = mp.solutions.drawing_styles.DrawingSpec - f_thick = 2 - f_rad = 1 - right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) - right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) - right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) - left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) - left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) - left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) - head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) - - mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad) - mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad) - - mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad) - mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad) - - mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad) - mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad) - - mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad) - mouth_draw_itr = DrawingSpec(color=(150 ,120, 100), thickness=f_thick, circle_radius=f_rad) - - FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61,146),(146,91),(91,181),(181,84),(84,17)] - FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17,314),(314,405),(405,321),(321,375),(375,291)] - - FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78,95),(95,88),(88,178),(178,87),(87,14)] - FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14,317),(317,402),(402,318),(318,324),(324,308)] - - FACEMESH_LIPS_OUTER_TOP_LEFT = [(61,185),(185,40),(40,39),(39,37),(37,0)] - FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0,267),(267,269),(269,270),(270,409),(409,291)] - - FACEMESH_LIPS_INNER_TOP_LEFT = [(78,191),(191,80),(80,81),(81,82),(82,13)] - FACEMESH_LIPS_INNER_TOP_RIGHT = [(13,312),(312,311),(311,310),(310,415),(415,308)] - - FACEMESH_CUSTOM_FACE_OVAL = [(176, 149), (150, 136), (356, 454), (58, 132), (152, 148), (361, 288), (251, 389), (132, 93), (389, 356), (400, 377), (136, 172), (377, 152), (323, 361), (172, 58), (454, 323), (365, 379), (379, 378), (148, 176), (93, 234), (397, 365), (149, 150), (288, 397), (234, 127), (378, 400), (127, 162), (162, 21)] - - # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. - face_connection_spec = {} - if self.forehead_edge: - for edge in mp_face_mesh.FACEMESH_FACE_OVAL: - face_connection_spec[edge] = head_draw - else: - for edge in FACEMESH_CUSTOM_FACE_OVAL: - face_connection_spec[edge] = head_draw - for edge in mp_face_mesh.FACEMESH_LEFT_EYE: - face_connection_spec[edge] = left_eye_draw - for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: - face_connection_spec[edge] = left_eyebrow_draw - # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: - # face_connection_spec[edge] = left_iris_draw - for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: - face_connection_spec[edge] = right_eye_draw - for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: - face_connection_spec[edge] = right_eyebrow_draw - # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: - # face_connection_spec[edge] = right_iris_draw - # for edge in mp_face_mesh.FACEMESH_LIPS: - # face_connection_spec[edge] = mouth_draw - - for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT: - face_connection_spec[edge] = mouth_draw_obl - for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT: - face_connection_spec[edge] = mouth_draw_obr - for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT: - face_connection_spec[edge] = mouth_draw_ibl - for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT: - face_connection_spec[edge] = mouth_draw_ibr - for edge in FACEMESH_LIPS_OUTER_TOP_LEFT: - face_connection_spec[edge] = mouth_draw_otl - for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT: - face_connection_spec[edge] = mouth_draw_otr - for edge in FACEMESH_LIPS_INNER_TOP_LEFT: - face_connection_spec[edge] = mouth_draw_itl - for edge in FACEMESH_LIPS_INNER_TOP_RIGHT: - face_connection_spec[edge] = mouth_draw_itr - - - iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} - - self.face_connection_spec = face_connection_spec - def draw_pupils(self, image, landmark_list, drawing_spec, halfwidth: int = 2): - """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all - landmarks. Until our PR is merged into mediapipe, we need this separate method.""" - if len(image.shape) != 3: - raise ValueError("Input image must be H,W,C.") - image_rows, image_cols, image_channels = image.shape - if image_channels != 3: # BGR channels - raise ValueError('Input image must contain three channel bgr data.') - for idx, landmark in enumerate(landmark_list.landmark): - if ( - (landmark.HasField('visibility') and landmark.visibility < 0.9) or - (landmark.HasField('presence') and landmark.presence < 0.5) - ): - continue - if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: - continue - image_x = int(image_cols*landmark.x) - image_y = int(image_rows*landmark.y) - draw_color = None - if isinstance(drawing_spec, Mapping): - if drawing_spec.get(idx) is None: - continue - else: - draw_color = drawing_spec[idx].color - elif isinstance(drawing_spec, DrawingSpec): - draw_color = drawing_spec.color - image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color - - - - def draw_landmarks(self, image_size, keypoints, normed=False): - ini_size = [512, 512] - image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) - new_landmarks = landmark_pb2.NormalizedLandmarkList() - for i in range(keypoints.shape[0]): - landmark = new_landmarks.landmark.add() - if normed: - landmark.x = keypoints[i, 0] - landmark.y = keypoints[i, 1] - else: - landmark.x = keypoints[i, 0] / image_size[0] - landmark.y = keypoints[i, 1] / image_size[1] - landmark.z = 1.0 - - self.mp_drawing.draw_landmarks( - image=image, - landmark_list=new_landmarks, - connections=self.face_connection_spec.keys(), - landmark_drawing_spec=None, - connection_drawing_spec=self.face_connection_spec - ) - # draw_pupils(image, face_landmarks, iris_landmark_spec, 2) - image = cv2.resize(image, (image_size[0], image_size[1])) - - return image - diff --git a/aniportrait/src/utils/face_landmark.py b/aniportrait/src/utils/face_landmark.py deleted file mode 100644 index b6580cb2cded9dcfeab46b0d50c8931ed6256669..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/face_landmark.py +++ /dev/null @@ -1,3305 +0,0 @@ -# Copyright 2023 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""MediaPipe face landmarker task.""" - -import dataclasses -import enum -from typing import Callable, Mapping, Optional, List - -import numpy as np - -from mediapipe.framework.formats import classification_pb2 -from mediapipe.framework.formats import landmark_pb2 -from mediapipe.framework.formats import matrix_data_pb2 -from mediapipe.python import packet_creator -from mediapipe.python import packet_getter -from mediapipe.python._framework_bindings import image as image_module -from mediapipe.python._framework_bindings import packet as packet_module -# pylint: disable=unused-import -from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2 -# pylint: enable=unused-import -from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2 -from mediapipe.tasks.python.components.containers import category as category_module -from mediapipe.tasks.python.components.containers import landmark as landmark_module -from mediapipe.tasks.python.core import base_options as base_options_module -from mediapipe.tasks.python.core import task_info as task_info_module -from mediapipe.tasks.python.core.optional_dependencies import doc_controls -from mediapipe.tasks.python.vision.core import base_vision_task_api -from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module -from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module - -_BaseOptions = base_options_module.BaseOptions -_FaceLandmarkerGraphOptionsProto = ( - face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions -) -_LayoutEnum = matrix_data_pb2.MatrixData.Layout -_RunningMode = running_mode_module.VisionTaskRunningMode -_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions -_TaskInfo = task_info_module.TaskInfo - -_IMAGE_IN_STREAM_NAME = 'image_in' -_IMAGE_OUT_STREAM_NAME = 'image_out' -_IMAGE_TAG = 'IMAGE' -_NORM_RECT_STREAM_NAME = 'norm_rect_in' -_NORM_RECT_TAG = 'NORM_RECT' -_NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks' -_NORM_LANDMARKS_TAG = 'NORM_LANDMARKS' -_BLENDSHAPES_STREAM_NAME = 'blendshapes' -_BLENDSHAPES_TAG = 'BLENDSHAPES' -_FACE_GEOMETRY_STREAM_NAME = 'face_geometry' -_FACE_GEOMETRY_TAG = 'FACE_GEOMETRY' -_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph' -_MICRO_SECONDS_PER_MILLISECOND = 1000 - - -class Blendshapes(enum.IntEnum): - """The 52 blendshape coefficients.""" - - NEUTRAL = 0 - BROW_DOWN_LEFT = 1 - BROW_DOWN_RIGHT = 2 - BROW_INNER_UP = 3 - BROW_OUTER_UP_LEFT = 4 - BROW_OUTER_UP_RIGHT = 5 - CHEEK_PUFF = 6 - CHEEK_SQUINT_LEFT = 7 - CHEEK_SQUINT_RIGHT = 8 - EYE_BLINK_LEFT = 9 - EYE_BLINK_RIGHT = 10 - EYE_LOOK_DOWN_LEFT = 11 - EYE_LOOK_DOWN_RIGHT = 12 - EYE_LOOK_IN_LEFT = 13 - EYE_LOOK_IN_RIGHT = 14 - EYE_LOOK_OUT_LEFT = 15 - EYE_LOOK_OUT_RIGHT = 16 - EYE_LOOK_UP_LEFT = 17 - EYE_LOOK_UP_RIGHT = 18 - EYE_SQUINT_LEFT = 19 - EYE_SQUINT_RIGHT = 20 - EYE_WIDE_LEFT = 21 - EYE_WIDE_RIGHT = 22 - JAW_FORWARD = 23 - JAW_LEFT = 24 - JAW_OPEN = 25 - JAW_RIGHT = 26 - MOUTH_CLOSE = 27 - MOUTH_DIMPLE_LEFT = 28 - MOUTH_DIMPLE_RIGHT = 29 - MOUTH_FROWN_LEFT = 30 - MOUTH_FROWN_RIGHT = 31 - MOUTH_FUNNEL = 32 - MOUTH_LEFT = 33 - MOUTH_LOWER_DOWN_LEFT = 34 - MOUTH_LOWER_DOWN_RIGHT = 35 - MOUTH_PRESS_LEFT = 36 - MOUTH_PRESS_RIGHT = 37 - MOUTH_PUCKER = 38 - MOUTH_RIGHT = 39 - MOUTH_ROLL_LOWER = 40 - MOUTH_ROLL_UPPER = 41 - MOUTH_SHRUG_LOWER = 42 - MOUTH_SHRUG_UPPER = 43 - MOUTH_SMILE_LEFT = 44 - MOUTH_SMILE_RIGHT = 45 - MOUTH_STRETCH_LEFT = 46 - MOUTH_STRETCH_RIGHT = 47 - MOUTH_UPPER_UP_LEFT = 48 - MOUTH_UPPER_UP_RIGHT = 49 - NOSE_SNEER_LEFT = 50 - NOSE_SNEER_RIGHT = 51 - - -class FaceLandmarksConnections: - """The connections between face landmarks.""" - - @dataclasses.dataclass - class Connection: - """The connection class for face landmarks.""" - - start: int - end: int - - FACE_LANDMARKS_LIPS: List[Connection] = [ - Connection(61, 146), - Connection(146, 91), - Connection(91, 181), - Connection(181, 84), - Connection(84, 17), - Connection(17, 314), - Connection(314, 405), - Connection(405, 321), - Connection(321, 375), - Connection(375, 291), - Connection(61, 185), - Connection(185, 40), - Connection(40, 39), - Connection(39, 37), - Connection(37, 0), - Connection(0, 267), - Connection(267, 269), - Connection(269, 270), - Connection(270, 409), - Connection(409, 291), - Connection(78, 95), - Connection(95, 88), - Connection(88, 178), - Connection(178, 87), - Connection(87, 14), - Connection(14, 317), - Connection(317, 402), - Connection(402, 318), - Connection(318, 324), - Connection(324, 308), - Connection(78, 191), - Connection(191, 80), - Connection(80, 81), - Connection(81, 82), - Connection(82, 13), - Connection(13, 312), - Connection(312, 311), - Connection(311, 310), - Connection(310, 415), - Connection(415, 308), - ] - - FACE_LANDMARKS_LEFT_EYE: List[Connection] = [ - Connection(263, 249), - Connection(249, 390), - Connection(390, 373), - Connection(373, 374), - Connection(374, 380), - Connection(380, 381), - Connection(381, 382), - Connection(382, 362), - Connection(263, 466), - Connection(466, 388), - Connection(388, 387), - Connection(387, 386), - Connection(386, 385), - Connection(385, 384), - Connection(384, 398), - Connection(398, 362), - ] - - FACE_LANDMARKS_LEFT_EYEBROW: List[Connection] = [ - Connection(276, 283), - Connection(283, 282), - Connection(282, 295), - Connection(295, 285), - Connection(300, 293), - Connection(293, 334), - Connection(334, 296), - Connection(296, 336), - ] - - FACE_LANDMARKS_LEFT_IRIS: List[Connection] = [ - Connection(474, 475), - Connection(475, 476), - Connection(476, 477), - Connection(477, 474), - ] - - FACE_LANDMARKS_RIGHT_EYE: List[Connection] = [ - Connection(33, 7), - Connection(7, 163), - Connection(163, 144), - Connection(144, 145), - Connection(145, 153), - Connection(153, 154), - Connection(154, 155), - Connection(155, 133), - Connection(33, 246), - Connection(246, 161), - Connection(161, 160), - Connection(160, 159), - Connection(159, 158), - Connection(158, 157), - Connection(157, 173), - Connection(173, 133), - ] - - FACE_LANDMARKS_RIGHT_EYEBROW: List[Connection] = [ - Connection(46, 53), - Connection(53, 52), - Connection(52, 65), - Connection(65, 55), - Connection(70, 63), - Connection(63, 105), - Connection(105, 66), - Connection(66, 107), - ] - - FACE_LANDMARKS_RIGHT_IRIS: List[Connection] = [ - Connection(469, 470), - Connection(470, 471), - Connection(471, 472), - Connection(472, 469), - ] - - FACE_LANDMARKS_FACE_OVAL: List[Connection] = [ - Connection(10, 338), - Connection(338, 297), - Connection(297, 332), - Connection(332, 284), - Connection(284, 251), - Connection(251, 389), - Connection(389, 356), - Connection(356, 454), - Connection(454, 323), - Connection(323, 361), - Connection(361, 288), - Connection(288, 397), - Connection(397, 365), - Connection(365, 379), - Connection(379, 378), - Connection(378, 400), - Connection(400, 377), - Connection(377, 152), - Connection(152, 148), - Connection(148, 176), - Connection(176, 149), - Connection(149, 150), - Connection(150, 136), - Connection(136, 172), - Connection(172, 58), - Connection(58, 132), - Connection(132, 93), - Connection(93, 234), - Connection(234, 127), - Connection(127, 162), - Connection(162, 21), - Connection(21, 54), - Connection(54, 103), - Connection(103, 67), - Connection(67, 109), - Connection(109, 10), - ] - - FACE_LANDMARKS_CONTOURS: List[Connection] = ( - FACE_LANDMARKS_LIPS - + FACE_LANDMARKS_LEFT_EYE - + FACE_LANDMARKS_LEFT_EYEBROW - + FACE_LANDMARKS_RIGHT_EYE - + FACE_LANDMARKS_RIGHT_EYEBROW - + FACE_LANDMARKS_FACE_OVAL - ) - - FACE_LANDMARKS_TESSELATION: List[Connection] = [ - Connection(127, 34), - Connection(34, 139), - Connection(139, 127), - Connection(11, 0), - Connection(0, 37), - Connection(37, 11), - Connection(232, 231), - Connection(231, 120), - Connection(120, 232), - Connection(72, 37), - Connection(37, 39), - Connection(39, 72), - Connection(128, 121), - Connection(121, 47), - Connection(47, 128), - Connection(232, 121), - Connection(121, 128), - Connection(128, 232), - Connection(104, 69), - Connection(69, 67), - Connection(67, 104), - Connection(175, 171), - Connection(171, 148), - Connection(148, 175), - Connection(118, 50), - Connection(50, 101), - Connection(101, 118), - Connection(73, 39), - Connection(39, 40), - Connection(40, 73), - Connection(9, 151), - Connection(151, 108), - Connection(108, 9), - Connection(48, 115), - Connection(115, 131), - Connection(131, 48), - Connection(194, 204), - Connection(204, 211), - Connection(211, 194), - Connection(74, 40), - Connection(40, 185), - Connection(185, 74), - Connection(80, 42), - Connection(42, 183), - Connection(183, 80), - Connection(40, 92), - Connection(92, 186), - Connection(186, 40), - Connection(230, 229), - Connection(229, 118), - Connection(118, 230), - Connection(202, 212), - Connection(212, 214), - Connection(214, 202), - Connection(83, 18), - Connection(18, 17), - Connection(17, 83), - Connection(76, 61), - Connection(61, 146), - Connection(146, 76), - Connection(160, 29), - Connection(29, 30), - Connection(30, 160), - Connection(56, 157), - Connection(157, 173), - Connection(173, 56), - Connection(106, 204), - Connection(204, 194), - Connection(194, 106), - Connection(135, 214), - Connection(214, 192), - Connection(192, 135), - Connection(203, 165), - Connection(165, 98), - Connection(98, 203), - Connection(21, 71), - Connection(71, 68), - Connection(68, 21), - Connection(51, 45), - Connection(45, 4), - Connection(4, 51), - Connection(144, 24), - Connection(24, 23), - Connection(23, 144), - Connection(77, 146), - Connection(146, 91), - Connection(91, 77), - Connection(205, 50), - Connection(50, 187), - Connection(187, 205), - Connection(201, 200), - Connection(200, 18), - Connection(18, 201), - Connection(91, 106), - Connection(106, 182), - Connection(182, 91), - Connection(90, 91), - Connection(91, 181), - Connection(181, 90), - Connection(85, 84), - Connection(84, 17), - Connection(17, 85), - Connection(206, 203), - Connection(203, 36), - Connection(36, 206), - Connection(148, 171), - Connection(171, 140), - Connection(140, 148), - Connection(92, 40), - Connection(40, 39), - Connection(39, 92), - Connection(193, 189), - Connection(189, 244), - Connection(244, 193), - Connection(159, 158), - Connection(158, 28), - Connection(28, 159), - Connection(247, 246), - Connection(246, 161), - Connection(161, 247), - Connection(236, 3), - Connection(3, 196), - Connection(196, 236), - Connection(54, 68), - Connection(68, 104), - Connection(104, 54), - Connection(193, 168), - Connection(168, 8), - Connection(8, 193), - Connection(117, 228), - Connection(228, 31), - Connection(31, 117), - Connection(189, 193), - Connection(193, 55), - Connection(55, 189), - Connection(98, 97), - Connection(97, 99), - Connection(99, 98), - Connection(126, 47), - Connection(47, 100), - Connection(100, 126), - Connection(166, 79), - Connection(79, 218), - Connection(218, 166), - Connection(155, 154), - Connection(154, 26), - Connection(26, 155), - Connection(209, 49), - Connection(49, 131), - Connection(131, 209), - Connection(135, 136), - Connection(136, 150), - Connection(150, 135), - Connection(47, 126), - Connection(126, 217), - Connection(217, 47), - Connection(223, 52), - Connection(52, 53), - Connection(53, 223), - Connection(45, 51), - Connection(51, 134), - Connection(134, 45), - Connection(211, 170), - Connection(170, 140), - Connection(140, 211), - Connection(67, 69), - Connection(69, 108), - Connection(108, 67), - Connection(43, 106), - Connection(106, 91), - Connection(91, 43), - Connection(230, 119), - Connection(119, 120), - Connection(120, 230), - Connection(226, 130), - Connection(130, 247), - Connection(247, 226), - Connection(63, 53), - Connection(53, 52), - Connection(52, 63), - Connection(238, 20), - Connection(20, 242), - Connection(242, 238), - Connection(46, 70), - Connection(70, 156), - Connection(156, 46), - Connection(78, 62), - Connection(62, 96), - Connection(96, 78), - Connection(46, 53), - Connection(53, 63), - Connection(63, 46), - Connection(143, 34), - Connection(34, 227), - Connection(227, 143), - Connection(123, 117), - Connection(117, 111), - Connection(111, 123), - Connection(44, 125), - Connection(125, 19), - Connection(19, 44), - Connection(236, 134), - Connection(134, 51), - Connection(51, 236), - Connection(216, 206), - Connection(206, 205), - Connection(205, 216), - Connection(154, 153), - Connection(153, 22), - Connection(22, 154), - Connection(39, 37), - Connection(37, 167), - Connection(167, 39), - Connection(200, 201), - Connection(201, 208), - Connection(208, 200), - Connection(36, 142), - Connection(142, 100), - Connection(100, 36), - Connection(57, 212), - Connection(212, 202), - Connection(202, 57), - Connection(20, 60), - Connection(60, 99), - Connection(99, 20), - Connection(28, 158), - Connection(158, 157), - Connection(157, 28), - Connection(35, 226), - Connection(226, 113), - Connection(113, 35), - Connection(160, 159), - Connection(159, 27), - Connection(27, 160), - Connection(204, 202), - Connection(202, 210), - Connection(210, 204), - Connection(113, 225), - Connection(225, 46), - Connection(46, 113), - Connection(43, 202), - Connection(202, 204), - Connection(204, 43), - Connection(62, 76), - Connection(76, 77), - Connection(77, 62), - Connection(137, 123), - Connection(123, 116), - Connection(116, 137), - Connection(41, 38), - Connection(38, 72), - Connection(72, 41), - Connection(203, 129), - Connection(129, 142), - Connection(142, 203), - Connection(64, 98), - Connection(98, 240), - Connection(240, 64), - Connection(49, 102), - Connection(102, 64), - Connection(64, 49), - Connection(41, 73), - Connection(73, 74), - Connection(74, 41), - Connection(212, 216), - Connection(216, 207), - Connection(207, 212), - Connection(42, 74), - Connection(74, 184), - Connection(184, 42), - Connection(169, 170), - Connection(170, 211), - Connection(211, 169), - Connection(170, 149), - Connection(149, 176), - Connection(176, 170), - Connection(105, 66), - Connection(66, 69), - Connection(69, 105), - Connection(122, 6), - Connection(6, 168), - Connection(168, 122), - Connection(123, 147), - Connection(147, 187), - Connection(187, 123), - Connection(96, 77), - Connection(77, 90), - Connection(90, 96), - Connection(65, 55), - Connection(55, 107), - Connection(107, 65), - Connection(89, 90), - Connection(90, 180), - Connection(180, 89), - Connection(101, 100), - Connection(100, 120), - Connection(120, 101), - Connection(63, 105), - Connection(105, 104), - Connection(104, 63), - Connection(93, 137), - Connection(137, 227), - Connection(227, 93), - Connection(15, 86), - Connection(86, 85), - Connection(85, 15), - Connection(129, 102), - Connection(102, 49), - Connection(49, 129), - Connection(14, 87), - Connection(87, 86), - Connection(86, 14), - Connection(55, 8), - Connection(8, 9), - Connection(9, 55), - Connection(100, 47), - Connection(47, 121), - Connection(121, 100), - Connection(145, 23), - Connection(23, 22), - Connection(22, 145), - Connection(88, 89), - Connection(89, 179), - Connection(179, 88), - Connection(6, 122), - Connection(122, 196), - Connection(196, 6), - Connection(88, 95), - Connection(95, 96), - Connection(96, 88), - Connection(138, 172), - Connection(172, 136), - Connection(136, 138), - Connection(215, 58), - Connection(58, 172), - Connection(172, 215), - Connection(115, 48), - Connection(48, 219), - Connection(219, 115), - Connection(42, 80), - Connection(80, 81), - Connection(81, 42), - Connection(195, 3), - Connection(3, 51), - Connection(51, 195), - Connection(43, 146), - Connection(146, 61), - Connection(61, 43), - Connection(171, 175), - Connection(175, 199), - Connection(199, 171), - Connection(81, 82), - Connection(82, 38), - Connection(38, 81), - Connection(53, 46), - Connection(46, 225), - Connection(225, 53), - Connection(144, 163), - Connection(163, 110), - Connection(110, 144), - Connection(52, 65), - Connection(65, 66), - Connection(66, 52), - Connection(229, 228), - Connection(228, 117), - Connection(117, 229), - Connection(34, 127), - Connection(127, 234), - Connection(234, 34), - Connection(107, 108), - Connection(108, 69), - Connection(69, 107), - Connection(109, 108), - Connection(108, 151), - Connection(151, 109), - Connection(48, 64), - Connection(64, 235), - Connection(235, 48), - Connection(62, 78), - Connection(78, 191), - Connection(191, 62), - Connection(129, 209), - Connection(209, 126), - Connection(126, 129), - Connection(111, 35), - Connection(35, 143), - Connection(143, 111), - Connection(117, 123), - Connection(123, 50), - Connection(50, 117), - Connection(222, 65), - Connection(65, 52), - Connection(52, 222), - Connection(19, 125), - Connection(125, 141), - Connection(141, 19), - Connection(221, 55), - Connection(55, 65), - Connection(65, 221), - Connection(3, 195), - Connection(195, 197), - Connection(197, 3), - Connection(25, 7), - Connection(7, 33), - Connection(33, 25), - Connection(220, 237), - Connection(237, 44), - Connection(44, 220), - Connection(70, 71), - Connection(71, 139), - Connection(139, 70), - Connection(122, 193), - Connection(193, 245), - Connection(245, 122), - Connection(247, 130), - Connection(130, 33), - Connection(33, 247), - Connection(71, 21), - Connection(21, 162), - Connection(162, 71), - Connection(170, 169), - Connection(169, 150), - Connection(150, 170), - Connection(188, 174), - Connection(174, 196), - Connection(196, 188), - Connection(216, 186), - Connection(186, 92), - Connection(92, 216), - Connection(2, 97), - Connection(97, 167), - Connection(167, 2), - Connection(141, 125), - Connection(125, 241), - Connection(241, 141), - Connection(164, 167), - Connection(167, 37), - Connection(37, 164), - Connection(72, 38), - Connection(38, 12), - Connection(12, 72), - Connection(38, 82), - Connection(82, 13), - Connection(13, 38), - Connection(63, 68), - Connection(68, 71), - Connection(71, 63), - Connection(226, 35), - Connection(35, 111), - Connection(111, 226), - Connection(101, 50), - Connection(50, 205), - Connection(205, 101), - Connection(206, 92), - Connection(92, 165), - Connection(165, 206), - Connection(209, 198), - Connection(198, 217), - Connection(217, 209), - Connection(165, 167), - Connection(167, 97), - Connection(97, 165), - Connection(220, 115), - Connection(115, 218), - Connection(218, 220), - Connection(133, 112), - Connection(112, 243), - Connection(243, 133), - Connection(239, 238), - Connection(238, 241), - Connection(241, 239), - Connection(214, 135), - Connection(135, 169), - Connection(169, 214), - Connection(190, 173), - Connection(173, 133), - Connection(133, 190), - Connection(171, 208), - Connection(208, 32), - Connection(32, 171), - Connection(125, 44), - Connection(44, 237), - Connection(237, 125), - Connection(86, 87), - Connection(87, 178), - Connection(178, 86), - Connection(85, 86), - Connection(86, 179), - Connection(179, 85), - Connection(84, 85), - Connection(85, 180), - Connection(180, 84), - Connection(83, 84), - Connection(84, 181), - Connection(181, 83), - Connection(201, 83), - Connection(83, 182), - Connection(182, 201), - Connection(137, 93), - Connection(93, 132), - Connection(132, 137), - Connection(76, 62), - Connection(62, 183), - Connection(183, 76), - Connection(61, 76), - Connection(76, 184), - Connection(184, 61), - Connection(57, 61), - Connection(61, 185), - Connection(185, 57), - Connection(212, 57), - Connection(57, 186), - Connection(186, 212), - Connection(214, 207), - Connection(207, 187), - Connection(187, 214), - Connection(34, 143), - Connection(143, 156), - Connection(156, 34), - Connection(79, 239), - Connection(239, 237), - Connection(237, 79), - Connection(123, 137), - Connection(137, 177), - Connection(177, 123), - Connection(44, 1), - Connection(1, 4), - Connection(4, 44), - Connection(201, 194), - Connection(194, 32), - Connection(32, 201), - Connection(64, 102), - Connection(102, 129), - Connection(129, 64), - Connection(213, 215), - Connection(215, 138), - Connection(138, 213), - Connection(59, 166), - Connection(166, 219), - Connection(219, 59), - Connection(242, 99), - Connection(99, 97), - Connection(97, 242), - Connection(2, 94), - Connection(94, 141), - Connection(141, 2), - Connection(75, 59), - Connection(59, 235), - Connection(235, 75), - Connection(24, 110), - Connection(110, 228), - Connection(228, 24), - Connection(25, 130), - Connection(130, 226), - Connection(226, 25), - Connection(23, 24), - Connection(24, 229), - Connection(229, 23), - Connection(22, 23), - Connection(23, 230), - Connection(230, 22), - Connection(26, 22), - Connection(22, 231), - Connection(231, 26), - Connection(112, 26), - Connection(26, 232), - Connection(232, 112), - Connection(189, 190), - Connection(190, 243), - Connection(243, 189), - Connection(221, 56), - Connection(56, 190), - Connection(190, 221), - Connection(28, 56), - Connection(56, 221), - Connection(221, 28), - Connection(27, 28), - Connection(28, 222), - Connection(222, 27), - Connection(29, 27), - Connection(27, 223), - Connection(223, 29), - Connection(30, 29), - Connection(29, 224), - Connection(224, 30), - Connection(247, 30), - Connection(30, 225), - Connection(225, 247), - Connection(238, 79), - Connection(79, 20), - Connection(20, 238), - Connection(166, 59), - Connection(59, 75), - Connection(75, 166), - Connection(60, 75), - Connection(75, 240), - Connection(240, 60), - Connection(147, 177), - Connection(177, 215), - Connection(215, 147), - Connection(20, 79), - Connection(79, 166), - Connection(166, 20), - Connection(187, 147), - Connection(147, 213), - Connection(213, 187), - Connection(112, 233), - Connection(233, 244), - Connection(244, 112), - Connection(233, 128), - Connection(128, 245), - Connection(245, 233), - Connection(128, 114), - Connection(114, 188), - Connection(188, 128), - Connection(114, 217), - Connection(217, 174), - Connection(174, 114), - Connection(131, 115), - Connection(115, 220), - Connection(220, 131), - Connection(217, 198), - Connection(198, 236), - Connection(236, 217), - Connection(198, 131), - Connection(131, 134), - Connection(134, 198), - Connection(177, 132), - Connection(132, 58), - Connection(58, 177), - Connection(143, 35), - Connection(35, 124), - Connection(124, 143), - Connection(110, 163), - Connection(163, 7), - Connection(7, 110), - Connection(228, 110), - Connection(110, 25), - Connection(25, 228), - Connection(356, 389), - Connection(389, 368), - Connection(368, 356), - Connection(11, 302), - Connection(302, 267), - Connection(267, 11), - Connection(452, 350), - Connection(350, 349), - Connection(349, 452), - Connection(302, 303), - Connection(303, 269), - Connection(269, 302), - Connection(357, 343), - Connection(343, 277), - Connection(277, 357), - Connection(452, 453), - Connection(453, 357), - Connection(357, 452), - Connection(333, 332), - Connection(332, 297), - Connection(297, 333), - Connection(175, 152), - Connection(152, 377), - Connection(377, 175), - Connection(347, 348), - Connection(348, 330), - Connection(330, 347), - Connection(303, 304), - Connection(304, 270), - Connection(270, 303), - Connection(9, 336), - Connection(336, 337), - Connection(337, 9), - Connection(278, 279), - Connection(279, 360), - Connection(360, 278), - Connection(418, 262), - Connection(262, 431), - Connection(431, 418), - Connection(304, 408), - Connection(408, 409), - Connection(409, 304), - Connection(310, 415), - Connection(415, 407), - Connection(407, 310), - Connection(270, 409), - Connection(409, 410), - Connection(410, 270), - Connection(450, 348), - Connection(348, 347), - Connection(347, 450), - Connection(422, 430), - Connection(430, 434), - Connection(434, 422), - Connection(313, 314), - Connection(314, 17), - Connection(17, 313), - Connection(306, 307), - Connection(307, 375), - Connection(375, 306), - Connection(387, 388), - Connection(388, 260), - Connection(260, 387), - Connection(286, 414), - Connection(414, 398), - Connection(398, 286), - Connection(335, 406), - Connection(406, 418), - Connection(418, 335), - Connection(364, 367), - Connection(367, 416), - Connection(416, 364), - Connection(423, 358), - Connection(358, 327), - Connection(327, 423), - Connection(251, 284), - Connection(284, 298), - Connection(298, 251), - Connection(281, 5), - Connection(5, 4), - Connection(4, 281), - Connection(373, 374), - Connection(374, 253), - Connection(253, 373), - Connection(307, 320), - Connection(320, 321), - Connection(321, 307), - Connection(425, 427), - Connection(427, 411), - Connection(411, 425), - Connection(421, 313), - Connection(313, 18), - Connection(18, 421), - Connection(321, 405), - Connection(405, 406), - Connection(406, 321), - Connection(320, 404), - Connection(404, 405), - Connection(405, 320), - Connection(315, 16), - Connection(16, 17), - Connection(17, 315), - Connection(426, 425), - Connection(425, 266), - Connection(266, 426), - Connection(377, 400), - Connection(400, 369), - Connection(369, 377), - Connection(322, 391), - Connection(391, 269), - Connection(269, 322), - Connection(417, 465), - Connection(465, 464), - Connection(464, 417), - Connection(386, 257), - Connection(257, 258), - Connection(258, 386), - Connection(466, 260), - Connection(260, 388), - Connection(388, 466), - Connection(456, 399), - Connection(399, 419), - Connection(419, 456), - Connection(284, 332), - Connection(332, 333), - Connection(333, 284), - Connection(417, 285), - Connection(285, 8), - Connection(8, 417), - Connection(346, 340), - Connection(340, 261), - Connection(261, 346), - Connection(413, 441), - Connection(441, 285), - Connection(285, 413), - Connection(327, 460), - Connection(460, 328), - Connection(328, 327), - Connection(355, 371), - Connection(371, 329), - Connection(329, 355), - Connection(392, 439), - Connection(439, 438), - Connection(438, 392), - Connection(382, 341), - Connection(341, 256), - Connection(256, 382), - Connection(429, 420), - Connection(420, 360), - Connection(360, 429), - Connection(364, 394), - Connection(394, 379), - Connection(379, 364), - Connection(277, 343), - Connection(343, 437), - Connection(437, 277), - Connection(443, 444), - Connection(444, 283), - Connection(283, 443), - Connection(275, 440), - Connection(440, 363), - Connection(363, 275), - Connection(431, 262), - Connection(262, 369), - Connection(369, 431), - Connection(297, 338), - Connection(338, 337), - Connection(337, 297), - Connection(273, 375), - Connection(375, 321), - Connection(321, 273), - Connection(450, 451), - Connection(451, 349), - Connection(349, 450), - Connection(446, 342), - Connection(342, 467), - Connection(467, 446), - Connection(293, 334), - Connection(334, 282), - Connection(282, 293), - Connection(458, 461), - Connection(461, 462), - Connection(462, 458), - Connection(276, 353), - Connection(353, 383), - Connection(383, 276), - Connection(308, 324), - Connection(324, 325), - Connection(325, 308), - Connection(276, 300), - Connection(300, 293), - Connection(293, 276), - Connection(372, 345), - Connection(345, 447), - Connection(447, 372), - Connection(352, 345), - Connection(345, 340), - Connection(340, 352), - Connection(274, 1), - Connection(1, 19), - Connection(19, 274), - Connection(456, 248), - Connection(248, 281), - Connection(281, 456), - Connection(436, 427), - Connection(427, 425), - Connection(425, 436), - Connection(381, 256), - Connection(256, 252), - Connection(252, 381), - Connection(269, 391), - Connection(391, 393), - Connection(393, 269), - Connection(200, 199), - Connection(199, 428), - Connection(428, 200), - Connection(266, 330), - Connection(330, 329), - Connection(329, 266), - Connection(287, 273), - Connection(273, 422), - Connection(422, 287), - Connection(250, 462), - Connection(462, 328), - Connection(328, 250), - Connection(258, 286), - Connection(286, 384), - Connection(384, 258), - Connection(265, 353), - Connection(353, 342), - Connection(342, 265), - Connection(387, 259), - Connection(259, 257), - Connection(257, 387), - Connection(424, 431), - Connection(431, 430), - Connection(430, 424), - Connection(342, 353), - Connection(353, 276), - Connection(276, 342), - Connection(273, 335), - Connection(335, 424), - Connection(424, 273), - Connection(292, 325), - Connection(325, 307), - Connection(307, 292), - Connection(366, 447), - Connection(447, 345), - Connection(345, 366), - Connection(271, 303), - Connection(303, 302), - Connection(302, 271), - Connection(423, 266), - Connection(266, 371), - Connection(371, 423), - Connection(294, 455), - Connection(455, 460), - Connection(460, 294), - Connection(279, 278), - Connection(278, 294), - Connection(294, 279), - Connection(271, 272), - Connection(272, 304), - Connection(304, 271), - Connection(432, 434), - Connection(434, 427), - Connection(427, 432), - Connection(272, 407), - Connection(407, 408), - Connection(408, 272), - Connection(394, 430), - Connection(430, 431), - Connection(431, 394), - Connection(395, 369), - Connection(369, 400), - Connection(400, 395), - Connection(334, 333), - Connection(333, 299), - Connection(299, 334), - Connection(351, 417), - Connection(417, 168), - Connection(168, 351), - Connection(352, 280), - Connection(280, 411), - Connection(411, 352), - Connection(325, 319), - Connection(319, 320), - Connection(320, 325), - Connection(295, 296), - Connection(296, 336), - Connection(336, 295), - Connection(319, 403), - Connection(403, 404), - Connection(404, 319), - Connection(330, 348), - Connection(348, 349), - Connection(349, 330), - Connection(293, 298), - Connection(298, 333), - Connection(333, 293), - Connection(323, 454), - Connection(454, 447), - Connection(447, 323), - Connection(15, 16), - Connection(16, 315), - Connection(315, 15), - Connection(358, 429), - Connection(429, 279), - Connection(279, 358), - Connection(14, 15), - Connection(15, 316), - Connection(316, 14), - Connection(285, 336), - Connection(336, 9), - Connection(9, 285), - Connection(329, 349), - Connection(349, 350), - Connection(350, 329), - Connection(374, 380), - Connection(380, 252), - Connection(252, 374), - Connection(318, 402), - Connection(402, 403), - Connection(403, 318), - Connection(6, 197), - Connection(197, 419), - Connection(419, 6), - Connection(318, 319), - Connection(319, 325), - Connection(325, 318), - Connection(367, 364), - Connection(364, 365), - Connection(365, 367), - Connection(435, 367), - Connection(367, 397), - Connection(397, 435), - Connection(344, 438), - Connection(438, 439), - Connection(439, 344), - Connection(272, 271), - Connection(271, 311), - Connection(311, 272), - Connection(195, 5), - Connection(5, 281), - Connection(281, 195), - Connection(273, 287), - Connection(287, 291), - Connection(291, 273), - Connection(396, 428), - Connection(428, 199), - Connection(199, 396), - Connection(311, 271), - Connection(271, 268), - Connection(268, 311), - Connection(283, 444), - Connection(444, 445), - Connection(445, 283), - Connection(373, 254), - Connection(254, 339), - Connection(339, 373), - Connection(282, 334), - Connection(334, 296), - Connection(296, 282), - Connection(449, 347), - Connection(347, 346), - Connection(346, 449), - Connection(264, 447), - Connection(447, 454), - Connection(454, 264), - Connection(336, 296), - Connection(296, 299), - Connection(299, 336), - Connection(338, 10), - Connection(10, 151), - Connection(151, 338), - Connection(278, 439), - Connection(439, 455), - Connection(455, 278), - Connection(292, 407), - Connection(407, 415), - Connection(415, 292), - Connection(358, 371), - Connection(371, 355), - Connection(355, 358), - Connection(340, 345), - Connection(345, 372), - Connection(372, 340), - Connection(346, 347), - Connection(347, 280), - Connection(280, 346), - Connection(442, 443), - Connection(443, 282), - Connection(282, 442), - Connection(19, 94), - Connection(94, 370), - Connection(370, 19), - Connection(441, 442), - Connection(442, 295), - Connection(295, 441), - Connection(248, 419), - Connection(419, 197), - Connection(197, 248), - Connection(263, 255), - Connection(255, 359), - Connection(359, 263), - Connection(440, 275), - Connection(275, 274), - Connection(274, 440), - Connection(300, 383), - Connection(383, 368), - Connection(368, 300), - Connection(351, 412), - Connection(412, 465), - Connection(465, 351), - Connection(263, 467), - Connection(467, 466), - Connection(466, 263), - Connection(301, 368), - Connection(368, 389), - Connection(389, 301), - Connection(395, 378), - Connection(378, 379), - Connection(379, 395), - Connection(412, 351), - Connection(351, 419), - Connection(419, 412), - Connection(436, 426), - Connection(426, 322), - Connection(322, 436), - Connection(2, 164), - Connection(164, 393), - Connection(393, 2), - Connection(370, 462), - Connection(462, 461), - Connection(461, 370), - Connection(164, 0), - Connection(0, 267), - Connection(267, 164), - Connection(302, 11), - Connection(11, 12), - Connection(12, 302), - Connection(268, 12), - Connection(12, 13), - Connection(13, 268), - Connection(293, 300), - Connection(300, 301), - Connection(301, 293), - Connection(446, 261), - Connection(261, 340), - Connection(340, 446), - Connection(330, 266), - Connection(266, 425), - Connection(425, 330), - Connection(426, 423), - Connection(423, 391), - Connection(391, 426), - Connection(429, 355), - Connection(355, 437), - Connection(437, 429), - Connection(391, 327), - Connection(327, 326), - Connection(326, 391), - Connection(440, 457), - Connection(457, 438), - Connection(438, 440), - Connection(341, 382), - Connection(382, 362), - Connection(362, 341), - Connection(459, 457), - Connection(457, 461), - Connection(461, 459), - Connection(434, 430), - Connection(430, 394), - Connection(394, 434), - Connection(414, 463), - Connection(463, 362), - Connection(362, 414), - Connection(396, 369), - Connection(369, 262), - Connection(262, 396), - Connection(354, 461), - Connection(461, 457), - Connection(457, 354), - Connection(316, 403), - Connection(403, 402), - Connection(402, 316), - Connection(315, 404), - Connection(404, 403), - Connection(403, 315), - Connection(314, 405), - Connection(405, 404), - Connection(404, 314), - Connection(313, 406), - Connection(406, 405), - Connection(405, 313), - Connection(421, 418), - Connection(418, 406), - Connection(406, 421), - Connection(366, 401), - Connection(401, 361), - Connection(361, 366), - Connection(306, 408), - Connection(408, 407), - Connection(407, 306), - Connection(291, 409), - Connection(409, 408), - Connection(408, 291), - Connection(287, 410), - Connection(410, 409), - Connection(409, 287), - Connection(432, 436), - Connection(436, 410), - Connection(410, 432), - Connection(434, 416), - Connection(416, 411), - Connection(411, 434), - Connection(264, 368), - Connection(368, 383), - Connection(383, 264), - Connection(309, 438), - Connection(438, 457), - Connection(457, 309), - Connection(352, 376), - Connection(376, 401), - Connection(401, 352), - Connection(274, 275), - Connection(275, 4), - Connection(4, 274), - Connection(421, 428), - Connection(428, 262), - Connection(262, 421), - Connection(294, 327), - Connection(327, 358), - Connection(358, 294), - Connection(433, 416), - Connection(416, 367), - Connection(367, 433), - Connection(289, 455), - Connection(455, 439), - Connection(439, 289), - Connection(462, 370), - Connection(370, 326), - Connection(326, 462), - Connection(2, 326), - Connection(326, 370), - Connection(370, 2), - Connection(305, 460), - Connection(460, 455), - Connection(455, 305), - Connection(254, 449), - Connection(449, 448), - Connection(448, 254), - Connection(255, 261), - Connection(261, 446), - Connection(446, 255), - Connection(253, 450), - Connection(450, 449), - Connection(449, 253), - Connection(252, 451), - Connection(451, 450), - Connection(450, 252), - Connection(256, 452), - Connection(452, 451), - Connection(451, 256), - Connection(341, 453), - Connection(453, 452), - Connection(452, 341), - Connection(413, 464), - Connection(464, 463), - Connection(463, 413), - Connection(441, 413), - Connection(413, 414), - Connection(414, 441), - Connection(258, 442), - Connection(442, 441), - Connection(441, 258), - Connection(257, 443), - Connection(443, 442), - Connection(442, 257), - Connection(259, 444), - Connection(444, 443), - Connection(443, 259), - Connection(260, 445), - Connection(445, 444), - Connection(444, 260), - Connection(467, 342), - Connection(342, 445), - Connection(445, 467), - Connection(459, 458), - Connection(458, 250), - Connection(250, 459), - Connection(289, 392), - Connection(392, 290), - Connection(290, 289), - Connection(290, 328), - Connection(328, 460), - Connection(460, 290), - Connection(376, 433), - Connection(433, 435), - Connection(435, 376), - Connection(250, 290), - Connection(290, 392), - Connection(392, 250), - Connection(411, 416), - Connection(416, 433), - Connection(433, 411), - Connection(341, 463), - Connection(463, 464), - Connection(464, 341), - Connection(453, 464), - Connection(464, 465), - Connection(465, 453), - Connection(357, 465), - Connection(465, 412), - Connection(412, 357), - Connection(343, 412), - Connection(412, 399), - Connection(399, 343), - Connection(360, 363), - Connection(363, 440), - Connection(440, 360), - Connection(437, 399), - Connection(399, 456), - Connection(456, 437), - Connection(420, 456), - Connection(456, 363), - Connection(363, 420), - Connection(401, 435), - Connection(435, 288), - Connection(288, 401), - Connection(372, 383), - Connection(383, 353), - Connection(353, 372), - Connection(339, 255), - Connection(255, 249), - Connection(249, 339), - Connection(448, 261), - Connection(261, 255), - Connection(255, 448), - Connection(133, 243), - Connection(243, 190), - Connection(190, 133), - Connection(133, 155), - Connection(155, 112), - Connection(112, 133), - Connection(33, 246), - Connection(246, 247), - Connection(247, 33), - Connection(33, 130), - Connection(130, 25), - Connection(25, 33), - Connection(398, 384), - Connection(384, 286), - Connection(286, 398), - Connection(362, 398), - Connection(398, 414), - Connection(414, 362), - Connection(362, 463), - Connection(463, 341), - Connection(341, 362), - Connection(263, 359), - Connection(359, 467), - Connection(467, 263), - Connection(263, 249), - Connection(249, 255), - Connection(255, 263), - Connection(466, 467), - Connection(467, 260), - Connection(260, 466), - Connection(75, 60), - Connection(60, 166), - Connection(166, 75), - Connection(238, 239), - Connection(239, 79), - Connection(79, 238), - Connection(162, 127), - Connection(127, 139), - Connection(139, 162), - Connection(72, 11), - Connection(11, 37), - Connection(37, 72), - Connection(121, 232), - Connection(232, 120), - Connection(120, 121), - Connection(73, 72), - Connection(72, 39), - Connection(39, 73), - Connection(114, 128), - Connection(128, 47), - Connection(47, 114), - Connection(233, 232), - Connection(232, 128), - Connection(128, 233), - Connection(103, 104), - Connection(104, 67), - Connection(67, 103), - Connection(152, 175), - Connection(175, 148), - Connection(148, 152), - Connection(119, 118), - Connection(118, 101), - Connection(101, 119), - Connection(74, 73), - Connection(73, 40), - Connection(40, 74), - Connection(107, 9), - Connection(9, 108), - Connection(108, 107), - Connection(49, 48), - Connection(48, 131), - Connection(131, 49), - Connection(32, 194), - Connection(194, 211), - Connection(211, 32), - Connection(184, 74), - Connection(74, 185), - Connection(185, 184), - Connection(191, 80), - Connection(80, 183), - Connection(183, 191), - Connection(185, 40), - Connection(40, 186), - Connection(186, 185), - Connection(119, 230), - Connection(230, 118), - Connection(118, 119), - Connection(210, 202), - Connection(202, 214), - Connection(214, 210), - Connection(84, 83), - Connection(83, 17), - Connection(17, 84), - Connection(77, 76), - Connection(76, 146), - Connection(146, 77), - Connection(161, 160), - Connection(160, 30), - Connection(30, 161), - Connection(190, 56), - Connection(56, 173), - Connection(173, 190), - Connection(182, 106), - Connection(106, 194), - Connection(194, 182), - Connection(138, 135), - Connection(135, 192), - Connection(192, 138), - Connection(129, 203), - Connection(203, 98), - Connection(98, 129), - Connection(54, 21), - Connection(21, 68), - Connection(68, 54), - Connection(5, 51), - Connection(51, 4), - Connection(4, 5), - Connection(145, 144), - Connection(144, 23), - Connection(23, 145), - Connection(90, 77), - Connection(77, 91), - Connection(91, 90), - Connection(207, 205), - Connection(205, 187), - Connection(187, 207), - Connection(83, 201), - Connection(201, 18), - Connection(18, 83), - Connection(181, 91), - Connection(91, 182), - Connection(182, 181), - Connection(180, 90), - Connection(90, 181), - Connection(181, 180), - Connection(16, 85), - Connection(85, 17), - Connection(17, 16), - Connection(205, 206), - Connection(206, 36), - Connection(36, 205), - Connection(176, 148), - Connection(148, 140), - Connection(140, 176), - Connection(165, 92), - Connection(92, 39), - Connection(39, 165), - Connection(245, 193), - Connection(193, 244), - Connection(244, 245), - Connection(27, 159), - Connection(159, 28), - Connection(28, 27), - Connection(30, 247), - Connection(247, 161), - Connection(161, 30), - Connection(174, 236), - Connection(236, 196), - Connection(196, 174), - Connection(103, 54), - Connection(54, 104), - Connection(104, 103), - Connection(55, 193), - Connection(193, 8), - Connection(8, 55), - Connection(111, 117), - Connection(117, 31), - Connection(31, 111), - Connection(221, 189), - Connection(189, 55), - Connection(55, 221), - Connection(240, 98), - Connection(98, 99), - Connection(99, 240), - Connection(142, 126), - Connection(126, 100), - Connection(100, 142), - Connection(219, 166), - Connection(166, 218), - Connection(218, 219), - Connection(112, 155), - Connection(155, 26), - Connection(26, 112), - Connection(198, 209), - Connection(209, 131), - Connection(131, 198), - Connection(169, 135), - Connection(135, 150), - Connection(150, 169), - Connection(114, 47), - Connection(47, 217), - Connection(217, 114), - Connection(224, 223), - Connection(223, 53), - Connection(53, 224), - Connection(220, 45), - Connection(45, 134), - Connection(134, 220), - Connection(32, 211), - Connection(211, 140), - Connection(140, 32), - Connection(109, 67), - Connection(67, 108), - Connection(108, 109), - Connection(146, 43), - Connection(43, 91), - Connection(91, 146), - Connection(231, 230), - Connection(230, 120), - Connection(120, 231), - Connection(113, 226), - Connection(226, 247), - Connection(247, 113), - Connection(105, 63), - Connection(63, 52), - Connection(52, 105), - Connection(241, 238), - Connection(238, 242), - Connection(242, 241), - Connection(124, 46), - Connection(46, 156), - Connection(156, 124), - Connection(95, 78), - Connection(78, 96), - Connection(96, 95), - Connection(70, 46), - Connection(46, 63), - Connection(63, 70), - Connection(116, 143), - Connection(143, 227), - Connection(227, 116), - Connection(116, 123), - Connection(123, 111), - Connection(111, 116), - Connection(1, 44), - Connection(44, 19), - Connection(19, 1), - Connection(3, 236), - Connection(236, 51), - Connection(51, 3), - Connection(207, 216), - Connection(216, 205), - Connection(205, 207), - Connection(26, 154), - Connection(154, 22), - Connection(22, 26), - Connection(165, 39), - Connection(39, 167), - Connection(167, 165), - Connection(199, 200), - Connection(200, 208), - Connection(208, 199), - Connection(101, 36), - Connection(36, 100), - Connection(100, 101), - Connection(43, 57), - Connection(57, 202), - Connection(202, 43), - Connection(242, 20), - Connection(20, 99), - Connection(99, 242), - Connection(56, 28), - Connection(28, 157), - Connection(157, 56), - Connection(124, 35), - Connection(35, 113), - Connection(113, 124), - Connection(29, 160), - Connection(160, 27), - Connection(27, 29), - Connection(211, 204), - Connection(204, 210), - Connection(210, 211), - Connection(124, 113), - Connection(113, 46), - Connection(46, 124), - Connection(106, 43), - Connection(43, 204), - Connection(204, 106), - Connection(96, 62), - Connection(62, 77), - Connection(77, 96), - Connection(227, 137), - Connection(137, 116), - Connection(116, 227), - Connection(73, 41), - Connection(41, 72), - Connection(72, 73), - Connection(36, 203), - Connection(203, 142), - Connection(142, 36), - Connection(235, 64), - Connection(64, 240), - Connection(240, 235), - Connection(48, 49), - Connection(49, 64), - Connection(64, 48), - Connection(42, 41), - Connection(41, 74), - Connection(74, 42), - Connection(214, 212), - Connection(212, 207), - Connection(207, 214), - Connection(183, 42), - Connection(42, 184), - Connection(184, 183), - Connection(210, 169), - Connection(169, 211), - Connection(211, 210), - Connection(140, 170), - Connection(170, 176), - Connection(176, 140), - Connection(104, 105), - Connection(105, 69), - Connection(69, 104), - Connection(193, 122), - Connection(122, 168), - Connection(168, 193), - Connection(50, 123), - Connection(123, 187), - Connection(187, 50), - Connection(89, 96), - Connection(96, 90), - Connection(90, 89), - Connection(66, 65), - Connection(65, 107), - Connection(107, 66), - Connection(179, 89), - Connection(89, 180), - Connection(180, 179), - Connection(119, 101), - Connection(101, 120), - Connection(120, 119), - Connection(68, 63), - Connection(63, 104), - Connection(104, 68), - Connection(234, 93), - Connection(93, 227), - Connection(227, 234), - Connection(16, 15), - Connection(15, 85), - Connection(85, 16), - Connection(209, 129), - Connection(129, 49), - Connection(49, 209), - Connection(15, 14), - Connection(14, 86), - Connection(86, 15), - Connection(107, 55), - Connection(55, 9), - Connection(9, 107), - Connection(120, 100), - Connection(100, 121), - Connection(121, 120), - Connection(153, 145), - Connection(145, 22), - Connection(22, 153), - Connection(178, 88), - Connection(88, 179), - Connection(179, 178), - Connection(197, 6), - Connection(6, 196), - Connection(196, 197), - Connection(89, 88), - Connection(88, 96), - Connection(96, 89), - Connection(135, 138), - Connection(138, 136), - Connection(136, 135), - Connection(138, 215), - Connection(215, 172), - Connection(172, 138), - Connection(218, 115), - Connection(115, 219), - Connection(219, 218), - Connection(41, 42), - Connection(42, 81), - Connection(81, 41), - Connection(5, 195), - Connection(195, 51), - Connection(51, 5), - Connection(57, 43), - Connection(43, 61), - Connection(61, 57), - Connection(208, 171), - Connection(171, 199), - Connection(199, 208), - Connection(41, 81), - Connection(81, 38), - Connection(38, 41), - Connection(224, 53), - Connection(53, 225), - Connection(225, 224), - Connection(24, 144), - Connection(144, 110), - Connection(110, 24), - Connection(105, 52), - Connection(52, 66), - Connection(66, 105), - Connection(118, 229), - Connection(229, 117), - Connection(117, 118), - Connection(227, 34), - Connection(34, 234), - Connection(234, 227), - Connection(66, 107), - Connection(107, 69), - Connection(69, 66), - Connection(10, 109), - Connection(109, 151), - Connection(151, 10), - Connection(219, 48), - Connection(48, 235), - Connection(235, 219), - Connection(183, 62), - Connection(62, 191), - Connection(191, 183), - Connection(142, 129), - Connection(129, 126), - Connection(126, 142), - Connection(116, 111), - Connection(111, 143), - Connection(143, 116), - Connection(118, 117), - Connection(117, 50), - Connection(50, 118), - Connection(223, 222), - Connection(222, 52), - Connection(52, 223), - Connection(94, 19), - Connection(19, 141), - Connection(141, 94), - Connection(222, 221), - Connection(221, 65), - Connection(65, 222), - Connection(196, 3), - Connection(3, 197), - Connection(197, 196), - Connection(45, 220), - Connection(220, 44), - Connection(44, 45), - Connection(156, 70), - Connection(70, 139), - Connection(139, 156), - Connection(188, 122), - Connection(122, 245), - Connection(245, 188), - Connection(139, 71), - Connection(71, 162), - Connection(162, 139), - Connection(149, 170), - Connection(170, 150), - Connection(150, 149), - Connection(122, 188), - Connection(188, 196), - Connection(196, 122), - Connection(206, 216), - Connection(216, 92), - Connection(92, 206), - Connection(164, 2), - Connection(2, 167), - Connection(167, 164), - Connection(242, 141), - Connection(141, 241), - Connection(241, 242), - Connection(0, 164), - Connection(164, 37), - Connection(37, 0), - Connection(11, 72), - Connection(72, 12), - Connection(12, 11), - Connection(12, 38), - Connection(38, 13), - Connection(13, 12), - Connection(70, 63), - Connection(63, 71), - Connection(71, 70), - Connection(31, 226), - Connection(226, 111), - Connection(111, 31), - Connection(36, 101), - Connection(101, 205), - Connection(205, 36), - Connection(203, 206), - Connection(206, 165), - Connection(165, 203), - Connection(126, 209), - Connection(209, 217), - Connection(217, 126), - Connection(98, 165), - Connection(165, 97), - Connection(97, 98), - Connection(237, 220), - Connection(220, 218), - Connection(218, 237), - Connection(237, 239), - Connection(239, 241), - Connection(241, 237), - Connection(210, 214), - Connection(214, 169), - Connection(169, 210), - Connection(140, 171), - Connection(171, 32), - Connection(32, 140), - Connection(241, 125), - Connection(125, 237), - Connection(237, 241), - Connection(179, 86), - Connection(86, 178), - Connection(178, 179), - Connection(180, 85), - Connection(85, 179), - Connection(179, 180), - Connection(181, 84), - Connection(84, 180), - Connection(180, 181), - Connection(182, 83), - Connection(83, 181), - Connection(181, 182), - Connection(194, 201), - Connection(201, 182), - Connection(182, 194), - Connection(177, 137), - Connection(137, 132), - Connection(132, 177), - Connection(184, 76), - Connection(76, 183), - Connection(183, 184), - Connection(185, 61), - Connection(61, 184), - Connection(184, 185), - Connection(186, 57), - Connection(57, 185), - Connection(185, 186), - Connection(216, 212), - Connection(212, 186), - Connection(186, 216), - Connection(192, 214), - Connection(214, 187), - Connection(187, 192), - Connection(139, 34), - Connection(34, 156), - Connection(156, 139), - Connection(218, 79), - Connection(79, 237), - Connection(237, 218), - Connection(147, 123), - Connection(123, 177), - Connection(177, 147), - Connection(45, 44), - Connection(44, 4), - Connection(4, 45), - Connection(208, 201), - Connection(201, 32), - Connection(32, 208), - Connection(98, 64), - Connection(64, 129), - Connection(129, 98), - Connection(192, 213), - Connection(213, 138), - Connection(138, 192), - Connection(235, 59), - Connection(59, 219), - Connection(219, 235), - Connection(141, 242), - Connection(242, 97), - Connection(97, 141), - Connection(97, 2), - Connection(2, 141), - Connection(141, 97), - Connection(240, 75), - Connection(75, 235), - Connection(235, 240), - Connection(229, 24), - Connection(24, 228), - Connection(228, 229), - Connection(31, 25), - Connection(25, 226), - Connection(226, 31), - Connection(230, 23), - Connection(23, 229), - Connection(229, 230), - Connection(231, 22), - Connection(22, 230), - Connection(230, 231), - Connection(232, 26), - Connection(26, 231), - Connection(231, 232), - Connection(233, 112), - Connection(112, 232), - Connection(232, 233), - Connection(244, 189), - Connection(189, 243), - Connection(243, 244), - Connection(189, 221), - Connection(221, 190), - Connection(190, 189), - Connection(222, 28), - Connection(28, 221), - Connection(221, 222), - Connection(223, 27), - Connection(27, 222), - Connection(222, 223), - Connection(224, 29), - Connection(29, 223), - Connection(223, 224), - Connection(225, 30), - Connection(30, 224), - Connection(224, 225), - Connection(113, 247), - Connection(247, 225), - Connection(225, 113), - Connection(99, 60), - Connection(60, 240), - Connection(240, 99), - Connection(213, 147), - Connection(147, 215), - Connection(215, 213), - Connection(60, 20), - Connection(20, 166), - Connection(166, 60), - Connection(192, 187), - Connection(187, 213), - Connection(213, 192), - Connection(243, 112), - Connection(112, 244), - Connection(244, 243), - Connection(244, 233), - Connection(233, 245), - Connection(245, 244), - Connection(245, 128), - Connection(128, 188), - Connection(188, 245), - Connection(188, 114), - Connection(114, 174), - Connection(174, 188), - Connection(134, 131), - Connection(131, 220), - Connection(220, 134), - Connection(174, 217), - Connection(217, 236), - Connection(236, 174), - Connection(236, 198), - Connection(198, 134), - Connection(134, 236), - Connection(215, 177), - Connection(177, 58), - Connection(58, 215), - Connection(156, 143), - Connection(143, 124), - Connection(124, 156), - Connection(25, 110), - Connection(110, 7), - Connection(7, 25), - Connection(31, 228), - Connection(228, 25), - Connection(25, 31), - Connection(264, 356), - Connection(356, 368), - Connection(368, 264), - Connection(0, 11), - Connection(11, 267), - Connection(267, 0), - Connection(451, 452), - Connection(452, 349), - Connection(349, 451), - Connection(267, 302), - Connection(302, 269), - Connection(269, 267), - Connection(350, 357), - Connection(357, 277), - Connection(277, 350), - Connection(350, 452), - Connection(452, 357), - Connection(357, 350), - Connection(299, 333), - Connection(333, 297), - Connection(297, 299), - Connection(396, 175), - Connection(175, 377), - Connection(377, 396), - Connection(280, 347), - Connection(347, 330), - Connection(330, 280), - Connection(269, 303), - Connection(303, 270), - Connection(270, 269), - Connection(151, 9), - Connection(9, 337), - Connection(337, 151), - Connection(344, 278), - Connection(278, 360), - Connection(360, 344), - Connection(424, 418), - Connection(418, 431), - Connection(431, 424), - Connection(270, 304), - Connection(304, 409), - Connection(409, 270), - Connection(272, 310), - Connection(310, 407), - Connection(407, 272), - Connection(322, 270), - Connection(270, 410), - Connection(410, 322), - Connection(449, 450), - Connection(450, 347), - Connection(347, 449), - Connection(432, 422), - Connection(422, 434), - Connection(434, 432), - Connection(18, 313), - Connection(313, 17), - Connection(17, 18), - Connection(291, 306), - Connection(306, 375), - Connection(375, 291), - Connection(259, 387), - Connection(387, 260), - Connection(260, 259), - Connection(424, 335), - Connection(335, 418), - Connection(418, 424), - Connection(434, 364), - Connection(364, 416), - Connection(416, 434), - Connection(391, 423), - Connection(423, 327), - Connection(327, 391), - Connection(301, 251), - Connection(251, 298), - Connection(298, 301), - Connection(275, 281), - Connection(281, 4), - Connection(4, 275), - Connection(254, 373), - Connection(373, 253), - Connection(253, 254), - Connection(375, 307), - Connection(307, 321), - Connection(321, 375), - Connection(280, 425), - Connection(425, 411), - Connection(411, 280), - Connection(200, 421), - Connection(421, 18), - Connection(18, 200), - Connection(335, 321), - Connection(321, 406), - Connection(406, 335), - Connection(321, 320), - Connection(320, 405), - Connection(405, 321), - Connection(314, 315), - Connection(315, 17), - Connection(17, 314), - Connection(423, 426), - Connection(426, 266), - Connection(266, 423), - Connection(396, 377), - Connection(377, 369), - Connection(369, 396), - Connection(270, 322), - Connection(322, 269), - Connection(269, 270), - Connection(413, 417), - Connection(417, 464), - Connection(464, 413), - Connection(385, 386), - Connection(386, 258), - Connection(258, 385), - Connection(248, 456), - Connection(456, 419), - Connection(419, 248), - Connection(298, 284), - Connection(284, 333), - Connection(333, 298), - Connection(168, 417), - Connection(417, 8), - Connection(8, 168), - Connection(448, 346), - Connection(346, 261), - Connection(261, 448), - Connection(417, 413), - Connection(413, 285), - Connection(285, 417), - Connection(326, 327), - Connection(327, 328), - Connection(328, 326), - Connection(277, 355), - Connection(355, 329), - Connection(329, 277), - Connection(309, 392), - Connection(392, 438), - Connection(438, 309), - Connection(381, 382), - Connection(382, 256), - Connection(256, 381), - Connection(279, 429), - Connection(429, 360), - Connection(360, 279), - Connection(365, 364), - Connection(364, 379), - Connection(379, 365), - Connection(355, 277), - Connection(277, 437), - Connection(437, 355), - Connection(282, 443), - Connection(443, 283), - Connection(283, 282), - Connection(281, 275), - Connection(275, 363), - Connection(363, 281), - Connection(395, 431), - Connection(431, 369), - Connection(369, 395), - Connection(299, 297), - Connection(297, 337), - Connection(337, 299), - Connection(335, 273), - Connection(273, 321), - Connection(321, 335), - Connection(348, 450), - Connection(450, 349), - Connection(349, 348), - Connection(359, 446), - Connection(446, 467), - Connection(467, 359), - Connection(283, 293), - Connection(293, 282), - Connection(282, 283), - Connection(250, 458), - Connection(458, 462), - Connection(462, 250), - Connection(300, 276), - Connection(276, 383), - Connection(383, 300), - Connection(292, 308), - Connection(308, 325), - Connection(325, 292), - Connection(283, 276), - Connection(276, 293), - Connection(293, 283), - Connection(264, 372), - Connection(372, 447), - Connection(447, 264), - Connection(346, 352), - Connection(352, 340), - Connection(340, 346), - Connection(354, 274), - Connection(274, 19), - Connection(19, 354), - Connection(363, 456), - Connection(456, 281), - Connection(281, 363), - Connection(426, 436), - Connection(436, 425), - Connection(425, 426), - Connection(380, 381), - Connection(381, 252), - Connection(252, 380), - Connection(267, 269), - Connection(269, 393), - Connection(393, 267), - Connection(421, 200), - Connection(200, 428), - Connection(428, 421), - Connection(371, 266), - Connection(266, 329), - Connection(329, 371), - Connection(432, 287), - Connection(287, 422), - Connection(422, 432), - Connection(290, 250), - Connection(250, 328), - Connection(328, 290), - Connection(385, 258), - Connection(258, 384), - Connection(384, 385), - Connection(446, 265), - Connection(265, 342), - Connection(342, 446), - Connection(386, 387), - Connection(387, 257), - Connection(257, 386), - Connection(422, 424), - Connection(424, 430), - Connection(430, 422), - Connection(445, 342), - Connection(342, 276), - Connection(276, 445), - Connection(422, 273), - Connection(273, 424), - Connection(424, 422), - Connection(306, 292), - Connection(292, 307), - Connection(307, 306), - Connection(352, 366), - Connection(366, 345), - Connection(345, 352), - Connection(268, 271), - Connection(271, 302), - Connection(302, 268), - Connection(358, 423), - Connection(423, 371), - Connection(371, 358), - Connection(327, 294), - Connection(294, 460), - Connection(460, 327), - Connection(331, 279), - Connection(279, 294), - Connection(294, 331), - Connection(303, 271), - Connection(271, 304), - Connection(304, 303), - Connection(436, 432), - Connection(432, 427), - Connection(427, 436), - Connection(304, 272), - Connection(272, 408), - Connection(408, 304), - Connection(395, 394), - Connection(394, 431), - Connection(431, 395), - Connection(378, 395), - Connection(395, 400), - Connection(400, 378), - Connection(296, 334), - Connection(334, 299), - Connection(299, 296), - Connection(6, 351), - Connection(351, 168), - Connection(168, 6), - Connection(376, 352), - Connection(352, 411), - Connection(411, 376), - Connection(307, 325), - Connection(325, 320), - Connection(320, 307), - Connection(285, 295), - Connection(295, 336), - Connection(336, 285), - Connection(320, 319), - Connection(319, 404), - Connection(404, 320), - Connection(329, 330), - Connection(330, 349), - Connection(349, 329), - Connection(334, 293), - Connection(293, 333), - Connection(333, 334), - Connection(366, 323), - Connection(323, 447), - Connection(447, 366), - Connection(316, 15), - Connection(15, 315), - Connection(315, 316), - Connection(331, 358), - Connection(358, 279), - Connection(279, 331), - Connection(317, 14), - Connection(14, 316), - Connection(316, 317), - Connection(8, 285), - Connection(285, 9), - Connection(9, 8), - Connection(277, 329), - Connection(329, 350), - Connection(350, 277), - Connection(253, 374), - Connection(374, 252), - Connection(252, 253), - Connection(319, 318), - Connection(318, 403), - Connection(403, 319), - Connection(351, 6), - Connection(6, 419), - Connection(419, 351), - Connection(324, 318), - Connection(318, 325), - Connection(325, 324), - Connection(397, 367), - Connection(367, 365), - Connection(365, 397), - Connection(288, 435), - Connection(435, 397), - Connection(397, 288), - Connection(278, 344), - Connection(344, 439), - Connection(439, 278), - Connection(310, 272), - Connection(272, 311), - Connection(311, 310), - Connection(248, 195), - Connection(195, 281), - Connection(281, 248), - Connection(375, 273), - Connection(273, 291), - Connection(291, 375), - Connection(175, 396), - Connection(396, 199), - Connection(199, 175), - Connection(312, 311), - Connection(311, 268), - Connection(268, 312), - Connection(276, 283), - Connection(283, 445), - Connection(445, 276), - Connection(390, 373), - Connection(373, 339), - Connection(339, 390), - Connection(295, 282), - Connection(282, 296), - Connection(296, 295), - Connection(448, 449), - Connection(449, 346), - Connection(346, 448), - Connection(356, 264), - Connection(264, 454), - Connection(454, 356), - Connection(337, 336), - Connection(336, 299), - Connection(299, 337), - Connection(337, 338), - Connection(338, 151), - Connection(151, 337), - Connection(294, 278), - Connection(278, 455), - Connection(455, 294), - Connection(308, 292), - Connection(292, 415), - Connection(415, 308), - Connection(429, 358), - Connection(358, 355), - Connection(355, 429), - Connection(265, 340), - Connection(340, 372), - Connection(372, 265), - Connection(352, 346), - Connection(346, 280), - Connection(280, 352), - Connection(295, 442), - Connection(442, 282), - Connection(282, 295), - Connection(354, 19), - Connection(19, 370), - Connection(370, 354), - Connection(285, 441), - Connection(441, 295), - Connection(295, 285), - Connection(195, 248), - Connection(248, 197), - Connection(197, 195), - Connection(457, 440), - Connection(440, 274), - Connection(274, 457), - Connection(301, 300), - Connection(300, 368), - Connection(368, 301), - Connection(417, 351), - Connection(351, 465), - Connection(465, 417), - Connection(251, 301), - Connection(301, 389), - Connection(389, 251), - Connection(394, 395), - Connection(395, 379), - Connection(379, 394), - Connection(399, 412), - Connection(412, 419), - Connection(419, 399), - Connection(410, 436), - Connection(436, 322), - Connection(322, 410), - Connection(326, 2), - Connection(2, 393), - Connection(393, 326), - Connection(354, 370), - Connection(370, 461), - Connection(461, 354), - Connection(393, 164), - Connection(164, 267), - Connection(267, 393), - Connection(268, 302), - Connection(302, 12), - Connection(12, 268), - Connection(312, 268), - Connection(268, 13), - Connection(13, 312), - Connection(298, 293), - Connection(293, 301), - Connection(301, 298), - Connection(265, 446), - Connection(446, 340), - Connection(340, 265), - Connection(280, 330), - Connection(330, 425), - Connection(425, 280), - Connection(322, 426), - Connection(426, 391), - Connection(391, 322), - Connection(420, 429), - Connection(429, 437), - Connection(437, 420), - Connection(393, 391), - Connection(391, 326), - Connection(326, 393), - Connection(344, 440), - Connection(440, 438), - Connection(438, 344), - Connection(458, 459), - Connection(459, 461), - Connection(461, 458), - Connection(364, 434), - Connection(434, 394), - Connection(394, 364), - Connection(428, 396), - Connection(396, 262), - Connection(262, 428), - Connection(274, 354), - Connection(354, 457), - Connection(457, 274), - Connection(317, 316), - Connection(316, 402), - Connection(402, 317), - Connection(316, 315), - Connection(315, 403), - Connection(403, 316), - Connection(315, 314), - Connection(314, 404), - Connection(404, 315), - Connection(314, 313), - Connection(313, 405), - Connection(405, 314), - Connection(313, 421), - Connection(421, 406), - Connection(406, 313), - Connection(323, 366), - Connection(366, 361), - Connection(361, 323), - Connection(292, 306), - Connection(306, 407), - Connection(407, 292), - Connection(306, 291), - Connection(291, 408), - Connection(408, 306), - Connection(291, 287), - Connection(287, 409), - Connection(409, 291), - Connection(287, 432), - Connection(432, 410), - Connection(410, 287), - Connection(427, 434), - Connection(434, 411), - Connection(411, 427), - Connection(372, 264), - Connection(264, 383), - Connection(383, 372), - Connection(459, 309), - Connection(309, 457), - Connection(457, 459), - Connection(366, 352), - Connection(352, 401), - Connection(401, 366), - Connection(1, 274), - Connection(274, 4), - Connection(4, 1), - Connection(418, 421), - Connection(421, 262), - Connection(262, 418), - Connection(331, 294), - Connection(294, 358), - Connection(358, 331), - Connection(435, 433), - Connection(433, 367), - Connection(367, 435), - Connection(392, 289), - Connection(289, 439), - Connection(439, 392), - Connection(328, 462), - Connection(462, 326), - Connection(326, 328), - Connection(94, 2), - Connection(2, 370), - Connection(370, 94), - Connection(289, 305), - Connection(305, 455), - Connection(455, 289), - Connection(339, 254), - Connection(254, 448), - Connection(448, 339), - Connection(359, 255), - Connection(255, 446), - Connection(446, 359), - Connection(254, 253), - Connection(253, 449), - Connection(449, 254), - Connection(253, 252), - Connection(252, 450), - Connection(450, 253), - Connection(252, 256), - Connection(256, 451), - Connection(451, 252), - Connection(256, 341), - Connection(341, 452), - Connection(452, 256), - Connection(414, 413), - Connection(413, 463), - Connection(463, 414), - Connection(286, 441), - Connection(441, 414), - Connection(414, 286), - Connection(286, 258), - Connection(258, 441), - Connection(441, 286), - Connection(258, 257), - Connection(257, 442), - Connection(442, 258), - Connection(257, 259), - Connection(259, 443), - Connection(443, 257), - Connection(259, 260), - Connection(260, 444), - Connection(444, 259), - Connection(260, 467), - Connection(467, 445), - Connection(445, 260), - Connection(309, 459), - Connection(459, 250), - Connection(250, 309), - Connection(305, 289), - Connection(289, 290), - Connection(290, 305), - Connection(305, 290), - Connection(290, 460), - Connection(460, 305), - Connection(401, 376), - Connection(376, 435), - Connection(435, 401), - Connection(309, 250), - Connection(250, 392), - Connection(392, 309), - Connection(376, 411), - Connection(411, 433), - Connection(433, 376), - Connection(453, 341), - Connection(341, 464), - Connection(464, 453), - Connection(357, 453), - Connection(453, 465), - Connection(465, 357), - Connection(343, 357), - Connection(357, 412), - Connection(412, 343), - Connection(437, 343), - Connection(343, 399), - Connection(399, 437), - Connection(344, 360), - Connection(360, 440), - Connection(440, 344), - Connection(420, 437), - Connection(437, 456), - Connection(456, 420), - Connection(360, 420), - Connection(420, 363), - Connection(363, 360), - Connection(361, 401), - Connection(401, 288), - Connection(288, 361), - Connection(265, 372), - Connection(372, 353), - Connection(353, 265), - Connection(390, 339), - Connection(339, 249), - Connection(249, 390), - Connection(339, 448), - Connection(448, 255), - Connection(255, 339), - ] - - -@dataclasses.dataclass -class FaceLandmarkerResult: - """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image. - - Attributes: - face_landmarks: Detected face landmarks in normalized image coordinates. - face_blendshapes: Optional face blendshapes results. - facial_transformation_matrixes: Optional facial transformation matrix. - """ - - face_landmarks: List[List[landmark_module.NormalizedLandmark]] - face_blendshapes: List[List[category_module.Category]] - facial_transformation_matrixes: List[np.ndarray] - - -def _build_landmarker_result( - output_packets: Mapping[str, packet_module.Packet] -) -> FaceLandmarkerResult: - """Constructs a `FaceLandmarkerResult` from output packets.""" - face_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_NORM_LANDMARKS_STREAM_NAME] - ) - - face_landmarks_results = [] - for proto in face_landmarks_proto_list: - face_landmarks = landmark_pb2.NormalizedLandmarkList() - face_landmarks.MergeFrom(proto) - face_landmarks_list = [] - for face_landmark in face_landmarks.landmark: - face_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) - ) - face_landmarks_results.append(face_landmarks_list) - - face_blendshapes_results = [] - if _BLENDSHAPES_STREAM_NAME in output_packets: - face_blendshapes_proto_list = packet_getter.get_proto_list( - output_packets[_BLENDSHAPES_STREAM_NAME] - ) - for proto in face_blendshapes_proto_list: - face_blendshapes_categories = [] - face_blendshapes_classifications = classification_pb2.ClassificationList() - face_blendshapes_classifications.MergeFrom(proto) - for face_blendshapes in face_blendshapes_classifications.classification: - face_blendshapes_categories.append( - category_module.Category( - index=face_blendshapes.index, - score=face_blendshapes.score, - display_name=face_blendshapes.display_name, - category_name=face_blendshapes.label, - ) - ) - face_blendshapes_results.append(face_blendshapes_categories) - - facial_transformation_matrixes_results = [] - if _FACE_GEOMETRY_STREAM_NAME in output_packets: - facial_transformation_matrixes_proto_list = packet_getter.get_proto_list( - output_packets[_FACE_GEOMETRY_STREAM_NAME] - ) - for proto in facial_transformation_matrixes_proto_list: - if hasattr(proto, 'pose_transform_matrix'): - matrix_data = matrix_data_pb2.MatrixData() - matrix_data.MergeFrom(proto.pose_transform_matrix) - matrix = np.array(matrix_data.packed_data) - matrix = matrix.reshape((matrix_data.rows, matrix_data.cols)) - matrix = ( - matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T - ) - facial_transformation_matrixes_results.append(matrix) - - return FaceLandmarkerResult( - face_landmarks_results, - face_blendshapes_results, - facial_transformation_matrixes_results, - ) - -def _build_landmarker_result2( - output_packets: Mapping[str, packet_module.Packet] -) -> FaceLandmarkerResult: - """Constructs a `FaceLandmarkerResult` from output packets.""" - face_landmarks_proto_list = packet_getter.get_proto_list( - output_packets[_NORM_LANDMARKS_STREAM_NAME] - ) - - face_landmarks_results = [] - for proto in face_landmarks_proto_list: - face_landmarks = landmark_pb2.NormalizedLandmarkList() - face_landmarks.MergeFrom(proto) - face_landmarks_list = [] - for face_landmark in face_landmarks.landmark: - face_landmarks_list.append( - landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) - ) - face_landmarks_results.append(face_landmarks_list) - - face_blendshapes_results = [] - if _BLENDSHAPES_STREAM_NAME in output_packets: - face_blendshapes_proto_list = packet_getter.get_proto_list( - output_packets[_BLENDSHAPES_STREAM_NAME] - ) - for proto in face_blendshapes_proto_list: - face_blendshapes_categories = [] - face_blendshapes_classifications = classification_pb2.ClassificationList() - face_blendshapes_classifications.MergeFrom(proto) - for face_blendshapes in face_blendshapes_classifications.classification: - face_blendshapes_categories.append( - category_module.Category( - index=face_blendshapes.index, - score=face_blendshapes.score, - display_name=face_blendshapes.display_name, - category_name=face_blendshapes.label, - ) - ) - face_blendshapes_results.append(face_blendshapes_categories) - - facial_transformation_matrixes_results = [] - if _FACE_GEOMETRY_STREAM_NAME in output_packets: - facial_transformation_matrixes_proto_list = packet_getter.get_proto_list( - output_packets[_FACE_GEOMETRY_STREAM_NAME] - ) - for proto in facial_transformation_matrixes_proto_list: - if hasattr(proto, 'pose_transform_matrix'): - matrix_data = matrix_data_pb2.MatrixData() - matrix_data.MergeFrom(proto.pose_transform_matrix) - matrix = np.array(matrix_data.packed_data) - matrix = matrix.reshape((matrix_data.rows, matrix_data.cols)) - matrix = ( - matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T - ) - facial_transformation_matrixes_results.append(matrix) - - return FaceLandmarkerResult( - face_landmarks_results, - face_blendshapes_results, - facial_transformation_matrixes_results, - ), facial_transformation_matrixes_proto_list[0].mesh - -@dataclasses.dataclass -class FaceLandmarkerOptions: - """Options for the face landmarker task. - - Attributes: - base_options: Base options for the face landmarker task. - running_mode: The running mode of the task. Default to the image mode. - FaceLandmarker has three running modes: 1) The image mode for detecting - face landmarks on single image inputs. 2) The video mode for detecting - face landmarks on the decoded frames of a video. 3) The live stream mode - for detecting face landmarks on the live stream of input data, such as - from camera. In this mode, the "result_callback" below must be specified - to receive the detection results asynchronously. - num_faces: The maximum number of faces that can be detected by the - FaceLandmarker. - min_face_detection_confidence: The minimum confidence score for the face - detection to be considered successful. - min_face_presence_confidence: The minimum confidence score of face presence - score in the face landmark detection. - min_tracking_confidence: The minimum confidence score for the face tracking - to be considered successful. - output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes - classification. Face blendshapes are used for rendering the 3D face model. - output_facial_transformation_matrixes: Whether FaceLandmarker outputs facial - transformation_matrix. Facial transformation matrix is used to transform - the face landmarks in canonical face to the detected face, so that users - can apply face effects on the detected landmarks. - result_callback: The user-defined result callback for processing live stream - data. The result callback should only be specified when the running mode - is set to the live stream mode. - """ - - base_options: _BaseOptions - running_mode: _RunningMode = _RunningMode.IMAGE - num_faces: int = 1 - min_face_detection_confidence: float = 0.5 - min_face_presence_confidence: float = 0.5 - min_tracking_confidence: float = 0.5 - output_face_blendshapes: bool = False - output_facial_transformation_matrixes: bool = False - result_callback: Optional[ - Callable[[FaceLandmarkerResult, image_module.Image, int], None] - ] = None - - @doc_controls.do_not_generate_docs - def to_pb2(self) -> _FaceLandmarkerGraphOptionsProto: - """Generates an FaceLandmarkerGraphOptions protobuf object.""" - base_options_proto = self.base_options.to_pb2() - base_options_proto.use_stream_mode = ( - False if self.running_mode == _RunningMode.IMAGE else True - ) - - # Initialize the face landmarker options from base options. - face_landmarker_options_proto = _FaceLandmarkerGraphOptionsProto( - base_options=base_options_proto - ) - - # Configure face detector options. - face_landmarker_options_proto.face_detector_graph_options.num_faces = ( - self.num_faces - ) - face_landmarker_options_proto.face_detector_graph_options.min_detection_confidence = ( - self.min_face_detection_confidence - ) - - # Configure face landmark detector options. - face_landmarker_options_proto.min_tracking_confidence = ( - self.min_tracking_confidence - ) - face_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = ( - self.min_face_detection_confidence - ) - return face_landmarker_options_proto - - -class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi): - """Class that performs face landmarks detection on images.""" - - @classmethod - def create_from_model_path(cls, model_path: str) -> 'FaceLandmarker': - """Creates an `FaceLandmarker` object from a TensorFlow Lite model and the default `FaceLandmarkerOptions`. - - Note that the created `FaceLandmarker` instance is in image mode, for - detecting face landmarks on single image inputs. - - Args: - model_path: Path to the model. - - Returns: - `FaceLandmarker` object that's created from the model file and the - default `FaceLandmarkerOptions`. - - Raises: - ValueError: If failed to create `FaceLandmarker` object from the - provided file such as invalid file path. - RuntimeError: If other types of error occurred. - """ - base_options = _BaseOptions(model_asset_path=model_path) - options = FaceLandmarkerOptions( - base_options=base_options, running_mode=_RunningMode.IMAGE - ) - return cls.create_from_options(options) - - @classmethod - def create_from_options( - cls, options: FaceLandmarkerOptions - ) -> 'FaceLandmarker': - """Creates the `FaceLandmarker` object from face landmarker options. - - Args: - options: Options for the face landmarker task. - - Returns: - `FaceLandmarker` object that's created from `options`. - - Raises: - ValueError: If failed to create `FaceLandmarker` object from - `FaceLandmarkerOptions` such as missing the model. - RuntimeError: If other types of error occurred. - """ - - def packets_callback(output_packets: Mapping[str, packet_module.Packet]): - if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): - return - - image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) - if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): - return - - if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): - empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME] - options.result_callback( - FaceLandmarkerResult([], [], []), - image, - empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, - ) - return - - face_landmarks_result = _build_landmarker_result(output_packets) - timestamp = output_packets[_NORM_LANDMARKS_STREAM_NAME].timestamp - options.result_callback( - face_landmarks_result, - image, - timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, - ) - - output_streams = [ - ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), - ] - - if options.output_face_blendshapes: - output_streams.append( - ':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME]) - ) - if options.output_facial_transformation_matrixes: - output_streams.append( - ':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME]) - ) - - task_info = _TaskInfo( - task_graph=_TASK_GRAPH_NAME, - input_streams=[ - ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), - ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), - ], - output_streams=output_streams, - task_options=options, - ) - return cls( - task_info.generate_graph_config( - enable_flow_limiting=options.running_mode - == _RunningMode.LIVE_STREAM - ), - options.running_mode, - packets_callback if options.result_callback else None, - ) - - def detect( - self, - image: image_module.Image, - image_processing_options: Optional[_ImageProcessingOptions] = None, - ) -> FaceLandmarkerResult: - """Performs face landmarks detection on the given image. - - Only use this method when the FaceLandmarker is created with the image - running mode. - - The image can be of any size with format RGB or RGBA. - TODO: Describes how the input image will be preprocessed after the yuv - support is implemented. - - Args: - image: MediaPipe Image. - image_processing_options: Options for image processing. - - Returns: - The face landmarks detection results. - - Raises: - ValueError: If any of the input arguments is invalid. - RuntimeError: If face landmarker detection failed to run. - """ - - normalized_rect = self.convert_to_normalized_rect( - image_processing_options, image, roi_allowed=False - ) - output_packets = self._process_image_data({ - _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), - _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - normalized_rect.to_pb2() - ), - }) - - if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): - return FaceLandmarkerResult([], [], []) - - return _build_landmarker_result2(output_packets) - - def detect_for_video( - self, - image: image_module.Image, - timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None, - ): - """Performs face landmarks detection on the provided video frame. - - Only use this method when the FaceLandmarker is created with the video - running mode. - - Only use this method when the FaceLandmarker is created with the video - running mode. It's required to provide the video frame's timestamp (in - milliseconds) along with the video frame. The input timestamps should be - monotonically increasing for adjacent calls of this method. - - Args: - image: MediaPipe Image. - timestamp_ms: The timestamp of the input video frame in milliseconds. - image_processing_options: Options for image processing. - - Returns: - The face landmarks detection results. - - Raises: - ValueError: If any of the input arguments is invalid. - RuntimeError: If face landmarker detection failed to run. - """ - normalized_rect = self.convert_to_normalized_rect( - image_processing_options, image, roi_allowed=False - ) - output_packets = self._process_video_data({ - _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND - ), - _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - normalized_rect.to_pb2() - ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - }) - - if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): - return FaceLandmarkerResult([], [], []) - - return _build_landmarker_result2(output_packets) - - def detect_async( - self, - image: image_module.Image, - timestamp_ms: int, - image_processing_options: Optional[_ImageProcessingOptions] = None, - ) -> None: - """Sends live image data to perform face landmarks detection. - - The results will be available via the "result_callback" provided in the - FaceLandmarkerOptions. Only use this method when the FaceLandmarker is - created with the live stream running mode. - - Only use this method when the FaceLandmarker is created with the live - stream running mode. The input timestamps should be monotonically increasing - for adjacent calls of this method. This method will return immediately after - the input image is accepted. The results will be available via the - `result_callback` provided in the `FaceLandmarkerOptions`. The - `detect_async` method is designed to process live stream data such as - camera input. To lower the overall latency, face landmarker may drop the - input images if needed. In other words, it's not guaranteed to have output - per input image. - - The `result_callback` provides: - - The face landmarks detection results. - - The input image that the face landmarker runs on. - - The input timestamp in milliseconds. - - Args: - image: MediaPipe Image. - timestamp_ms: The timestamp of the input image in milliseconds. - image_processing_options: Options for image processing. - - Raises: - ValueError: If the current input timestamp is smaller than what the - face landmarker has already processed. - """ - normalized_rect = self.convert_to_normalized_rect( - image_processing_options, image, roi_allowed=False - ) - self._send_live_stream_data({ - _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( - timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND - ), - _NORM_RECT_STREAM_NAME: packet_creator.create_proto( - normalized_rect.to_pb2() - ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), - }) \ No newline at end of file diff --git a/aniportrait/src/utils/frame_interpolation.py b/aniportrait/src/utils/frame_interpolation.py deleted file mode 100644 index 6ae04817ffef8aaf8a980cfd27c728dc496eeae4..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/frame_interpolation.py +++ /dev/null @@ -1,69 +0,0 @@ -# Adapted from https://github.com/dajes/frame-interpolation-pytorch -import os -import cv2 -import numpy as np -import torch -import bisect -import shutil -import pdb -from tqdm import tqdm - -def init_frame_interpolation_model(): - print("Initializing frame interpolation model") - checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt") - - model = torch.jit.load(checkpoint_name, map_location='cpu') - model.eval() - model = model.half() - model = model.to(device="cuda") - return model - - -def batch_images_interpolation_tool(input_tensor, model, inter_frames=1): - - video_tensor = [] - frame_num = input_tensor.shape[2] # bs, channel, frame, height, width - - for idx in tqdm(range(frame_num-1)): - image1 = input_tensor[:,:,idx] - image2 = input_tensor[:,:,idx+1] - - results = [image1, image2] - - inter_frames = int(inter_frames) - idxes = [0, inter_frames + 1] - remains = list(range(1, inter_frames + 1)) - - splits = torch.linspace(0, 1, inter_frames + 2) - - for _ in range(len(remains)): - starts = splits[idxes[:-1]] - ends = splits[idxes[1:]] - distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() - matrix = torch.argmin(distances).item() - start_i, step = np.unravel_index(matrix, distances.shape) - end_i = start_i + 1 - - x0 = results[start_i] - x1 = results[end_i] - - x0 = x0.half() - x1 = x1.half() - x0 = x0.cuda() - x1 = x1.cuda() - - dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) - - with torch.no_grad(): - prediction = model(x0, x1, dt) - insert_position = bisect.bisect_left(idxes, remains[step]) - idxes.insert(insert_position, remains[step]) - results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) - del remains[step] - - for sub_idx in range(len(results)-1): - video_tensor.append(results[sub_idx].unsqueeze(2)) - - video_tensor.append(input_tensor[:,:,-1].unsqueeze(2)) - video_tensor = torch.cat(video_tensor, dim=2) - return video_tensor \ No newline at end of file diff --git a/aniportrait/src/utils/mp_models/blaze_face_short_range.tflite b/aniportrait/src/utils/mp_models/blaze_face_short_range.tflite deleted file mode 100644 index 2645898ee18d8bf53746df830303779c9deabc7d..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/mp_models/blaze_face_short_range.tflite +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f -size 229746 diff --git a/aniportrait/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task b/aniportrait/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task deleted file mode 100644 index fedb14de6d2b6708a56c04ae259783e23404c1aa..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/mp_models/face_landmarker_v2_with_blendshapes.task +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff -size 3758596 diff --git a/aniportrait/src/utils/mp_models/pose_landmarker_heavy.task b/aniportrait/src/utils/mp_models/pose_landmarker_heavy.task deleted file mode 100644 index 5f2c1e254fe2d104606a9031b20b266863d014a6..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/mp_models/pose_landmarker_heavy.task +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:64437af838a65d18e5ba7a0d39b465540069bc8aae8308de3e318aad31fcbc7b -size 30664242 diff --git a/aniportrait/src/utils/mp_utils.py b/aniportrait/src/utils/mp_utils.py deleted file mode 100644 index cb4385128d1d188fc9023ee6c3df51fb8be39e2f..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/mp_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -import numpy as np -import cv2 -import time -from tqdm import tqdm -import multiprocessing -import glob - -import mediapipe as mp -from mediapipe import solutions -from mediapipe.framework.formats import landmark_pb2 -from mediapipe.tasks import python -from mediapipe.tasks.python import vision -from . import face_landmark - -CUR_DIR = os.path.dirname(__file__) - - -class LMKExtractor(): - def __init__(self, FPS=25): - # Create an FaceLandmarker object. - self.mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE - base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/face_landmarker_v2_with_blendshapes.task')) - base_options.delegate = mp.tasks.BaseOptions.Delegate.CPU - options = vision.FaceLandmarkerOptions(base_options=base_options, - running_mode=self.mode, - output_face_blendshapes=True, - output_facial_transformation_matrixes=True, - num_faces=1) - self.detector = face_landmark.FaceLandmarker.create_from_options(options) - self.last_ts = 0 - self.frame_ms = int(1000 / FPS) - - det_base_options = python.BaseOptions(model_asset_path=os.path.join(CUR_DIR, 'mp_models/blaze_face_short_range.tflite')) - det_options = vision.FaceDetectorOptions(base_options=det_base_options) - self.det_detector = vision.FaceDetector.create_from_options(det_options) - - - def __call__(self, img): - frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) - t0 = time.time() - if self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.VIDEO: - det_result = self.det_detector.detect(image) - if len(det_result.detections) != 1: - return None - self.last_ts += self.frame_ms - try: - detection_result, mesh3d = self.detector.detect_for_video(image, timestamp_ms=self.last_ts) - except: - return None - elif self.mode == mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE: - # det_result = self.det_detector.detect(image) - - # if len(det_result.detections) != 1: - # return None - try: - detection_result, mesh3d = self.detector.detect(image) - except: - return None - - - bs_list = detection_result.face_blendshapes - if len(bs_list) == 1: - bs = bs_list[0] - bs_values = [] - for index in range(len(bs)): - bs_values.append(bs[index].score) - bs_values = bs_values[1:] # remove neutral - trans_mat = detection_result.facial_transformation_matrixes[0] - face_landmarks_list = detection_result.face_landmarks - face_landmarks = face_landmarks_list[0] - lmks = [] - for index in range(len(face_landmarks)): - x = face_landmarks[index].x - y = face_landmarks[index].y - z = face_landmarks[index].z - lmks.append([x, y, z]) - lmks = np.array(lmks) - - lmks3d = np.array(mesh3d.vertex_buffer) - lmks3d = lmks3d.reshape(-1, 5)[:, :3] - mp_tris = np.array(mesh3d.index_buffer).reshape(-1, 3) + 1 - - return { - "lmks": lmks, - 'lmks3d': lmks3d, - "trans_mat": trans_mat, - 'faces': mp_tris, - "bs": bs_values - } - else: - # print('multiple faces in the image: {}'.format(img_path)) - return None - \ No newline at end of file diff --git a/aniportrait/src/utils/pose_util.py b/aniportrait/src/utils/pose_util.py deleted file mode 100644 index a09d07f20d404fbdfb6d444e2896df641ccc364c..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/pose_util.py +++ /dev/null @@ -1,89 +0,0 @@ -import math - -import numpy as np -from scipy.spatial.transform import Rotation as R - - -def create_perspective_matrix(aspect_ratio): - kDegreesToRadians = np.pi / 180. - near = 1 - far = 10000 - perspective_matrix = np.zeros(16, dtype=np.float32) - - # Standard perspective projection matrix calculations. - f = 1.0 / np.tan(kDegreesToRadians * 63 / 2.) - - denom = 1.0 / (near - far) - perspective_matrix[0] = f / aspect_ratio - perspective_matrix[5] = f - perspective_matrix[10] = (near + far) * denom - perspective_matrix[11] = -1. - perspective_matrix[14] = 1. * far * near * denom - - # If the environment's origin point location is in the top left corner, - # then skip additional flip along Y-axis is required to render correctly. - - perspective_matrix[5] *= -1. - return perspective_matrix - - -def project_points(points_3d, transformation_matrix, pose_vectors, image_shape): - P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T - L, N, _ = points_3d.shape - projected_points = np.zeros((L, N, 2)) - for i in range(L): - points_3d_frame = points_3d[i] - ones = np.ones((points_3d_frame.shape[0], 1)) - points_3d_homogeneous = np.hstack([points_3d_frame, ones]) - transformed_points = points_3d_homogeneous @ (transformation_matrix @ euler_and_translation_to_matrix(pose_vectors[i][:3], pose_vectors[i][3:])).T @ P - projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1 - projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1] - projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0] - projected_points[i] = projected_points_frame - return projected_points - - -def project_points_with_trans(points_3d, transformation_matrix, image_shape): - P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T - L, N, _ = points_3d.shape - projected_points = np.zeros((L, N, 2)) - for i in range(L): - points_3d_frame = points_3d[i] - ones = np.ones((points_3d_frame.shape[0], 1)) - points_3d_homogeneous = np.hstack([points_3d_frame, ones]) - transformed_points = points_3d_homogeneous @ transformation_matrix[i].T @ P - projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] # -1 ~ 1 - projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1] - projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0] - projected_points[i] = projected_points_frame - return projected_points - - -def euler_and_translation_to_matrix(euler_angles, translation_vector): - rotation = R.from_euler('xyz', euler_angles, degrees=True) - rotation_matrix = rotation.as_matrix() - - matrix = np.eye(4) - matrix[:3, :3] = rotation_matrix - matrix[:3, 3] = translation_vector - - return matrix - - -def matrix_to_euler_and_translation(matrix): - rotation_matrix = matrix[:3, :3] - translation_vector = matrix[:3, 3] - rotation = R.from_matrix(rotation_matrix) - euler_angles = rotation.as_euler('xyz', degrees=True) - return euler_angles, translation_vector - - -def smooth_pose_seq(pose_seq, window_size=5): - smoothed_pose_seq = np.zeros_like(pose_seq) - - for i in range(len(pose_seq)): - start = max(0, i - window_size // 2) - end = min(len(pose_seq), i + window_size // 2 + 1) - smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) - - return smoothed_pose_seq \ No newline at end of file diff --git a/aniportrait/src/utils/util.py b/aniportrait/src/utils/util.py deleted file mode 100644 index 3e473082ae134df241868fdb4ef1ffc61129cf10..0000000000000000000000000000000000000000 --- a/aniportrait/src/utils/util.py +++ /dev/null @@ -1,181 +0,0 @@ -import importlib -import os -import os.path as osp -import shutil -import sys -import cv2 -from pathlib import Path - -import av -import numpy as np -import torch -import torchvision -from einops import rearrange -from PIL import Image - - -def seed_everything(seed): - import random - - import numpy as np - - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed % (2**32)) - random.seed(seed) - - -def import_filename(filename): - spec = importlib.util.spec_from_file_location("mymodule", filename) - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - return module - - -def delete_additional_ckpt(base_path, num_keep): - dirs = [] - for d in os.listdir(base_path): - if d.startswith("checkpoint-"): - dirs.append(d) - num_tot = len(dirs) - if num_tot <= num_keep: - return - # ensure ckpt is sorted and delete the ealier! - del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep] - for d in del_dirs: - path_to_dir = osp.join(base_path, d) - if osp.exists(path_to_dir): - shutil.rmtree(path_to_dir) - - -def save_videos_from_pil(pil_images, path, fps=8): - import av - - save_fmt = Path(path).suffix - os.makedirs(os.path.dirname(path), exist_ok=True) - width, height = pil_images[0].size - - if save_fmt == ".mp4": - codec = "libx264" - container = av.open(path, "w") - stream = container.add_stream(codec, rate=fps) - - stream.width = width - stream.height = height - - for pil_image in pil_images: - # pil_image = Image.fromarray(image_arr).convert("RGB") - av_frame = av.VideoFrame.from_image(pil_image) - container.mux(stream.encode(av_frame)) - container.mux(stream.encode()) - container.close() - - elif save_fmt == ".gif": - pil_images[0].save( - fp=path, - format="GIF", - append_images=pil_images[1:], - save_all=True, - duration=(1 / fps * 1000), - loop=0, - ) - else: - raise ValueError("Unsupported file type. Use .mp4 or .gif.") - - -def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): - videos = rearrange(videos, "b c t h w -> t b c h w") - height, width = videos.shape[-2:] - outputs = [] - - for x in videos: - x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) - x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) - if rescale: - x = (x + 1.0) / 2.0 # -1,1 -> 0,1 - x = (x * 255).numpy().astype(np.uint8) - x = Image.fromarray(x) - - outputs.append(x) - - os.makedirs(os.path.dirname(path), exist_ok=True) - - save_videos_from_pil(outputs, path, fps) - - -def read_frames(video_path): - container = av.open(video_path) - - video_stream = next(s for s in container.streams if s.type == "video") - frames = [] - for packet in container.demux(video_stream): - for frame in packet.decode(): - image = Image.frombytes( - "RGB", - (frame.width, frame.height), - frame.to_rgb().to_ndarray(), - ) - frames.append(image) - - return frames - - -def get_fps(video_path): - container = av.open(video_path) - video_stream = next(s for s in container.streams if s.type == "video") - fps = video_stream.average_rate - container.close() - return fps - -def crop_face(img, lmk_extractor, expand=1.5): - result = lmk_extractor(img) # cv2 BGR - - if result is None: - return None - - H, W, _ = img.shape - lmks = result['lmks'] - lmks[:, 0] *= W - lmks[:, 1] *= H - - x_min = np.min(lmks[:, 0]) - x_max = np.max(lmks[:, 0]) - y_min = np.min(lmks[:, 1]) - y_max = np.max(lmks[:, 1]) - - width = x_max - x_min - height = y_max - y_min - - if width*height >= W*H*0.15: - if W == H: - return img - size = min(H, W) - offset = int((max(H, W) - size)/2) - if size == H: - return img[:, offset:-offset] - else: - return img[offset:-offset, :] - else: - center_x = x_min + width / 2 - center_y = y_min + height / 2 - - width *= expand - height *= expand - - size = max(width, height) - - x_min = int(center_x - size / 2) - x_max = int(center_x + size / 2) - y_min = int(center_y - size / 2) - y_max = int(center_y + size / 2) - - top = max(0, -y_min) - bottom = max(0, y_max - img.shape[0]) - left = max(0, -x_min) - right = max(0, x_max - img.shape[1]) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) - - cropped_img = img[y_min + top:y_max + top, x_min + left:x_max + left] - - return cropped_img \ No newline at end of file diff --git a/ckpt_tree.md b/ckpt_tree.md deleted file mode 100644 index 5018c23169e4fc414546ee8b439fa8696a5b61fc..0000000000000000000000000000000000000000 --- a/ckpt_tree.md +++ /dev/null @@ -1,108 +0,0 @@ - -``` -|-- ckpts -| |-- aniportrait -| | `-- motion_module.pth -| | `-- audio2mesh.pt -| | `-- film_net_fp16.pt -| | |-- sd-vae-ft-mse -| | | `-- diffusion_pytorch_model.safetensors -| | | `-- config.json -| | | `-- diffusion_pytorch_model.bin -| | `-- denoising_unet.pth -| | `-- audio2pose.pt -| | `-- pose_guider.pth -| | |-- sd-image-variations-diffusers -| | | `-- v1-montage.jpg -| | | |-- scheduler -| | | | `-- scheduler_config.json -| | | `-- README.md -| | | `-- model_index.json -| | | |-- unet -| | | | `-- config.json -| | | | `-- diffusion_pytorch_model.bin -| | | |-- feature_extractor -| | | | `-- preprocessor_config.json -| | | `-- v2-montage.jpg -| | | |-- vae -| | | | `-- config.json -| | | | `-- diffusion_pytorch_model.bin -| | | `-- alias-montage.jpg -| | | `-- inputs.jpg -| | | |-- safety_checker -| | | | `-- pytorch_model.bin -| | | | `-- config.json -| | | `-- earring.jpg -| | | `-- default-montage.jpg -| | |-- image_encoder -| | | `-- pytorch_model.bin -| | | `-- config.json -| | |-- stable-diffusion-v1-5 -| | | `-- model_index.json -| | | `-- v1-inference.yaml -| | | |-- unet -| | | | `-- config.json -| | | | `-- diffusion_pytorch_model.bin -| | | |-- feature_extractor -| | | | `-- preprocessor_config.json -| | `-- reference_unet.pth -| | |-- wav2vec2-base-960h -| | | `-- pytorch_model.bin -| | | `-- README.md -| | | `-- vocab.json -| | | `-- config.json -| | | `-- tf_model.h5 -| | | `-- tokenizer_config.json -| | | `-- model.safetensors -| | | `-- special_tokens_map.json -| | | `-- preprocessor_config.json -| | | `-- feature_extractor_config.json -| |-- mofa -| | |-- traj_controlnet -| | | `-- diffusion_pytorch_model.safetensors -| | | `-- config.json -| | |-- stable-video-diffusion-img2vid-xt-1-1 -| | | |-- scheduler -| | | | `-- scheduler_config.json -| | | `-- README.md -| | | `-- model_index.json -| | | |-- unet -| | | | `-- diffusion_pytorch_model.fp16.safetensors -| | | | `-- config.json -| | | |-- feature_extractor -| | | | `-- preprocessor_config.json -| | | |-- vae -| | | | `-- diffusion_pytorch_model.fp16.safetensors -| | | | `-- config.json -| | | `-- LICENSE -| | | `-- svd11.webp -| | | |-- image_encoder -| | | | `-- config.json -| | | | `-- model.fp16.safetensors -| | |-- ldmk_controlnet -| | | `-- diffusion_pytorch_model.safetensors -| | | `-- config.json -| |-- sad_talker -| | `-- SadTalker_V0.0.2_256.safetensors -| | |-- hub -| | `-- mapping_00229-model.pth.tar -| | |-- BFM_Fitting -| | | `-- select_vertex_id.mat -| | | `-- facemodel_info.mat -| | | `-- BFM_exp_idx.mat -| | | `-- BFM_model_front.mat -| | | `-- 01_MorphableModel.mat -| | | `-- similarity_Lm3D_all.mat -| | | `-- BFM_front_idx.mat -| | | `-- Exp_Pca.bin -| | | `-- std_exp.txt -| | `-- SadTalker_V0.0.2_512.safetensors -| | `-- similarity_Lm3D_all.mat -| | `-- epoch_00190_iteration_000400000_checkpoint.pt -| | `-- mapping_00109-model.pth.tar -| |-- gfpgan -| | `-- alignment_WFLW_4HG.pth -| | `-- parsing_parsenet.pth -| | `-- detection_Resnet50_Final.pth - -``` \ No newline at end of file diff --git a/expression.mat b/expression.mat deleted file mode 100644 index bf4d3c687be74adda57b4096cf05e279b9bf72ec..0000000000000000000000000000000000000000 --- a/expression.mat +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:93e9d69eb46e866ed5cbb569ed2bdb3813254720fb0cb745d5b56181faf9aec5 -size 1456 diff --git a/models/.DS_Store b/models/.DS_Store deleted file mode 100644 index 087053b36a558287fc3b017d1f8393a4e004363e..0000000000000000000000000000000000000000 Binary files a/models/.DS_Store and /dev/null differ diff --git a/models/cmp/.DS_Store b/models/cmp/.DS_Store deleted file mode 100644 index ff76e5324d7be1a0c28204838fcff4fc1fb09c7a..0000000000000000000000000000000000000000 Binary files a/models/cmp/.DS_Store and /dev/null differ diff --git a/models/cmp/experiments/.DS_Store b/models/cmp/experiments/.DS_Store deleted file mode 100644 index 9359ef6cd967551d5476ebf1ce12f2e1a233d898..0000000000000000000000000000000000000000 Binary files a/models/cmp/experiments/.DS_Store and /dev/null differ diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml deleted file mode 100644 index 2944c4056bc0683c9a94bc20017c1056352356e1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml +++ /dev/null @@ -1,59 +0,0 @@ -model: - arch: CMP - total_iter: 140000 - lr_steps: [80000, 120000] - lr_mults: [0.1, 0.1] - lr: 0.1 - optim: SGD - warmup_lr: [] - warmup_steps: [] - module: - arch: CMP - image_encoder: alexnet_fcn_32x - sparse_encoder: shallownet32x - flow_decoder: MotionDecoderPlain - skip_layer: False - img_enc_dim: 256 - sparse_enc_dim: 16 - output_dim: 198 - decoder_combo: [1,2,4] - pretrained_image_encoder: False - flow_criterion: "DiscreteLoss" - nbins: 99 - fmax: 50 -data: - workers: 2 - batch_size: 12 - batch_size_test: 1 - data_mean: [123.675, 116.28, 103.53] # RGB - data_div: [58.395, 57.12, 57.375] - short_size: 416 - crop_size: [384, 384] - sample_strategy: ['grid', 'watershed'] - sample_bg_ratio: 0.000025 - nms_ks: 81 - max_num_guide: 150 - - flow_file_type: "jpg" - image_flow_aug: - flip: False - flow_aug: - reverse: False - scale: False - rotate: False - train_source: - - data/yfcc/lists/train.txt - - data/youtube9000/lists/train.txt - val_source: - - data/yfcc/lists/val.txt - memcached: False -trainer: - initial_val: True - print_freq: 100 - val_freq: 10000 - save_freq: 10000 - val_iter: -1 - val_disp_start_iter: 0 - val_disp_end_iter: 16 - loss_record: ['loss_flow'] - tensorboard: False diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh deleted file mode 100644 index 06bd63a2c51db22a687f347635759f3a41ea30b2..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh deleted file mode 100644 index 644276733e346ef31fa9d3aaa4110b0b360cff3f..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh deleted file mode 100644 index 5f2b03a431e84f04599c76865ec14cd499ff3063..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh deleted file mode 100644 index e9c1a9f27ef9e639802ecf29247297ff7eb022d1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh deleted file mode 100644 index 6d7abca3cc02d020fce66d22d880f5c9e03ce34c..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh deleted file mode 100644 index 9bfe2eec2f1a52089b86f7d8a2550f12251a269e..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py --config $work_path/config.yaml --launcher slurm \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml deleted file mode 100644 index 9296127836bea77efc5d6d28ccf363c6e8adbf91..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml +++ /dev/null @@ -1,58 +0,0 @@ -model: - arch: CMP - total_iter: 70000 - lr_steps: [40000, 60000] - lr_mults: [0.1, 0.1] - lr: 0.1 - optim: SGD - warmup_lr: [] - warmup_steps: [] - module: - arch: CMP - image_encoder: alexnet_fcn_32x - sparse_encoder: shallownet32x - flow_decoder: MotionDecoderPlain - skip_layer: False - img_enc_dim: 256 - sparse_enc_dim: 16 - output_dim: 198 - decoder_combo: [1,2,4] - pretrained_image_encoder: False - flow_criterion: "DiscreteLoss" - nbins: 99 - fmax: 50 -data: - workers: 2 - batch_size: 12 - batch_size_test: 1 - data_mean: [123.675, 116.28, 103.53] # RGB - data_div: [58.395, 57.12, 57.375] - short_size: 416 - crop_size: [384, 384] - sample_strategy: ['grid', 'watershed'] - sample_bg_ratio: 0.00015625 - nms_ks: 41 - max_num_guide: 150 - - flow_file_type: "jpg" - image_flow_aug: - flip: False - flow_aug: - reverse: False - scale: False - rotate: False - train_source: - - data/yfcc/lists/train.txt - val_source: - - data/yfcc/lists/val.txt - memcached: False -trainer: - initial_val: True - print_freq: 100 - val_freq: 10000 - save_freq: 10000 - val_iter: -1 - val_disp_start_iter: 0 - val_disp_end_iter: 16 - loss_record: ['loss_flow'] - tensorboard: False diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh deleted file mode 100644 index 06bd63a2c51db22a687f347635759f3a41ea30b2..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh deleted file mode 100644 index 644276733e346ef31fa9d3aaa4110b0b360cff3f..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh deleted file mode 100644 index 5f2b03a431e84f04599c76865ec14cd499ff3063..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh deleted file mode 100644 index e9c1a9f27ef9e639802ecf29247297ff7eb022d1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh deleted file mode 100644 index 6d7abca3cc02d020fce66d22d880f5c9e03ce34c..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh deleted file mode 100644 index 9bfe2eec2f1a52089b86f7d8a2550f12251a269e..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py --config $work_path/config.yaml --launcher slurm \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml deleted file mode 100644 index 6e8751ff794627d37449771734bb2fe1521f527a..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml +++ /dev/null @@ -1,58 +0,0 @@ -model: - arch: CMP - total_iter: 140000 - lr_steps: [80000, 120000] - lr_mults: [0.1, 0.1] - lr: 0.1 - optim: SGD - warmup_lr: [] - warmup_steps: [] - module: - arch: CMP - image_encoder: alexnet_fcn_32x - sparse_encoder: shallownet32x - flow_decoder: MotionDecoderPlain - skip_layer: False - img_enc_dim: 256 - sparse_enc_dim: 16 - output_dim: 198 - decoder_combo: [1,2,4] - pretrained_image_encoder: False - flow_criterion: "DiscreteLoss" - nbins: 99 - fmax: 50 -data: - workers: 2 - batch_size: 12 - batch_size_test: 1 - data_mean: [123.675, 116.28, 103.53] # RGB - data_div: [58.395, 57.12, 57.375] - short_size: 416 - crop_size: [384, 384] - sample_strategy: ['grid', 'watershed'] - sample_bg_ratio: 0.00015625 - nms_ks: 41 - max_num_guide: 150 - - flow_file_type: "jpg" - image_flow_aug: - flip: False - flow_aug: - reverse: False - scale: False - rotate: False - train_source: - - data/yfcc/lists/train.txt - val_source: - - data/yfcc/lists/val.txt - memcached: False -trainer: - initial_val: True - print_freq: 100 - val_freq: 10000 - save_freq: 10000 - val_iter: -1 - val_disp_start_iter: 0 - val_disp_end_iter: 16 - loss_record: ['loss_flow'] - tensorboard: False diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh deleted file mode 100644 index 17cec90cef2555c6b7dd5acfe3b938c9be451346..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh deleted file mode 100644 index 94b7cccac61566afb3eef7924a6d8b56027b2d13..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh deleted file mode 100644 index 330d14c459f8549ea81f956c3497e13ddf68aed0..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh deleted file mode 100644 index 140bfae1f0543e2b186f06f3dfc7a934c0aeccf1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh deleted file mode 100644 index 6d7abca3cc02d020fce66d22d880f5c9e03ce34c..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh deleted file mode 100644 index 9bfe2eec2f1a52089b86f7d8a2550f12251a269e..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py --config $work_path/config.yaml --launcher slurm \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml deleted file mode 100644 index 5dd44cf7642837242711326eb413b950c384dd26..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml +++ /dev/null @@ -1,61 +0,0 @@ -model: - arch: CMP - total_iter: 70000 - lr_steps: [40000, 60000] - lr_mults: [0.1, 0.1] - lr: 0.1 - optim: SGD - warmup_lr: [] - warmup_steps: [] - module: - arch: CMP - image_encoder: resnet50 - sparse_encoder: shallownet8x - flow_decoder: MotionDecoderPlain - skip_layer: False - img_enc_dim: 256 - sparse_enc_dim: 16 - output_dim: 198 - decoder_combo: [1,2,4] - pretrained_image_encoder: False - flow_criterion: "DiscreteLoss" - nbins: 99 - fmax: 50 -data: - workers: 2 - batch_size: 10 - batch_size_test: 1 - data_mean: [123.675, 116.28, 103.53] # RGB - data_div: [58.395, 57.12, 57.375] - short_size: 416 - crop_size: [320, 320] - sample_strategy: ['grid', 'watershed'] - sample_bg_ratio: 0.00015625 - nms_ks: 15 - max_num_guide: -1 - - flow_file_type: "jpg" - image_flow_aug: - flip: False - flow_aug: - reverse: False - scale: False - rotate: False - train_source: - - data/yfcc/lists/train.txt - - data/youtube9000/lists/train.txt - - data/VIP/lists/train.txt - - data/MPII/lists/train.txt - val_source: - - data/yfcc/lists/val.txt - memcached: False -trainer: - initial_val: True - print_freq: 100 - val_freq: 10000 - save_freq: 10000 - val_iter: -1 - val_disp_start_iter: 0 - val_disp_end_iter: 16 - loss_record: ['loss_flow'] - tensorboard: False diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh deleted file mode 100644 index 06bd63a2c51db22a687f347635759f3a41ea30b2..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh deleted file mode 100644 index 644276733e346ef31fa9d3aaa4110b0b360cff3f..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh deleted file mode 100644 index 5f2b03a431e84f04599c76865ec14cd499ff3063..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh deleted file mode 100644 index e9c1a9f27ef9e639802ecf29247297ff7eb022d1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh deleted file mode 100644 index 6d7abca3cc02d020fce66d22d880f5c9e03ce34c..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh deleted file mode 100644 index 9bfe2eec2f1a52089b86f7d8a2550f12251a269e..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py --config $work_path/config.yaml --launcher slurm \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml deleted file mode 100644 index 1a453c27947f570320609a61fde9c862819842bc..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml +++ /dev/null @@ -1,58 +0,0 @@ -model: - arch: CMP - total_iter: 42000 - lr_steps: [24000, 36000] - lr_mults: [0.1, 0.1] - lr: 0.1 - optim: SGD - warmup_lr: [] - warmup_steps: [] - module: - arch: CMP - image_encoder: resnet50 - sparse_encoder: shallownet8x - flow_decoder: MotionDecoderPlain - skip_layer: False - img_enc_dim: 256 - sparse_enc_dim: 16 - output_dim: 198 - decoder_combo: [1,2,4] - pretrained_image_encoder: False - flow_criterion: "DiscreteLoss" - nbins: 99 - fmax: 50 -data: - workers: 2 - batch_size: 16 - batch_size_test: 1 - data_mean: [123.675, 116.28, 103.53] # RGB - data_div: [58.395, 57.12, 57.375] - short_size: 333 - crop_size: [256, 256] - sample_strategy: ['grid', 'watershed'] - sample_bg_ratio: 0.00005632 - nms_ks: 49 - max_num_guide: -1 - - flow_file_type: "jpg" - image_flow_aug: - flip: False - flow_aug: - reverse: False - scale: False - rotate: False - train_source: - - data/yfcc/lists/train.txt - val_source: - - data/yfcc/lists/val.txt - memcached: False -trainer: - initial_val: True - print_freq: 100 - val_freq: 10000 - save_freq: 10000 - val_iter: -1 - val_disp_start_iter: 0 - val_disp_end_iter: 16 - loss_record: ['loss_flow'] - tensorboard: False diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh deleted file mode 100644 index 06bd63a2c51db22a687f347635759f3a41ea30b2..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh deleted file mode 100644 index 644276733e346ef31fa9d3aaa4110b0b360cff3f..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh deleted file mode 100644 index 5f2b03a431e84f04599c76865ec14cd499ff3063..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh deleted file mode 100644 index e9c1a9f27ef9e639802ecf29247297ff7eb022d1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh deleted file mode 100644 index 6d7abca3cc02d020fce66d22d880f5c9e03ce34c..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh deleted file mode 100644 index 9bfe2eec2f1a52089b86f7d8a2550f12251a269e..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py --config $work_path/config.yaml --launcher slurm \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml deleted file mode 100644 index 47ba5c8c0d6f63247b7fcf6ac18554f4cddb0eac..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml +++ /dev/null @@ -1,58 +0,0 @@ -model: - arch: CMP - total_iter: 42000 - lr_steps: [24000, 36000] - lr_mults: [0.1, 0.1] - lr: 0.1 - optim: SGD - warmup_lr: [] - warmup_steps: [] - module: - arch: CMP - image_encoder: resnet50 - sparse_encoder: shallownet8x - flow_decoder: MotionDecoderPlain - skip_layer: False - img_enc_dim: 256 - sparse_enc_dim: 16 - output_dim: 198 - decoder_combo: [1,2,4] - pretrained_image_encoder: False - flow_criterion: "DiscreteLoss" - nbins: 99 - fmax: 50 -data: - workers: 2 - batch_size: 10 - batch_size_test: 1 - data_mean: [123.675, 116.28, 103.53] # RGB - data_div: [58.395, 57.12, 57.375] - short_size: 416 - crop_size: [320, 320] - sample_strategy: ['grid', 'watershed'] - sample_bg_ratio: 0.00003629 - nms_ks: 67 - max_num_guide: -1 - - flow_file_type: "jpg" - image_flow_aug: - flip: False - flow_aug: - reverse: False - scale: False - rotate: False - train_source: - - data/yfcc/lists/train.txt - val_source: - - data/yfcc/lists/val.txt - memcached: False -trainer: - initial_val: True - print_freq: 100 - val_freq: 10000 - save_freq: 10000 - val_iter: -1 - val_disp_start_iter: 0 - val_disp_end_iter: 16 - loss_record: ['loss_flow'] - tensorboard: False diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh deleted file mode 100644 index 06bd63a2c51db22a687f347635759f3a41ea30b2..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh deleted file mode 100644 index 644276733e346ef31fa9d3aaa4110b0b360cff3f..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh deleted file mode 100644 index 5f2b03a431e84f04599c76865ec14cd499ff3063..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 \ - --nnodes=2 --node_rank=$1 \ - --master_addr="192.168.1.1" main.py \ - --config $work_path/config.yaml --launcher pytorch diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh deleted file mode 100644 index e9c1a9f27ef9e639802ecf29247297ff7eb022d1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh deleted file mode 100644 index 6d7abca3cc02d020fce66d22d880f5c9e03ce34c..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh deleted file mode 100644 index 9bfe2eec2f1a52089b86f7d8a2550f12251a269e..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py --config $work_path/config.yaml --launcher slurm \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar deleted file mode 100644 index a15fde53bc352803ac906bb48f7ec6f08f55f817..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cd3a385e227c29f89b5c7c6f4c89d356f6022fa7fcfc71ab1bd40e9833048dd6 -size 228465722 diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml deleted file mode 100644 index fc56f53ce2088872c5f6987a0f1a44dabaf76f9d..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml +++ /dev/null @@ -1,59 +0,0 @@ -model: - arch: CMP - total_iter: 42000 - lr_steps: [24000, 36000] - lr_mults: [0.1, 0.1] - lr: 0.1 - optim: SGD - warmup_lr: [] - warmup_steps: [] - module: - arch: CMP - image_encoder: resnet50 - sparse_encoder: shallownet8x - flow_decoder: MotionDecoderSkipLayer - skip_layer: True - img_enc_dim: 256 - sparse_enc_dim: 16 - output_dim: 198 - decoder_combo: [1,2,4] - pretrained_image_encoder: False - flow_criterion: "DiscreteLoss" - nbins: 99 - fmax: 50 -data: - workers: 2 - batch_size: 8 - batch_size_test: 1 - data_mean: [123.675, 116.28, 103.53] # RGB - data_div: [58.395, 57.12, 57.375] - short_size: 416 - crop_size: [384, 384] - sample_strategy: ['grid', 'watershed'] - sample_bg_ratio: 5.74e-5 - nms_ks: 41 - max_num_guide: -1 - - flow_file_type: "jpg" - image_flow_aug: - flip: False - flow_aug: - reverse: False - scale: False - rotate: False - train_source: - - data/VIP/lists/train.txt - - data/MPII/lists/train.txt - val_source: - - data/VIP/lists/randval.txt - memcached: False -trainer: - initial_val: True - print_freq: 100 - val_freq: 5000 - save_freq: 5000 - val_iter: -1 - val_disp_start_iter: 0 - val_disp_end_iter: 16 - loss_record: ['loss_flow'] - tensorboard: True diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh deleted file mode 100644 index 17cec90cef2555c6b7dd5acfe3b938c9be451346..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh deleted file mode 100644 index 94b7cccac61566afb3eef7924a6d8b56027b2d13..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm \ - --load-iter 10000 \ - --resume diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh deleted file mode 100644 index 330d14c459f8549ea81f956c3497e13ddf68aed0..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh deleted file mode 100644 index 140bfae1f0543e2b186f06f3dfc7a934c0aeccf1..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \ - --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py \ - --config $work_path/config.yaml --launcher slurm diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh deleted file mode 100644 index 6d7abca3cc02d020fce66d22d880f5c9e03ce34c..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -python -m torch.distributed.launch --nproc_per_node=8 main.py \ - --config $work_path/config.yaml --launcher pytorch \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh deleted file mode 100644 index aef377a6e02a61de710eb8a72769ede93ce897e7..0000000000000000000000000000000000000000 --- a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -work_path=$(dirname $0) -partition=$1 -GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition \ - -n8 --gres=gpu:8 --ntasks-per-node=8 \ - python -u main.py --config $work_path/config.yaml --launcher slurm \ - --load-iter 70000 \ - --validate diff --git a/models/cmp/losses.py b/models/cmp/losses.py deleted file mode 100644 index b562ff841da3e8508ddf5e1264de382fb510376d..0000000000000000000000000000000000000000 --- a/models/cmp/losses.py +++ /dev/null @@ -1,536 +0,0 @@ -import torch -import numpy as np -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable -import random -import math - -def MultiChannelSoftBinaryCrossEntropy(input, target, reduction='mean'): - ''' - input: N x 38 x H x W --> 19N x 2 x H x W - target: N x 19 x H x W --> 19N x 1 x H x W - ''' - input = input.view(-1, 2, input.size(2), input.size(3)) - target = target.view(-1, 1, input.size(2), input.size(3)) - - logsoftmax = nn.LogSoftmax(dim=1) - if reduction == 'mean': - return torch.mean(torch.sum(-target * logsoftmax(input), dim=1)) - else: - return torch.sum(torch.sum(-target * logsoftmax(input), dim=1)) - -class EdgeAwareLoss(): - def __init__(self, nc=2, loss_type="L1", reduction='mean'): - assert loss_type in ['L1', 'BCE'], "Undefined loss type: {}".format(loss_type) - self.nc = nc - self.loss_type = loss_type - self.kernelx = Variable(torch.Tensor([[1,0,-1],[2,0,-2],[1,0,-1]]).cuda()) - self.kernelx = self.kernelx.repeat(nc,1,1,1) - self.kernely = Variable(torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]]).cuda()) - self.kernely = self.kernely.repeat(nc,1,1,1) - self.bias = Variable(torch.zeros(nc).cuda()) - self.reduction = reduction - if loss_type == 'L1': - self.loss = nn.SmoothL1Loss(reduction=reduction) - elif loss_type == 'BCE': - self.loss = self.bce2d - - def bce2d(self, input, target): - assert not target.requires_grad - beta = 1 - torch.mean(target) - weights = 1 - beta + (2 * beta - 1) * target - loss = nn.functional.binary_cross_entropy(input, target, weights, reduction=self.reduction) - return loss - - def get_edge(self, var): - assert var.size(1) == self.nc, \ - "input size at dim 1 should be consistent with nc, {} vs {}".format(var.size(1), self.nc) - outputx = nn.functional.conv2d(var, self.kernelx, bias=self.bias, padding=1, groups=self.nc) - outputy = nn.functional.conv2d(var, self.kernely, bias=self.bias, padding=1, groups=self.nc) - eps=1e-05 - return torch.sqrt(outputx.pow(2) + outputy.pow(2) + eps).mean(dim=1, keepdim=True) - - def __call__(self, input, target): - size = target.shape[2:4] - input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True) - target_edge = self.get_edge(target) - if self.loss_type == 'L1': - return self.loss(self.get_edge(input), target_edge) - elif self.loss_type == 'BCE': - raise NotImplemented - #target_edge = torch.sign(target_edge - 0.1) - #pred = self.get_edge(nn.functional.sigmoid(input)) - #return self.loss(pred, target_edge) - -def KLD(mean, logvar): - return -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) - -class DiscreteLoss(nn.Module): - def __init__(self, nbins, fmax): - super().__init__() - self.loss = nn.CrossEntropyLoss() - assert nbins % 2 == 1, "nbins should be odd" - self.nbins = nbins - self.fmax = fmax - self.step = 2 * fmax / float(nbins) - - def tobin(self, target): - target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3) - quantized_target = torch.floor((target + self.fmax) / self.step) - return quantized_target.type(torch.cuda.LongTensor) - - def __call__(self, input, target): - size = target.shape[2:4] - if input.shape[2] != size[0] or input.shape[3] != size[1]: - input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True) - target = self.tobin(target) - assert input.size(1) == self.nbins * 2 - # print(target.shape) - # print(input.shape) - # print(torch.max(target)) - target[target>=99]=98 # odd bugs of the training loss. We have [0 ~ 99] in GT flow, but nbins = 99 - return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...]) - -class MultiDiscreteLoss(): - def __init__(self, nbins=19, fmax=47.5, reduction='mean', xy_weight=(1., 1.), quantize_strategy='linear'): - self.loss = nn.CrossEntropyLoss(reduction=reduction) - assert nbins % 2 == 1, "nbins should be odd" - self.nbins = nbins - self.fmax = fmax - self.step = 2 * fmax / float(nbins) - self.x_weight, self.y_weight = xy_weight - self.quantize_strategy = quantize_strategy - - def tobin(self, target): - target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3) - if self.quantize_strategy == "linear": - quantized_target = torch.floor((target + self.fmax) / self.step) - elif self.quantize_strategy == "quadratic": - ind = target.data > 0 - quantized_target = target.clone() - quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.) - quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.) - return quantized_target.type(torch.cuda.LongTensor) - - def __call__(self, input, target): - size = target.shape[2:4] - target = self.tobin(target) - if isinstance(input, list): - input = [nn.functional.interpolate(ip, size=size, mode="bilinear", align_corners=True) for ip in input] - return sum([self.x_weight * self.loss(input[k][:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[k][:,self.nbins:,...], target[:,1,...]) for k in range(len(input))]) / float(len(input)) - else: - input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True) - return self.x_weight * self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[:,self.nbins:,...], target[:,1,...]) - -class MultiL1Loss(): - def __init__(self, reduction='mean'): - self.loss = nn.SmoothL1Loss(reduction=reduction) - - def __call__(self, input, target): - size = target.shape[2:4] - if isinstance(input, list): - input = [nn.functional.interpolate(ip, size=size, mode="bilinear", align_corners=True) for ip in input] - return sum([self.loss(input[k], target) for k in range(len(input))]) / float(len(input)) - else: - input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True) - return self.loss(input, target) - -class MultiMSELoss(): - def __init__(self): - self.loss = nn.MSELoss() - - def __call__(self, predicts, targets): - loss = 0 - for predict, target in zip(predicts, targets): - loss += self.loss(predict, target) - return loss - -class JointDiscreteLoss(): - def __init__(self, nbins=19, fmax=47.5, reduction='mean', quantize_strategy='linear'): - self.loss = nn.CrossEntropyLoss(reduction=reduction) - assert nbins % 2 == 1, "nbins should be odd" - self.nbins = nbins - self.fmax = fmax - self.step = 2 * fmax / float(nbins) - self.quantize_strategy = quantize_strategy - - def tobin(self, target): - target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3) - if self.quantize_strategy == "linear": - quantized_target = torch.floor((target + self.fmax) / self.step) - elif self.quantize_strategy == "quadratic": - ind = target.data > 0 - quantized_target = target.clone() - quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.) - quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.) - else: - raise Exception("No such quantize strategy: {}".format(self.quantize_strategy)) - joint_target = quantized_target[:,0,:,:] * self.nbins + quantized_target[:,1,:,:] - return joint_target.type(torch.cuda.LongTensor) - - def __call__(self, input, target): - target = self.tobin(target) - assert input.size(1) == self.nbins ** 2 - return self.loss(input, target) - -class PolarDiscreteLoss(): - def __init__(self, abins=30, rbins=20, fmax=50., reduction='mean', ar_weight=(1., 1.), quantize_strategy='linear'): - self.loss = nn.CrossEntropyLoss(reduction=reduction) - self.fmax = fmax - self.rbins = rbins - self.abins = abins - self.a_weight, self.r_weight = ar_weight - self.quantize_strategy = quantize_strategy - - def tobin(self, target): - indxneg = target.data[:,0,:,:] < 0 - eps = torch.zeros(target.data[:,0,:,:].size()).cuda() - epsind = target.data[:,0,:,:] == 0 - eps[epsind] += 1e-5 - angle = torch.atan(target.data[:,1,:,:] / (target.data[:,0,:,:] + eps)) - angle[indxneg] += np.pi - angle += np.pi / 2 # 0 to 2pi - angle = torch.clamp(angle, 0, 2 * np.pi - 1e-3) - radius = torch.sqrt(target.data[:,0,:,:] ** 2 + target.data[:,1,:,:] ** 2) - radius = torch.clamp(radius, 0, self.fmax - 1e-3) - quantized_angle = torch.floor(self.abins * angle / (2 * np.pi)) - if self.quantize_strategy == 'linear': - quantized_radius = torch.floor(self.rbins * radius / self.fmax) - elif self.quantize_strategy == 'quadratic': - quantized_radius = torch.floor(self.rbins * torch.sqrt(radius / self.fmax)) - else: - raise Exception("No such quantize strategy: {}".format(self.quantize_strategy)) - quantized_target = torch.autograd.Variable(torch.cat([torch.unsqueeze(quantized_angle, 1), torch.unsqueeze(quantized_radius, 1)], dim=1)) - return quantized_target.type(torch.cuda.LongTensor) - - def __call__(self, input, target): - target = self.tobin(target) - assert (target >= 0).all() and (target[:,0,:,:] < self.abins).all() and (target[:,1,:,:] < self.rbins).all() - return self.a_weight * self.loss(input[:,:self.abins,...], target[:,0,...]) + self.r_weight * self.loss(input[:,self.abins:,...], target[:,1,...]) - -class WeightedDiscreteLoss(): - def __init__(self, nbins=19, fmax=47.5, reduction='mean'): - self.loss = CrossEntropy2d(reduction=reduction) - assert nbins % 2 == 1, "nbins should be odd" - self.nbins = nbins - self.fmax = fmax - self.step = 2 * fmax / float(nbins) - self.weight = np.ones((nbins), dtype=np.float32) - self.weight[int(self.fmax / self.step)] = 0.01 - self.weight = torch.from_numpy(self.weight).cuda() - - def tobin(self, target): - target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3) - return torch.floor((target + self.fmax) / self.step).type(torch.cuda.LongTensor) - - def __call__(self, input, target): - target = self.tobin(target) - assert (target >= 0).all() and (target < self.nbins).all() - return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...], self.weight) - - -class CrossEntropy2d(nn.Module): - def __init__(self, reduction='mean', ignore_label=-1): - super(CrossEntropy2d, self).__init__() - self.ignore_label = ignore_label - self.reduction = reduction - - def forward(self, predict, target, weight=None): - """ - Args: - predict:(n, c, h, w) - target:(n, h, w) - weight (Tensor, optional): a manual rescaling weight given to each class. - If given, has to be a Tensor of size "nclasses" - """ - assert not target.requires_grad - assert predict.dim() == 4 - assert target.dim() == 3 - assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) - assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) - assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3)) - n, c, h, w = predict.size() - target_mask = (target >= 0) * (target != self.ignore_label) - target = target[target_mask] - predict = predict.transpose(1, 2).transpose(2, 3).contiguous() - predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) - loss = F.cross_entropy(predict, target, weight=weight, reduction=self.reduction) - return loss - -#class CrossPixelSimilarityLoss(): -# ''' -# Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py -# ''' -# def __init__(self, sigma=0.0036, sampling_size=512): -# self.sigma = sigma -# self.sampling_size = sampling_size -# self.epsilon = 1.0e-15 -# self.embed_norm = True # loss does not decrease no matter it is true or false. -# -# def __call__(self, embeddings, flows): -# ''' -# embedding: Variable Nx256xHxW (not hyper-column) -# flows: Variable Nx2xHxW -# ''' -# assert flows.size(1) == 2 -# -# # flow normalization -# positive_mask = (flows > 0) -# flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.) -# flows[positive_mask] = -flows[positive_mask] -# -# # embedding normalization -# if self.embed_norm: -# embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True) -# -# # Spatially random sampling (512 samples) -# flows_flatten = flows.view(flows.shape[0], 2, -1) -# random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda()) -# flows_sample = torch.index_select(flows_flatten, 2, random_locations) -# -# # K_f -# k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) - -# torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3, -# keepdim=False) ** 2 -# exp_k_f = torch.exp(-k_f / 2. / self.sigma) -# -# -# # mask -# eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda()) -# mask = torch.ones_like(exp_k_f) - eye -# -# # S_f -# masked_exp_k_f = torch.mul(mask, exp_k_f) + eye -# s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True) -# -# # K_theta -# embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1) -# embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations) -# embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True) -# k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm)) -# exp_k_theta = torch.exp(k_theta) -# -# # S_theta -# masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye -# s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True) -# -# # loss -# loss = -torch.mean(torch.mul(s_f, torch.log(s_theta))) -# -# return loss - -class CrossPixelSimilarityLoss(): - ''' - Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py - ''' - def __init__(self, sigma=0.01, sampling_size=512): - self.sigma = sigma - self.sampling_size = sampling_size - self.epsilon = 1.0e-15 - self.embed_norm = True # loss does not decrease no matter it is true or false. - - def __call__(self, embeddings, flows): - ''' - embedding: Variable Nx256xHxW (not hyper-column) - flows: Variable Nx2xHxW - ''' - assert flows.size(1) == 2 - - # flow normalization - positive_mask = (flows > 0) - flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.) - flows[positive_mask] = -flows[positive_mask] - - # embedding normalization - if self.embed_norm: - embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True) - - # Spatially random sampling (512 samples) - flows_flatten = flows.view(flows.shape[0], 2, -1) - random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda()) - flows_sample = torch.index_select(flows_flatten, 2, random_locations) - - # K_f - k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) - - torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3, - keepdim=False) ** 2 - exp_k_f = torch.exp(-k_f / 2. / self.sigma) - - - # mask - eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda()) - mask = torch.ones_like(exp_k_f) - eye - - # S_f - masked_exp_k_f = torch.mul(mask, exp_k_f) + eye - s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True) - - # K_theta - embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1) - embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations) - embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True) - k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm)) - exp_k_theta = torch.exp(k_theta) - - # S_theta - masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye - s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True) - - # loss - loss = -torch.mean(torch.mul(s_f, torch.log(s_theta))) - - return loss - - -class CrossPixelSimilarityFullLoss(): - ''' - Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py - ''' - def __init__(self, sigma=0.01): - self.sigma = sigma - self.epsilon = 1.0e-15 - self.embed_norm = True # loss does not decrease no matter it is true or false. - - def __call__(self, embeddings, flows): - ''' - embedding: Variable Nx256xHxW (not hyper-column) - flows: Variable Nx2xHxW - ''' - assert flows.size(1) == 2 - - # downsample flow - factor = flows.shape[2] // embeddings.shape[2] - flows = nn.functional.avg_pool2d(flows, factor, factor) - assert flows.shape[2] == embeddings.shape[2] - - # flow normalization - positive_mask = (flows > 0) - flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.) - flows[positive_mask] = -flows[positive_mask] - - # embedding normalization - if self.embed_norm: - embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True) - - # Spatially random sampling (512 samples) - flows_flatten = flows.view(flows.shape[0], 2, -1) - #random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda()) - #flows_sample = torch.index_select(flows_flatten, 2, random_locations) - - # K_f - k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_flatten, dim=-1).permute(0, 3, 2, 1) - - torch.unsqueeze(flows_flatten, dim=-1).permute(0, 2, 3, 1), p=2, dim=3, - keepdim=False) ** 2 - exp_k_f = torch.exp(-k_f / 2. / self.sigma) - - - # mask - eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda()) - mask = torch.ones_like(exp_k_f) - eye - - # S_f - masked_exp_k_f = torch.mul(mask, exp_k_f) + eye - s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True) - - # K_theta - embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1) - #embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations) - embeddings_flatten_norm = torch.norm(embeddings_flatten, p=2, dim=1, keepdim=True) - k_theta = 0.25 * (torch.matmul(embeddings_flatten.permute(0, 2, 1), embeddings_flatten)) / (self.epsilon + torch.matmul(embeddings_flatten_norm.permute(0, 2, 1), embeddings_flatten_norm)) - exp_k_theta = torch.exp(k_theta) - - # S_theta - masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye - s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True) - - # loss - loss = -torch.mean(torch.mul(s_f, torch.log(s_theta))) - - return loss - - -def get_column(embeddings, index, full_size): - col = [] - for embd in embeddings: - ind = (index.float() / full_size * embd.size(2)).long() - col.append(torch.index_select(embd.view(embd.shape[0], embd.shape[1], -1), 2, ind)) - return torch.cat(col, dim=1) # N x coldim x sparsenum - -class CrossPixelSimilarityColumnLoss(nn.Module): - ''' - Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py - ''' - def __init__(self, sigma=0.0036, sampling_size=512): - super(CrossPixelSimilarityColumnLoss, self).__init__() - self.sigma = sigma - self.sampling_size = sampling_size - self.epsilon = 1.0e-15 - self.embed_norm = True # loss does not decrease no matter it is true or false. - self.mlp = nn.Sequential( - nn.Linear(96 + 96 + 384 + 256 + 4096, 256), - nn.ReLU(inplace=True), - nn.Linear(256, 16)) - - def forward(self, feats, flows): - ''' - embedding: Variable Nx256xHxW (not hyper-column) - flows: Variable Nx2xHxW - ''' - assert flows.size(1) == 2 - - # flow normalization - positive_mask = (flows > 0) - flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.) - flows[positive_mask] = -flows[positive_mask] - - # Spatially random sampling (512 samples) - flows_flatten = flows.view(flows.shape[0], 2, -1) - random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda()) - flows_sample = torch.index_select(flows_flatten, 2, random_locations) - - # K_f - k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) - - torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3, - keepdim=False) ** 2 - exp_k_f = torch.exp(-k_f / 2. / self.sigma) - - - # mask - eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda()) - mask = torch.ones_like(exp_k_f) - eye - - # S_f - masked_exp_k_f = torch.mul(mask, exp_k_f) + eye - s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True) - - - # column - column = get_column(feats, random_locations, flows.shape[2]) - embedding = self.mlp(column) - # K_theta - embedding_norm = torch.norm(embedding, p=2, dim=1, keepdim=True) - k_theta = 0.25 * (torch.matmul(embedding.permute(0, 2, 1), embedding)) / (self.epsilon + torch.matmul(embedding_norm.permute(0, 2, 1), embedding_norm)) - exp_k_theta = torch.exp(k_theta) - - # S_theta - masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye - s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True) - - # loss - loss = -torch.mean(torch.mul(s_f, torch.log(s_theta))) - - return loss - - -def print_info(name, var): - print(name, var.size(), torch.max(var).data.cpu()[0], torch.min(var).data.cpu()[0], torch.mean(var).data.cpu()[0]) - - -def MaskL1Loss(input, target, mask): - input_size = input.size() - res = torch.sum(torch.abs(input * mask - target * mask)) - total = torch.sum(mask).item() - if total > 0: - res = res / (total * input_size[1]) - return res diff --git a/models/cmp/models/.DS_Store b/models/cmp/models/.DS_Store deleted file mode 100644 index 778b39036de751baf323736bc69406f0202af4ea..0000000000000000000000000000000000000000 Binary files a/models/cmp/models/.DS_Store and /dev/null differ diff --git a/models/cmp/models/__init__.py b/models/cmp/models/__init__.py deleted file mode 100644 index 168e1aae54937d9c23f4f40eae871b7fd73dc5c8..0000000000000000000000000000000000000000 --- a/models/cmp/models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .single_stage_model import * -from .cmp import * -from . import modules -from . import backbone diff --git a/models/cmp/models/backbone/__init__.py b/models/cmp/models/backbone/__init__.py deleted file mode 100644 index eea305c40902faaf491f9ed7ca70a56c0b9ae7fb..0000000000000000000000000000000000000000 --- a/models/cmp/models/backbone/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .resnet import * -from .alexnet import * diff --git a/models/cmp/models/backbone/alexnet.py b/models/cmp/models/backbone/alexnet.py deleted file mode 100644 index d4ac39a8d8af096e9363854a2e7d720623ecd73e..0000000000000000000000000000000000000000 --- a/models/cmp/models/backbone/alexnet.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch.nn as nn -import math - -class AlexNetBN_FCN(nn.Module): - - def __init__(self, output_dim=256, stride=[4, 2, 2, 2], dilation=[1, 1], padding=[1, 1]): - super(AlexNetBN_FCN, self).__init__() - BN = nn.BatchNorm2d - - self.conv1 = nn.Sequential( - nn.Conv2d(3, 96, kernel_size=11, stride=stride[0], padding=5), - BN(96), - nn.ReLU(inplace=True)) - self.pool1 = nn.MaxPool2d(kernel_size=3, stride=stride[1], padding=1) - self.conv2 = nn.Sequential( - nn.Conv2d(96, 256, kernel_size=5, padding=2), - BN(256), - nn.ReLU(inplace=True)) - self.pool2 = nn.MaxPool2d(kernel_size=3, stride=stride[2], padding=1) - self.conv3 = nn.Sequential( - nn.Conv2d(256, 384, kernel_size=3, padding=1), - BN(384), - nn.ReLU(inplace=True)) - self.conv4 = nn.Sequential( - nn.Conv2d(384, 384, kernel_size=3, padding=padding[0], dilation=dilation[0]), - BN(384), - nn.ReLU(inplace=True)) - self.conv5 = nn.Sequential( - nn.Conv2d(384, 256, kernel_size=3, padding=padding[1], dilation=dilation[1]), - BN(256), - nn.ReLU(inplace=True)) - self.pool5 = nn.MaxPool2d(kernel_size=3, stride=stride[3], padding=1) - - self.fc6 = nn.Sequential( - nn.Conv2d(256, 4096, kernel_size=3, stride=1, padding=1), - BN(4096), - nn.ReLU(inplace=True)) - self.drop6 = nn.Dropout(0.5) - self.fc7 = nn.Sequential( - nn.Conv2d(4096, 4096, kernel_size=1, stride=1, padding=0), - BN(4096), - nn.ReLU(inplace=True)) - self.drop7 = nn.Dropout(0.5) - self.conv8 = nn.Conv2d(4096, output_dim, kernel_size=1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1] - scale = math.sqrt(2. / fan_in) - m.weight.data.uniform_(-scale, scale) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def forward(self, x, ret_feat=False): - if ret_feat: - raise NotImplemented - x = self.conv1(x) - x = self.pool1(x) - x = self.conv2(x) - x = self.pool2(x) - x = self.conv3(x) - x = self.conv4(x) - x = self.conv5(x) - x = self.pool5(x) - x = self.fc6(x) - x = self.drop6(x) - x = self.fc7(x) - x = self.drop7(x) - x = self.conv8(x) - return x - -def alexnet_fcn_32x(output_dim, pretrained=False, **kwargs): - assert pretrained == False - model = AlexNetBN_FCN(output_dim=output_dim, **kwargs) - return model - -def alexnet_fcn_8x(output_dim, use_ppm=False, pretrained=False, **kwargs): - assert pretrained == False - model = AlexNetBN_FCN(output_dim=output_dim, stride=[2, 2, 2, 1], **kwargs) - return model diff --git a/models/cmp/models/backbone/resnet.py b/models/cmp/models/backbone/resnet.py deleted file mode 100644 index 126ef386b7abba0ff9d09b2f051494ada0cfab30..0000000000000000000000000000000000000000 --- a/models/cmp/models/backbone/resnet.py +++ /dev/null @@ -1,201 +0,0 @@ -import torch.nn as nn -import math -import torch.utils.model_zoo as model_zoo - -BN = None - - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', -} - - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock, self).__init__() - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = BN(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = BN(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = BN(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = BN(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = BN(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - - def __init__(self, output_dim, block, layers): - - global BN - - BN = nn.BatchNorm2d - - self.inplanes = 64 - super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = BN(64) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2) - - self.conv5 = nn.Conv2d(2048, output_dim, kernel_size=1) - - ## dilation - for n, m in self.layer3.named_modules(): - if 'conv2' in n: - m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) - elif 'downsample.0' in n: - m.stride = (1, 1) - for n, m in self.layer4.named_modules(): - if 'conv2' in n: - m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) - elif 'downsample.0' in n: - m.stride = (1, 1) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - BN(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, img, ret_feat=False): - x = self.conv1(img) # 1/2 - x = self.bn1(x) - conv1 = self.relu(x) # 1/2 - pool1 = self.maxpool(conv1) # 1/4 - - layer1 = self.layer1(pool1) # 1/4 - layer2 = self.layer2(layer1) # 1/8 - layer3 = self.layer3(layer2) # 1/8 - layer4 = self.layer4(layer3) # 1/8 - out = self.conv5(layer4) - - if ret_feat: - return out, [img, conv1, layer1] # 3, 64, 256 - else: - return out - -def resnet18(output_dim, pretrained=False): - model = ResNet(output_dim, BasicBlock, [2, 2, 2, 2]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) - return model - - -def resnet34(output_dim, pretrained=False): - model = ResNet(output_dim, BasicBlock, [3, 4, 6, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) - return model - - -def resnet50(output_dim, pretrained=False): - model = ResNet(output_dim, Bottleneck, [3, 4, 6, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) - return model - -def resnet101(output_dim, pretrained=False): - model = ResNet(output_dim, Bottleneck, [3, 4, 23, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False) - return model - - -def resnet152(output_dim, pretrained=False): - model = ResNet(output_dim, Bottleneck, [3, 8, 36, 3]) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False) - return model diff --git a/models/cmp/models/cmp.py b/models/cmp/models/cmp.py deleted file mode 100644 index 11987b4b7c14a2a2e7a2ad01b34e84f7f77bc03f..0000000000000000000000000000000000000000 --- a/models/cmp/models/cmp.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import torch.nn as nn - -import models.cmp.losses as losses -import models.cmp.utils as utils - -from . import SingleStageModel - -class CMP(SingleStageModel): - - def __init__(self, params, dist_model=False): - super(CMP, self).__init__(params, dist_model) - model_params = params['module'] - - # define loss - if model_params['flow_criterion'] == 'L1': - self.flow_criterion = nn.SmoothL1Loss() - elif model_params['flow_criterion'] == 'L2': - self.flow_criterion = nn.MSELoss() - elif model_params['flow_criterion'] == 'DiscreteLoss': - self.flow_criterion = losses.DiscreteLoss( - nbins=model_params['nbins'], fmax=model_params['fmax']) - else: - raise Exception("No such flow loss: {}".format(model_params['flow_criterion'])) - - self.fuser = utils.Fuser(nbins=model_params['nbins'], - fmax=model_params['fmax']) - self.model_params = model_params - - def eval(self, ret_loss=True): - with torch.no_grad(): - cmp_output = self.model(self.image_input, self.sparse_input) - if self.model_params['flow_criterion'] == "DiscreteLoss": - self.flow = self.fuser.convert_flow(cmp_output) - else: - self.flow = cmp_output - if self.flow.shape[2] != self.image_input.shape[2]: - self.flow = nn.functional.interpolate( - self.flow, size=self.image_input.shape[2:4], - mode="bilinear", align_corners=True) - - ret_tensors = { - 'flow_tensors': [self.flow, self.flow_target], - 'common_tensors': [], - 'rgb_tensors': []} # except for image_input - - if ret_loss: - if cmp_output.shape[2] != self.flow_target.shape[2]: - cmp_output = nn.functional.interpolate( - cmp_output, size=self.flow_target.shape[2:4], - mode="bilinear", align_corners=True) - loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size - return ret_tensors, {'loss_flow': loss_flow} - else: - return ret_tensors - - def step(self): - cmp_output = self.model(self.image_input, self.sparse_input) - loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size - self.optim.zero_grad() - loss_flow.backward() - utils.average_gradients(self.model) - self.optim.step() - return {'loss_flow': loss_flow} diff --git a/models/cmp/models/modules/__init__.py b/models/cmp/models/modules/__init__.py deleted file mode 100644 index eff11cb76475f299be4ce9641182686866e00f99..0000000000000000000000000000000000000000 --- a/models/cmp/models/modules/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .warp import * -from .others import * -from .shallownet import * -from .decoder import * -from .cmp import * - diff --git a/models/cmp/models/modules/cmp.py b/models/cmp/models/modules/cmp.py deleted file mode 100644 index 7c5130353000c6f971425a37c2588e45d8710664..0000000000000000000000000000000000000000 --- a/models/cmp/models/modules/cmp.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch -import torch.nn as nn -import models.cmp.models as models - - -class CMP(nn.Module): - - def __init__(self, params): - super(CMP, self).__init__() - img_enc_dim = params['img_enc_dim'] - sparse_enc_dim = params['sparse_enc_dim'] - output_dim = params['output_dim'] - pretrained = params['pretrained_image_encoder'] - decoder_combo = params['decoder_combo'] - self.skip_layer = params['skip_layer'] - if self.skip_layer: - assert params['flow_decoder'] == "MotionDecoderSkipLayer" - - self.image_encoder = models.backbone.__dict__[params['image_encoder']]( - img_enc_dim, pretrained) - self.flow_encoder = models.modules.__dict__[params['sparse_encoder']]( - sparse_enc_dim) - self.flow_decoder = models.modules.__dict__[params['flow_decoder']]( - input_dim=img_enc_dim+sparse_enc_dim, - output_dim=output_dim, combo=decoder_combo) - - def forward(self, image, sparse): - sparse_enc = self.flow_encoder(sparse) - if self.skip_layer: - img_enc, skip_feat = self.image_encoder(image, ret_feat=True) - flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1), skip_feat) - else: - img_enc = self.image_encoder(image) - flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1)) - return flow_dec - - diff --git a/models/cmp/models/modules/decoder.py b/models/cmp/models/modules/decoder.py deleted file mode 100644 index 8f1c0e395f55f4e348410c98d2d37a13441d7139..0000000000000000000000000000000000000000 --- a/models/cmp/models/modules/decoder.py +++ /dev/null @@ -1,358 +0,0 @@ -import torch -import torch.nn as nn -import math - -class MotionDecoderPlain(nn.Module): - - def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4]): - super(MotionDecoderPlain, self).__init__() - BN = nn.BatchNorm2d - - self.combo = combo - for c in combo: - assert c in [1,2,4,8], "invalid combo: {}".format(combo) - - if 1 in combo: - self.decoder1 = nn.Sequential( - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - if 2 in combo: - self.decoder2 = nn.Sequential( - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - if 4 in combo: - self.decoder4 = nn.Sequential( - nn.MaxPool2d(kernel_size=4, stride=4), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - if 8 in combo: - self.decoder8 = nn.Sequential( - nn.MaxPool2d(kernel_size=8, stride=8), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.head = nn.Conv2d(128 * len(self.combo), output_dim, kernel_size=1, padding=0) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1] - scale = math.sqrt(2. / fan_in) - m.weight.data.uniform_(-scale, scale) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - if not m.weight is None: - m.weight.data.fill_(1) - if not m.bias is None: - m.bias.data.zero_() - - def forward(self, x): - - cat_list = [] - if 1 in self.combo: - x1 = self.decoder1(x) - cat_list.append(x1) - if 2 in self.combo: - x2 = nn.functional.interpolate( - self.decoder2(x), size=(x.size(2), x.size(3)), - mode="bilinear", align_corners=True) - cat_list.append(x2) - if 4 in self.combo: - x4 = nn.functional.interpolate( - self.decoder4(x), size=(x.size(2), x.size(3)), - mode="bilinear", align_corners=True) - cat_list.append(x4) - if 8 in self.combo: - x8 = nn.functional.interpolate( - self.decoder8(x), size=(x.size(2), x.size(3)), - mode="bilinear", align_corners=True) - cat_list.append(x8) - - cat = torch.cat(cat_list, dim=1) - flow = self.head(cat) - return flow - - -class MotionDecoderSkipLayer(nn.Module): - - def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]): - super(MotionDecoderSkipLayer, self).__init__() - - BN = nn.BatchNorm2d - - self.decoder1 = nn.Sequential( - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.decoder2 = nn.Sequential( - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.decoder4 = nn.Sequential( - nn.MaxPool2d(kernel_size=4, stride=4), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.decoder8 = nn.Sequential( - nn.MaxPool2d(kernel_size=8, stride=8), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.fusion8 = nn.Sequential( - nn.Conv2d(512, 256, kernel_size=3, padding=1), - BN(256), - nn.ReLU(inplace=True)) - - self.skipconv4 = nn.Sequential( - nn.Conv2d(256, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - self.fusion4 = nn.Sequential( - nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.skipconv2 = nn.Sequential( - nn.Conv2d(64, 32, kernel_size=3, padding=1), - BN(32), - nn.ReLU(inplace=True)) - self.fusion2 = nn.Sequential( - nn.Conv2d(128 + 32, 64, kernel_size=3, padding=1), - BN(64), - nn.ReLU(inplace=True)) - - self.head = nn.Conv2d(64, output_dim, kernel_size=1, padding=0) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1] - scale = math.sqrt(2. / fan_in) - m.weight.data.uniform_(-scale, scale) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - if not m.weight is None: - m.weight.data.fill_(1) - if not m.bias is None: - m.bias.data.zero_() - - def forward(self, x, skip_feat): - layer1, layer2, layer4 = skip_feat - - x1 = self.decoder1(x) - x2 = nn.functional.interpolate( - self.decoder2(x), size=(x1.size(2), x1.size(3)), - mode="bilinear", align_corners=True) - x4 = nn.functional.interpolate( - self.decoder4(x), size=(x1.size(2), x1.size(3)), - mode="bilinear", align_corners=True) - x8 = nn.functional.interpolate( - self.decoder8(x), size=(x1.size(2), x1.size(3)), - mode="bilinear", align_corners=True) - cat = torch.cat([x1, x2, x4, x8], dim=1) - f8 = self.fusion8(cat) - - f8_up = nn.functional.interpolate( - f8, size=(layer4.size(2), layer4.size(3)), - mode="bilinear", align_corners=True) - f4 = self.fusion4(torch.cat([f8_up, self.skipconv4(layer4)], dim=1)) - - f4_up = nn.functional.interpolate( - f4, size=(layer2.size(2), layer2.size(3)), - mode="bilinear", align_corners=True) - f2 = self.fusion2(torch.cat([f4_up, self.skipconv2(layer2)], dim=1)) - - flow = self.head(f2) - return flow - - -class MotionDecoderFlowNet(nn.Module): - - def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]): - super(MotionDecoderFlowNet, self).__init__() - global BN - - BN = nn.BatchNorm2d - - self.decoder1 = nn.Sequential( - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.decoder2 = nn.Sequential( - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.decoder4 = nn.Sequential( - nn.MaxPool2d(kernel_size=4, stride=4), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.decoder8 = nn.Sequential( - nn.MaxPool2d(kernel_size=8, stride=8), - nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 128, kernel_size=3, padding=1), - BN(128), - nn.ReLU(inplace=True)) - - self.fusion8 = nn.Sequential( - nn.Conv2d(512, 256, kernel_size=3, padding=1), - BN(256), - nn.ReLU(inplace=True)) - - # flownet head - self.predict_flow8 = predict_flow(256, output_dim) - self.predict_flow4 = predict_flow(384 + output_dim, output_dim) - self.predict_flow2 = predict_flow(192 + output_dim, output_dim) - self.predict_flow1 = predict_flow(67 + output_dim, output_dim) - - self.upsampled_flow8_to_4 = nn.ConvTranspose2d( - output_dim, output_dim, 4, 2, 1, bias=False) - self.upsampled_flow4_to_2 = nn.ConvTranspose2d( - output_dim, output_dim, 4, 2, 1, bias=False) - self.upsampled_flow2_to_1 = nn.ConvTranspose2d( - output_dim, output_dim, 4, 2, 1, bias=False) - - self.deconv8 = deconv(256, 128) - self.deconv4 = deconv(384 + output_dim, 128) - self.deconv2 = deconv(192 + output_dim, 64) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1] - scale = math.sqrt(2. / fan_in) - m.weight.data.uniform_(-scale, scale) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - if not m.weight is None: - m.weight.data.fill_(1) - if not m.bias is None: - m.bias.data.zero_() - - def forward(self, x, skip_feat): - layer1, layer2, layer4 = skip_feat # 3, 64, 256 - - # propagation nets - x1 = self.decoder1(x) - x2 = nn.functional.interpolate( - self.decoder2(x), size=(x1.size(2), x1.size(3)), - mode="bilinear", align_corners=True) - x4 = nn.functional.interpolate( - self.decoder4(x), size=(x1.size(2), x1.size(3)), - mode="bilinear", align_corners=True) - x8 = nn.functional.interpolate( - self.decoder8(x), size=(x1.size(2), x1.size(3)), - mode="bilinear", align_corners=True) - cat = torch.cat([x1, x2, x4, x8], dim=1) - feat8 = self.fusion8(cat) # 256 - - # flownet head - flow8 = self.predict_flow8(feat8) - flow8_up = self.upsampled_flow8_to_4(flow8) - out_deconv8 = self.deconv8(feat8) # 128 - - concat4 = torch.cat((layer4, out_deconv8, flow8_up), dim=1) # 394 + out - flow4 = self.predict_flow4(concat4) - flow4_up = self.upsampled_flow4_to_2(flow4) - out_deconv4 = self.deconv4(concat4) # 128 - - concat2 = torch.cat((layer2, out_deconv4, flow4_up), dim=1) # 192 + out - flow2 = self.predict_flow2(concat2) - flow2_up = self.upsampled_flow2_to_1(flow2) - out_deconv2 = self.deconv2(concat2) # 64 - - concat1 = torch.cat((layer1, out_deconv2, flow2_up), dim=1) # 67 + out - flow1 = self.predict_flow1(concat1) - - return [flow1, flow2, flow4, flow8] - - -def predict_flow(in_planes, out_planes): - return nn.Conv2d(in_planes, out_planes, kernel_size=3, - stride=1, padding=1, bias=True) - - -def deconv(in_planes, out_planes): - return nn.Sequential( - nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, - stride=2, padding=1, bias=True), - nn.LeakyReLU(0.1, inplace=True) - ) - - diff --git a/models/cmp/models/modules/others.py b/models/cmp/models/modules/others.py deleted file mode 100644 index 591ce94f7d10db49fb3209d4d74a4a973f9a6cf5..0000000000000000000000000000000000000000 --- a/models/cmp/models/modules/others.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch.nn as nn - -class FixModule(nn.Module): - - def __init__(self, m): - super(FixModule, self).__init__() - self.module = m - - def forward(self, *args, **kwargs): - return self.module(*args, **kwargs) - diff --git a/models/cmp/models/modules/shallownet.py b/models/cmp/models/modules/shallownet.py deleted file mode 100644 index b37fedd26b5096e34c0e6303f69e54b3d58c39b4..0000000000000000000000000000000000000000 --- a/models/cmp/models/modules/shallownet.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch.nn as nn -import math - -class ShallowNet(nn.Module): - - def __init__(self, input_dim=4, output_dim=16, stride=[2, 2, 2]): - super(ShallowNet, self).__init__() - global BN - - BN = nn.BatchNorm2d - - self.features = nn.Sequential( - nn.Conv2d(input_dim, 16, kernel_size=5, stride=stride[0], padding=2), - nn.BatchNorm2d(16), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=stride[1], stride=stride[1]), - nn.Conv2d(16, output_dim, kernel_size=3, padding=1), - nn.BatchNorm2d(output_dim), - nn.ReLU(inplace=True), - nn.AvgPool2d(kernel_size=stride[2], stride=stride[2]), - ) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1] - scale = math.sqrt(2. / fan_in) - m.weight.data.uniform_(-scale, scale) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - if not m.weight is None: - m.weight.data.fill_(1) - if not m.bias is None: - m.bias.data.zero_() - - def forward(self, x): - x = self.features(x) - return x - - -def shallownet8x(output_dim): - model = ShallowNet(output_dim=output_dim, stride=[2,2,2]) - return model - -def shallownet32x(output_dim, **kwargs): - model = ShallowNet(output_dim=output_dim, stride=[2,2,8]) - return model - - - diff --git a/models/cmp/models/modules/warp.py b/models/cmp/models/modules/warp.py deleted file mode 100644 index d32dc5db787345c9d2622fa6f65d463dd78ef8ba..0000000000000000000000000000000000000000 --- a/models/cmp/models/modules/warp.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import torch.nn as nn - -class WarpingLayerBWFlow(nn.Module): - - def __init__(self): - super(WarpingLayerBWFlow, self).__init__() - - def forward(self, image, flow): - flow_for_grip = torch.zeros_like(flow) - flow_for_grip[:,0,:,:] = flow[:,0,:,:] / ((flow.size(3) - 1.0) / 2.0) - flow_for_grip[:,1,:,:] = flow[:,1,:,:] / ((flow.size(2) - 1.0) / 2.0) - - torchHorizontal = torch.linspace( - -1.0, 1.0, image.size(3)).view( - 1, 1, 1, image.size(3)).expand( - image.size(0), 1, image.size(2), image.size(3)) - torchVertical = torch.linspace( - -1.0, 1.0, image.size(2)).view( - 1, 1, image.size(2), 1).expand( - image.size(0), 1, image.size(2), image.size(3)) - grid = torch.cat([torchHorizontal, torchVertical], 1).cuda() - - grid = (grid + flow_for_grip).permute(0, 2, 3, 1) - return torch.nn.functional.grid_sample(image, grid) - - -class WarpingLayerFWFlow(nn.Module): - - def __init__(self): - super(WarpingLayerFWFlow, self).__init__() - self.initialized = False - - def forward(self, image, flow, ret_mask = False): - n, h, w = image.size(0), image.size(2), image.size(3) - - if not self.initialized or n != self.meshx.shape[0] or h * w != self.meshx.shape[1]: - self.meshx = torch.arange(w).view(1, 1, w).expand( - n, h, w).contiguous().view(n, -1).cuda() - self.meshy = torch.arange(h).view(1, h, 1).expand( - n, h, w).contiguous().view(n, -1).cuda() - self.warped_image = torch.zeros((n, 3, h, w), dtype=torch.float32).cuda() - if ret_mask: - self.hole_mask = torch.ones((n, 1, h, w), dtype=torch.float32).cuda() - self.initialized = True - - v = (flow[:,0,:,:] ** 2 + flow[:,1,:,:] ** 2).view(n, -1) - _, sortidx = torch.sort(v, dim=1) - - warped_meshx = self.meshx + flow[:,0,:,:].long().view(n, -1) - warped_meshy = self.meshy + flow[:,1,:,:].long().view(n, -1) - - warped_meshx = torch.clamp(warped_meshx, 0, w - 1) - warped_meshy = torch.clamp(warped_meshy, 0, h - 1) - - self.warped_image.zero_() - if ret_mask: - self.hole_mask.fill_(1.) - for i in range(n): - for c in range(3): - ind = sortidx[i] - self.warped_image[i,c,warped_meshy[i][ind],warped_meshx[i][ind]] = image[i,c,self.meshy[i][ind],self.meshx[i][ind]] - if ret_mask: - self.hole_mask[i,0,warped_meshy[i],warped_meshx[i]] = 0. - if ret_mask: - return self.warped_image, self.hole_mask - else: - return self.warped_image diff --git a/models/cmp/models/single_stage_model.py b/models/cmp/models/single_stage_model.py deleted file mode 100644 index 79f5d4ab7ccffba99f72612ad6c77bf4dc3f2521..0000000000000000000000000000000000000000 --- a/models/cmp/models/single_stage_model.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -import torch -import torch.backends.cudnn as cudnn -import torch.distributed as dist - -import models.cmp.models as models -import models.cmp.utils as utils - - -class SingleStageModel(object): - - def __init__(self, params, dist_model=False): - model_params = params['module'] - self.model = models.modules.__dict__[params['module']['arch']](model_params) - utils.init_weights(self.model, init_type='xavier') - self.model.cuda() - if dist_model: - self.model = utils.DistModule(self.model) - self.world_size = dist.get_world_size() - else: - self.model = models.modules.FixModule(self.model) - self.world_size = 1 - - if params['optim'] == 'SGD': - self.optim = torch.optim.SGD( - self.model.parameters(), lr=params['lr'], - momentum=0.9, weight_decay=0.0001) - elif params['optim'] == 'Adam': - self.optim = torch.optim.Adam( - self.model.parameters(), lr=params['lr'], - betas=(params['beta1'], 0.999)) - else: - raise Exception("No such optimizer: {}".format(params['optim'])) - - cudnn.benchmark = True - - def set_input(self, image_input, sparse_input, flow_target=None, rgb_target=None): - self.image_input = image_input - self.sparse_input = sparse_input - self.flow_target = flow_target - self.rgb_target = rgb_target - - def eval(self, ret_loss=True): - pass - - def step(self): - pass - - def load_state(self, path, Iter, resume=False): - path = os.path.join(path, "ckpt_iter_{}.pth.tar".format(Iter)) - - if resume: - utils.load_state(path, self.model, self.optim) - else: - utils.load_state(path, self.model) - - def load_pretrain(self, load_path): - utils.load_state(load_path, self.model) - - def save_state(self, path, Iter): - path = os.path.join(path, "ckpt_iter_{}.pth.tar".format(Iter)) - - torch.save({ - 'step': Iter, - 'state_dict': self.model.state_dict(), - 'optimizer': self.optim.state_dict()}, path) - - def switch_to(self, phase): - if phase == 'train': - self.model.train() - else: - self.model.eval() diff --git a/models/cmp/utils/__init__.py b/models/cmp/utils/__init__.py deleted file mode 100644 index 29be9c14049e0540b324db3fc65eedf1b492358e..0000000000000000000000000000000000000000 --- a/models/cmp/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .common_utils import * -from .data_utils import * -from .distributed_utils import * -from .visualize_utils import * -from .scheduler import * -from . import flowlib diff --git a/models/cmp/utils/common_utils.py b/models/cmp/utils/common_utils.py deleted file mode 100644 index 0a3862068c32b5094b7ee3caa045d250c1d63264..0000000000000000000000000000000000000000 --- a/models/cmp/utils/common_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import logging -import numpy as np - -import torch -from torch.nn import init - -def init_weights(net, init_type='normal', init_gain=0.02): - """Initialize network weights. - Parameters: - net (network) -- network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might - work better for some applications. Feel free to try yourself. - """ - def init_func(m): # define the initialization function - classname = m.__class__.__name__ - if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): - if init_type == 'normal': - init.normal_(m.weight.data, 0.0, init_gain) - elif init_type == 'xavier': - init.xavier_normal_(m.weight.data, gain=init_gain) - elif init_type == 'kaiming': - init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - elif init_type == 'orthogonal': - init.orthogonal_(m.weight.data, gain=init_gain) - else: - raise NotImplementedError('initialization method [%s] is not implemented' % init_type) - if hasattr(m, 'bias') and m.bias is not None: - init.constant_(m.bias.data, 0.0) - elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. - init.normal_(m.weight.data, 1.0, init_gain) - init.constant_(m.bias.data, 0.0) - - net.apply(init_func) # apply the initialization function - -def create_logger(name, log_file, level=logging.INFO): - l = logging.getLogger(name) - formatter = logging.Formatter('[%(asctime)s] %(message)s') - fh = logging.FileHandler(log_file) - fh.setFormatter(formatter) - sh = logging.StreamHandler() - sh.setFormatter(formatter) - l.setLevel(level) - l.addHandler(fh) - l.addHandler(sh) - return l - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self, length=0): - self.length = length - self.reset() - - def reset(self): - if self.length > 0: - self.history = [] - else: - self.count = 0 - self.sum = 0.0 - self.val = 0.0 - self.avg = 0.0 - - def update(self, val): - if self.length > 0: - self.history.append(val) - if len(self.history) > self.length: - del self.history[0] - - self.val = self.history[-1] - self.avg = np.mean(self.history) - else: - self.val = val - self.sum += val - self.count += 1 - self.avg = self.sum / self.count - -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdims=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - -def load_state(path, model, optimizer=None): - def map_func(storage, location): - return storage.cuda() - if os.path.isfile(path): - print("=> loading checkpoint '{}'".format(path)) - checkpoint = torch.load(path, map_location=map_func) - model.load_state_dict(checkpoint['state_dict'], strict=False) - ckpt_keys = set(checkpoint['state_dict'].keys()) - own_keys = set(model.state_dict().keys()) - missing_keys = own_keys - ckpt_keys - # print(ckpt_keys) - # print(own_keys) - for k in missing_keys: - print('caution: missing keys from checkpoint {}: {}'.format(path, k)) - - last_iter = checkpoint['step'] - if optimizer != None: - optimizer.load_state_dict(checkpoint['optimizer']) - print("=> also loaded optimizer from checkpoint '{}' (iter {})" - .format(path, last_iter)) - return last_iter - else: - print("=> no checkpoint found at '{}'".format(path)) - - diff --git a/models/cmp/utils/data_utils.py b/models/cmp/utils/data_utils.py deleted file mode 100644 index 0651fc5d9fefa638c76f86ebdb9db3139fecf991..0000000000000000000000000000000000000000 --- a/models/cmp/utils/data_utils.py +++ /dev/null @@ -1,280 +0,0 @@ -from PIL import Image, ImageOps -import scipy.ndimage as ndimage -import cv2 -import random -import numpy as np -from scipy.ndimage.filters import maximum_filter -from scipy import signal -cv2.ocl.setUseOpenCL(False) - -def get_edge(data, blur=False): - if blur: - data = cv2.GaussianBlur(data, (3, 3), 1.) - sobel = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]).astype(np.float32) - ch_edges = [] - for k in range(data.shape[2]): - edgex = signal.convolve2d(data[:,:,k], sobel, boundary='symm', mode='same') - edgey = signal.convolve2d(data[:,:,k], sobel.T, boundary='symm', mode='same') - ch_edges.append(np.sqrt(edgex**2 + edgey**2)) - return sum(ch_edges) - -def get_max(score, bbox): - u = max(0, bbox[0]) - d = min(score.shape[0], bbox[1]) - l = max(0, bbox[2]) - r = min(score.shape[1], bbox[3]) - return score[u:d,l:r].max() - -def nms(score, ks): - assert ks % 2 == 1 - ret_score = score.copy() - maxpool = maximum_filter(score, footprint=np.ones((ks, ks))) - ret_score[score < maxpool] = 0. - return ret_score - -def image_flow_crop(img1, img2, flow, crop_size, phase): - assert len(crop_size) == 2 - pad_h = max(crop_size[0] - img1.height, 0) - pad_w = max(crop_size[1] - img1.width, 0) - pad_h_half = int(pad_h / 2) - pad_w_half = int(pad_w / 2) - if pad_h > 0 or pad_w > 0: - flow_expand = np.zeros((img1.height + pad_h, img1.width + pad_w, 2), dtype=np.float32) - flow_expand[pad_h_half:pad_h_half+img1.height, pad_w_half:pad_w_half+img1.width, :] = flow - flow = flow_expand - border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half) - img1 = ImageOps.expand(img1, border=border, fill=(0,0,0)) - img2 = ImageOps.expand(img2, border=border, fill=(0,0,0)) - if phase == 'train': - hoff = int(np.random.rand() * (img1.height - crop_size[0])) - woff = int(np.random.rand() * (img1.width - crop_size[1])) - else: - hoff = (img1.height - crop_size[0]) // 2 - woff = (img1.width - crop_size[1]) // 2 - - img1 = img1.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])) - img2 = img2.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])) - flow = flow[hoff:hoff+crop_size[0], woff:woff+crop_size[1], :] - offset = (hoff, woff) - return img1, img2, flow, offset - -def image_crop(img, crop_size): - pad_h = max(crop_size[0] - img.height, 0) - pad_w = max(crop_size[1] - img.width, 0) - pad_h_half = int(pad_h / 2) - pad_w_half = int(pad_w / 2) - if pad_h > 0 or pad_w > 0: - border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half) - img = ImageOps.expand(img, border=border, fill=(0,0,0)) - hoff = (img.height - crop_size[0]) // 2 - woff = (img.width - crop_size[1]) // 2 - return img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])), (pad_w_half, pad_h_half) - -def image_flow_resize(img1, img2, flow, short_size=None, long_size=None): - assert (short_size is None) ^ (long_size is None) - w, h = img1.width, img1.height - if short_size is not None: - if w < h: - neww = short_size - newh = int(short_size / float(w) * h) - else: - neww = int(short_size / float(h) * w) - newh = short_size - else: - if w < h: - neww = int(long_size / float(h) * w) - newh = long_size - else: - neww = long_size - newh = int(long_size / float(w) * h) - img1 = img1.resize((neww, newh), Image.BICUBIC) - img2 = img2.resize((neww, newh), Image.BICUBIC) - ratio = float(newh) / h - flow = cv2.resize(flow.copy(), (neww, newh), interpolation=cv2.INTER_LINEAR) * ratio - return img1, img2, flow, ratio - -def image_resize(img, short_size=None, long_size=None): - assert (short_size is None) ^ (long_size is None) - w, h = img.width, img.height - if short_size is not None: - if w < h: - neww = short_size - newh = int(short_size / float(w) * h) - else: - neww = int(short_size / float(h) * w) - newh = short_size - else: - if w < h: - neww = int(long_size / float(h) * w) - newh = long_size - else: - neww = long_size - newh = int(long_size / float(w) * h) - img = img.resize((neww, newh), Image.BICUBIC) - return img, [w, h] - - -def image_pose_crop(img, posemap, crop_size, scale): - assert len(crop_size) == 2 - assert crop_size[0] <= img.height - assert crop_size[1] <= img.width - hoff = (img.height - crop_size[0]) // 2 - woff = (img.width - crop_size[1]) // 2 - img = img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])) - posemap = posemap[hoff//scale:hoff//scale+crop_size[0]//scale, woff//scale:woff//scale+crop_size[1]//scale,:] - return img, posemap - -def neighbor_elim(ph, pw, d): - valid = np.ones((len(ph))).astype(np.int) - h_dist = np.fabs(np.tile(ph[:,np.newaxis], [1,len(ph)]) - np.tile(ph.T[np.newaxis,:], [len(ph),1])) - w_dist = np.fabs(np.tile(pw[:,np.newaxis], [1,len(pw)]) - np.tile(pw.T[np.newaxis,:], [len(pw),1])) - idx1, idx2 = np.where((h_dist < d) & (w_dist < d)) - for i,j in zip(idx1, idx2): - if valid[i] and valid[j] and i != j: - if np.random.rand() > 0.5: - valid[i] = 0 - else: - valid[j] = 0 - valid_idx = np.where(valid==1) - return ph[valid_idx], pw[valid_idx] - -def remove_border(mask): - mask[0,:] = 0 - mask[:,0] = 0 - mask[mask.shape[0]-1,:] = 0 - mask[:,mask.shape[1]-1] = 0 - -def flow_sampler(flow, strategy=['grid'], bg_ratio=1./6400, nms_ks=15, max_num_guide=-1, guidepoint=None): - assert bg_ratio >= 0 and bg_ratio <= 1, "sampling ratio must be in (0, 1]" - for s in strategy: - assert s in ['grid', 'uniform', 'gradnms', 'watershed', 'single', 'full', 'specified'], "No such strategy: {}".format(s) - h = flow.shape[0] - w = flow.shape[1] - ds = max(1, max(h, w) // 400) # reduce computation - - if 'full' in strategy: - sparse = flow.copy() - mask = np.ones(flow.shape, dtype=np.int) - return sparse, mask - - pts_h = [] - pts_w = [] - if 'grid' in strategy: - stride = int(np.sqrt(1./bg_ratio)) - mesh_start_h = int((h - h // stride * stride) / 2) - mesh_start_w = int((w - w // stride * stride) / 2) - mesh = np.meshgrid(np.arange(mesh_start_h, h, stride), np.arange(mesh_start_w, w, stride)) - pts_h.append(mesh[0].flat) - pts_w.append(mesh[1].flat) - if 'uniform' in strategy: - pts_h.append(np.random.randint(0, h, int(bg_ratio * h * w))) - pts_w.append(np.random.randint(0, w, int(bg_ratio * h * w))) - if "gradnms" in strategy: - ks = w // ds // 20 - edge = get_edge(flow[::ds,::ds,:]) - kernel = np.ones((ks, ks), dtype=np.float32) / (ks * ks) - subkernel = np.ones((ks//2, ks//2), dtype=np.float32) / (ks//2 * ks//2) - score = signal.convolve2d(edge, kernel, boundary='symm', mode='same') - subscore = signal.convolve2d(edge, subkernel, boundary='symm', mode='same') - score = score / score.max() - subscore / subscore.max() - nms_res = nms(score, nms_ks) - pth, ptw = np.where(nms_res > 0.1) - pts_h.append(pth * ds) - pts_w.append(ptw * ds) - if "watershed" in strategy: - edge = get_edge(flow[::ds,::ds,:]) - edge /= max(edge.max(), 0.01) - edge = (edge > 0.1).astype(np.float32) - watershed = ndimage.distance_transform_edt(1-edge) - nms_res = nms(watershed, nms_ks) - remove_border(nms_res) - pth, ptw = np.where(nms_res > 0) - pth, ptw = neighbor_elim(pth, ptw, (nms_ks-1)/2) - pts_h.append(pth * ds) - pts_w.append(ptw * ds) - if "single" in strategy: - pth, ptw = np.where((flow[:,:,0] != 0) | (flow[:,:,1] != 0)) - randidx = np.random.randint(len(pth)) - pts_h.append(pth[randidx:randidx+1]) - pts_w.append(ptw[randidx:randidx+1]) - if 'specified' in strategy: - assert guidepoint is not None, "if using \"specified\", switch \"with_info\" on." - pts_h.append(guidepoint[:,1]) - pts_w.append(guidepoint[:,0]) - - pts_h = np.concatenate(pts_h) - pts_w = np.concatenate(pts_w) - - if max_num_guide == -1: - max_num_guide = np.inf - - randsel = np.random.permutation(len(pts_h))[:len(pts_h)] - selidx = randsel[np.arange(min(max_num_guide, len(randsel)))] - pts_h = pts_h[selidx] - pts_w = pts_w[selidx] - - sparse = np.zeros(flow.shape, dtype=flow.dtype) - mask = np.zeros(flow.shape, dtype=np.int) - - sparse[:, :, 0][(pts_h, pts_w)] = flow[:, :, 0][(pts_h, pts_w)] - sparse[:, :, 1][(pts_h, pts_w)] = flow[:, :, 1][(pts_h, pts_w)] - - mask[:,:,0][(pts_h, pts_w)] = 1 - mask[:,:,1][(pts_h, pts_w)] = 1 - return sparse, mask - -def image_flow_aug(img1, img2, flow, flip_horizon=True): - if flip_horizon: - if random.random() < 0.5: - img1 = img1.transpose(Image.FLIP_LEFT_RIGHT) - img2 = img2.transpose(Image.FLIP_LEFT_RIGHT) - flow = flow[:,::-1,:].copy() - flow[:,:,0] = -flow[:,:,0] - return img1, img2, flow - -def flow_aug(flow, reverse=True, scale=True, rotate=True): - if reverse: - if random.random() < 0.5: - flow = -flow - if scale: - rand_scale = random.uniform(0.5, 2.0) - flow = flow * rand_scale - if rotate and random.random() < 0.5: - lengh = np.sqrt(np.square(flow[:,:,0]) + np.square(flow[:,:,1])) - alpha = np.arctan(flow[:,:,1] / flow[:,:,0]) - theta = random.uniform(0, np.pi*2) - flow[:,:,0] = lengh * np.cos(alpha + theta) - flow[:,:,1] = lengh * np.sin(alpha + theta) - return flow - -def draw_gaussian(img, pt, sigma, type='Gaussian'): - # Check that any part of the gaussian is in-bounds - ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] - br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] - if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or - br[0] < 0 or br[1] < 0): - # If not, just return the image as is - return img - - # Generate gaussian - size = 6 * sigma + 1 - x = np.arange(0, size, 1, float) - y = x[:, np.newaxis] - x0 = y0 = size // 2 - # The gaussian is not normalized, we want the center value to equal 1 - if type == 'Gaussian': - g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) - elif type == 'Cauchy': - g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) - - # Usable gaussian range - g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] - g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] - # Image range - img_x = max(0, ul[0]), min(br[0], img.shape[1]) - img_y = max(0, ul[1]), min(br[1], img.shape[0]) - - img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] - return img - - diff --git a/models/cmp/utils/distributed_utils.py b/models/cmp/utils/distributed_utils.py deleted file mode 100644 index 97056fc313c198ea11ec96a6ad7575db5de6b302..0000000000000000000000000000000000000000 --- a/models/cmp/utils/distributed_utils.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -import subprocess -import numpy as np -import multiprocessing as mp -import math - -import torch -import torch.distributed as dist -from torch.utils.data.sampler import Sampler -from torch.nn import Module - -class DistModule(Module): - def __init__(self, module): - super(DistModule, self).__init__() - self.module = module - broadcast_params(self.module) - def forward(self, *inputs, **kwargs): - return self.module(*inputs, **kwargs) - def train(self, mode=True): - super(DistModule, self).train(mode) - self.module.train(mode) - -def average_gradients(model): - """ average gradients """ - for param in model.parameters(): - if param.requires_grad: - dist.all_reduce(param.grad.data) - -def broadcast_params(model): - """ broadcast model parameters """ - for p in model.state_dict().values(): - dist.broadcast(p, 0) - -def dist_init(launcher, backend='nccl', **kwargs): - if mp.get_start_method(allow_none=True) is None: - mp.set_start_method('spawn') - if launcher == 'pytorch': - _init_dist_pytorch(backend, **kwargs) - elif launcher == 'mpi': - _init_dist_mpi(backend, **kwargs) - elif launcher == 'slurm': - _init_dist_slurm(backend, **kwargs) - else: - raise ValueError('Invalid launcher type: {}'.format(launcher)) - -def _init_dist_pytorch(backend, **kwargs): - rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) - -def _init_dist_mpi(backend, **kwargs): - raise NotImplementedError - -def _init_dist_slurm(backend, port=10086, **kwargs): - proc_id = int(os.environ['SLURM_PROCID']) - ntasks = int(os.environ['SLURM_NTASKS']) - node_list = os.environ['SLURM_NODELIST'] - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(proc_id % num_gpus) - addr = subprocess.getoutput( - 'scontrol show hostname {} | head -n1'.format(node_list)) - os.environ['MASTER_PORT'] = str(port) - os.environ['MASTER_ADDR'] = addr - os.environ['WORLD_SIZE'] = str(ntasks) - os.environ['RANK'] = str(proc_id) - dist.init_process_group(backend=backend) - -def gather_tensors(input_array): - world_size = dist.get_world_size() - ## gather shapes first - myshape = input_array.shape - mycount = input_array.size - shape_tensor = torch.Tensor(np.array(myshape)).cuda() - all_shape = [torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)] - dist.all_gather(all_shape, shape_tensor) - ## compute largest shapes - all_shape = [x.cpu().numpy() for x in all_shape] - all_count = [int(x.prod()) for x in all_shape] - all_shape = [list(map(int, x)) for x in all_shape] - max_count = max(all_count) - ## padding tensors and gather them - output_tensors = [torch.Tensor(max_count).cuda() for i in range(world_size)] - padded_input_array = np.zeros(max_count) - padded_input_array[:mycount] = input_array.reshape(-1) - input_tensor = torch.Tensor(padded_input_array).cuda() - dist.all_gather(output_tensors, input_tensor) - ## unpadding gathered tensors - padded_output = [x.cpu().numpy() for x in output_tensors] - output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)] - return output - -def gather_tensors_batch(input_array, part_size=10): - # gather - rank = dist.get_rank() - all_features = [] - part_num = input_array.shape[0] // part_size + 1 if input_array.shape[0] % part_size != 0 else input_array.shape[0] // part_size - for i in range(part_num): - part_feat = input_array[i * part_size:min((i+1)*part_size, input_array.shape[0]),...] - assert part_feat.shape[0] > 0, "rank: {}, length of part features should > 0".format(rank) - print("rank: {}, gather part: {}/{}, length: {}".format(rank, i, part_num, len(part_feat))) - gather_part_feat = gather_tensors(part_feat) - all_features.append(gather_part_feat) - print("rank: {}, gather done.".format(rank)) - all_features = np.concatenate([np.concatenate([all_features[i][j] for i in range(part_num)], axis=0) for j in range(len(all_features[0]))], axis=0) - return all_features - -def reduce_tensors(tensor): - reduced_tensor = tensor.clone() - dist.all_reduce(reduced_tensor) - return reduced_tensor - -class DistributedSequentialSampler(Sampler): - def __init__(self, dataset, world_size=None, rank=None): - if world_size == None: - world_size = dist.get_world_size() - if rank == None: - rank = dist.get_rank() - self.dataset = dataset - self.world_size = world_size - self.rank = rank - assert len(self.dataset) >= self.world_size, '{} vs {}'.format(len(self.dataset), self.world_size) - sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size)) - self.beg = sub_num * self.rank - #self.end = min(self.beg+sub_num, len(self.dataset)) - self.end = self.beg + sub_num - self.padded_ind = list(range(len(self.dataset))) + list(range(sub_num * self.world_size - len(self.dataset))) - - def __iter__(self): - indices = [self.padded_ind[i] for i in range(self.beg, self.end)] - return iter(indices) - - def __len__(self): - return self.end - self.beg - -class GivenIterationSampler(Sampler): - def __init__(self, dataset, total_iter, batch_size, last_iter=-1): - self.dataset = dataset - self.total_iter = total_iter - self.batch_size = batch_size - self.last_iter = last_iter - - self.total_size = self.total_iter * self.batch_size - self.indices = self.gen_new_list() - self.call = 0 - - def __iter__(self): - if self.call == 0: - self.call = 1 - return iter(self.indices[(self.last_iter + 1) * self.batch_size:]) - else: - raise RuntimeError("this sampler is not designed to be called more than once!!") - - def gen_new_list(self): - - # each process shuffle all list with same seed, and pick one piece according to rank - np.random.seed(0) - - all_size = self.total_size - indices = np.arange(len(self.dataset)) - indices = indices[:all_size] - num_repeat = (all_size-1) // indices.shape[0] + 1 - indices = np.tile(indices, num_repeat) - indices = indices[:all_size] - - np.random.shuffle(indices) - - assert len(indices) == self.total_size - - return indices - - def __len__(self): - return self.total_size - - -class DistributedGivenIterationSampler(Sampler): - def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1): - if world_size is None: - world_size = dist.get_world_size() - if rank is None: - rank = dist.get_rank() - assert rank < world_size - self.dataset = dataset - self.total_iter = total_iter - self.batch_size = batch_size - self.world_size = world_size - self.rank = rank - self.last_iter = last_iter - - self.total_size = self.total_iter*self.batch_size - - self.indices = self.gen_new_list() - self.call = 0 - - def __iter__(self): - if self.call == 0: - self.call = 1 - return iter(self.indices[(self.last_iter+1)*self.batch_size:]) - else: - raise RuntimeError("this sampler is not designed to be called more than once!!") - - def gen_new_list(self): - - # each process shuffle all list with same seed, and pick one piece according to rank - np.random.seed(0) - - all_size = self.total_size * self.world_size - indices = np.arange(len(self.dataset)) - indices = indices[:all_size] - num_repeat = (all_size-1) // indices.shape[0] + 1 - indices = np.tile(indices, num_repeat) - indices = indices[:all_size] - - np.random.shuffle(indices) - beg = self.total_size * self.rank - indices = indices[beg:beg+self.total_size] - - assert len(indices) == self.total_size - - return indices - - def __len__(self): - # note here we do not take last iter into consideration, since __len__ - # should only be used for displaying, the correct remaining size is - # handled by dataloader - #return self.total_size - (self.last_iter+1)*self.batch_size - return self.total_size - - diff --git a/models/cmp/utils/flowlib.py b/models/cmp/utils/flowlib.py deleted file mode 100644 index 8a0ab1a8bf3cbe05b55c50449319a55d4ae8d1ee..0000000000000000000000000000000000000000 --- a/models/cmp/utils/flowlib.py +++ /dev/null @@ -1,308 +0,0 @@ -#!/usr/bin/python -""" -# ============================== -# flowlib.py -# library for optical flow processing -# Author: Ruoteng Li -# Date: 6th Aug 2016 -# ============================== -""" -#import png -import numpy as np -from PIL import Image -import io - -UNKNOWN_FLOW_THRESH = 1e7 -SMALLFLOW = 0.0 -LARGEFLOW = 1e8 - -""" -============= -Flow Section -============= -""" - -def write_flow(flow, filename): - """ - write optical flow in Middlebury .flo format - :param flow: optical flow map - :param filename: optical flow file path to be saved - :return: None - """ - f = open(filename, 'wb') - magic = np.array([202021.25], dtype=np.float32) - (height, width) = flow.shape[0:2] - w = np.array([width], dtype=np.int32) - h = np.array([height], dtype=np.int32) - magic.tofile(f) - w.tofile(f) - h.tofile(f) - flow.tofile(f) - f.close() - - -def save_flow_image(flow, image_file): - """ - save flow visualization into image file - :param flow: optical flow data - :param flow_fil - :return: None - """ - flow_img = flow_to_image(flow) - img_out = Image.fromarray(flow_img) - img_out.save(image_file) - -def segment_flow(flow): - h = flow.shape[0] - w = flow.shape[1] - u = flow[:, :, 0] - v = flow[:, :, 1] - - idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW)) - idx2 = (abs(u) == SMALLFLOW) - class0 = (v == 0) & (u == 0) - u[idx2] = 0.00001 - tan_value = v / u - - class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0) - class2 = (tan_value >= 1) & (u >= 0) & (v >= 0) - class3 = (tan_value < -1) & (u <= 0) & (v >= 0) - class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0) - class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0) - class7 = (tan_value < -1) & (u >= 0) & (v <= 0) - class6 = (tan_value >= 1) & (u <= 0) & (v <= 0) - class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0) - - seg = np.zeros((h, w)) - - seg[class1] = 1 - seg[class2] = 2 - seg[class3] = 3 - seg[class4] = 4 - seg[class5] = 5 - seg[class6] = 6 - seg[class7] = 7 - seg[class8] = 8 - seg[class0] = 0 - seg[idx] = 0 - - return seg - -def flow_to_image(flow): - """ - Convert flow into middlebury color code image - :param flow: optical flow map - :return: optical flow image in middlebury color - """ - u = flow[:, :, 0] - v = flow[:, :, 1] - - maxu = -999. - maxv = -999. - minu = 999. - minv = 999. - - idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) - u[idxUnknow] = 0 - v[idxUnknow] = 0 - - maxu = max(maxu, np.max(u)) - minu = min(minu, np.min(u)) - - maxv = max(maxv, np.max(v)) - minv = min(minv, np.min(v)) - - rad = np.sqrt(u ** 2 + v ** 2) - maxrad = max(5, np.max(rad)) - #maxrad = max(-1, 99) - - u = u/(maxrad + np.finfo(float).eps) - v = v/(maxrad + np.finfo(float).eps) - - img = compute_color(u, v) - - idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) - img[idx] = 0 - - return np.uint8(img) - -def disp_to_flowfile(disp, filename): - """ - Read KITTI disparity file in png format - :param disp: disparity matrix - :param filename: the flow file name to save - :return: None - """ - f = open(filename, 'wb') - magic = np.array([202021.25], dtype=np.float32) - (height, width) = disp.shape[0:2] - w = np.array([width], dtype=np.int32) - h = np.array([height], dtype=np.int32) - empty_map = np.zeros((height, width), dtype=np.float32) - data = np.dstack((disp, empty_map)) - magic.tofile(f) - w.tofile(f) - h.tofile(f) - data.tofile(f) - f.close() - -def compute_color(u, v): - """ - compute optical flow color map - :param u: optical flow horizontal map - :param v: optical flow vertical map - :return: optical flow in color code - """ - [h, w] = u.shape - img = np.zeros([h, w, 3]) - nanIdx = np.isnan(u) | np.isnan(v) - u[nanIdx] = 0 - v[nanIdx] = 0 - - colorwheel = make_color_wheel() - ncols = np.size(colorwheel, 0) - - rad = np.sqrt(u**2+v**2) - - a = np.arctan2(-v, -u) / np.pi - - fk = (a+1) / 2 * (ncols - 1) + 1 - - k0 = np.floor(fk).astype(int) - - k1 = k0 + 1 - k1[k1 == ncols+1] = 1 - f = fk - k0 - - for i in range(0, np.size(colorwheel,1)): - tmp = colorwheel[:, i] - col0 = tmp[k0-1] / 255 - col1 = tmp[k1-1] / 255 - col = (1-f) * col0 + f * col1 - - idx = rad <= 1 - col[idx] = 1-rad[idx]*(1-col[idx]) - notidx = np.logical_not(idx) - - col[notidx] *= 0.75 - img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) - - return img - - -def make_color_wheel(): - """ - Generate color wheel according Middlebury color code - :return: Color wheel - """ - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - - colorwheel = np.zeros([ncols, 3]) - - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) - col += RY - - # YG - colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) - colorwheel[col:col+YG, 1] = 255 - col += YG - - # GC - colorwheel[col:col+GC, 1] = 255 - colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) - col += GC - - # CB - colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) - colorwheel[col:col+CB, 2] = 255 - col += CB - - # BM - colorwheel[col:col+BM, 2] = 255 - colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) - col += + BM - - # MR - colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) - colorwheel[col:col+MR, 0] = 255 - - return colorwheel - - -def read_flo_file(filename, memcached=False): - """ - Read from Middlebury .flo file - :param flow_file: name of the flow file - :return: optical flow data in matrix - """ - if memcached: - filename = io.BytesIO(filename) - f = open(filename, 'rb') - magic = np.fromfile(f, np.float32, count=1)[0] - data2d = None - - if 202021.25 != magic: - print('Magic number incorrect. Invalid .flo file') - else: - w = np.fromfile(f, np.int32, count=1)[0] - h = np.fromfile(f, np.int32, count=1)[0] - data2d = np.fromfile(f, np.float32, count=2 * w * h) - # reshape data into 3D array (columns, rows, channels) - data2d = np.resize(data2d, (h, w, 2)) - f.close() - return data2d - - -# fast resample layer -def resample(img, sz): - """ - img: flow map to be resampled - sz: new flow map size. Must be [height,weight] - """ - original_image_size = img.shape - in_height = img.shape[0] - in_width = img.shape[1] - out_height = sz[0] - out_width = sz[1] - out_flow = np.zeros((out_height, out_width, 2)) - # find scale - height_scale = float(in_height) / float(out_height) - width_scale = float(in_width) / float(out_width) - - [x,y] = np.meshgrid(range(out_width), range(out_height)) - xx = x * width_scale - yy = y * height_scale - x0 = np.floor(xx).astype(np.int32) - x1 = x0 + 1 - y0 = np.floor(yy).astype(np.int32) - y1 = y0 + 1 - - x0 = np.clip(x0,0,in_width-1) - x1 = np.clip(x1,0,in_width-1) - y0 = np.clip(y0,0,in_height-1) - y1 = np.clip(y1,0,in_height-1) - - Ia = img[y0,x0,:] - Ib = img[y1,x0,:] - Ic = img[y0,x1,:] - Id = img[y1,x1,:] - - wa = (y1-yy) * (x1-xx) - wb = (yy-y0) * (x1-xx) - wc = (y1-yy) * (xx-x0) - wd = (yy-y0) * (xx-x0) - out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width - out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height - - return out_flow diff --git a/models/cmp/utils/scheduler.py b/models/cmp/utils/scheduler.py deleted file mode 100644 index 4f34f6321b0b1f567c656e57ee0e32c6baafe9ce..0000000000000000000000000000000000000000 --- a/models/cmp/utils/scheduler.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -from bisect import bisect_right - -class _LRScheduler(object): - def __init__(self, optimizer, last_iter=-1): - if not isinstance(optimizer, torch.optim.Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) - self.optimizer = optimizer - if last_iter == -1: - for group in optimizer.param_groups: - group.setdefault('initial_lr', group['lr']) - else: - for i, group in enumerate(optimizer.param_groups): - if 'initial_lr' not in group: - raise KeyError("param 'initial_lr' is not specified " - "in param_groups[{}] when resuming an optimizer".format(i)) - self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) - self.last_iter = last_iter - - def _get_new_lr(self): - raise NotImplementedError - - def get_lr(self): - return list(map(lambda group: group['lr'], self.optimizer.param_groups)) - - def step(self, this_iter=None): - if this_iter is None: - this_iter = self.last_iter + 1 - self.last_iter = this_iter - for param_group, lr in zip(self.optimizer.param_groups, self._get_new_lr()): - param_group['lr'] = lr - -class _WarmUpLRSchedulerOld(_LRScheduler): - - def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1): - self.base_lr = base_lr - self.warmup_steps = warmup_steps - if warmup_steps == 0: - self.warmup_lr = base_lr - else: - self.warmup_lr = warmup_lr - super(_WarmUpLRSchedulerOld, self).__init__(optimizer, last_iter) - - def _get_warmup_lr(self): - if self.warmup_steps > 0 and self.last_iter < self.warmup_steps: - # first compute relative scale for self.base_lr, then multiply to base_lr - scale = ((self.last_iter/self.warmup_steps)*(self.warmup_lr - self.base_lr) + self.base_lr)/self.base_lr - #print('last_iter: {}, warmup_lr: {}, base_lr: {}, scale: {}'.format(self.last_iter, self.warmup_lr, self.base_lr, scale)) - return [scale * base_lr for base_lr in self.base_lrs] - else: - return None - -class _WarmUpLRScheduler(_LRScheduler): - - def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1): - self.base_lr = base_lr - self.warmup_lr = warmup_lr - self.warmup_steps = warmup_steps - assert isinstance(warmup_lr, list) - assert isinstance(warmup_steps, list) - assert len(warmup_lr) == len(warmup_steps) - super(_WarmUpLRScheduler, self).__init__(optimizer, last_iter) - - def _get_warmup_lr(self): - pos = bisect_right(self.warmup_steps, self.last_iter) - if pos >= len(self.warmup_steps): - return None - else: - if pos == 0: - curr_lr = self.base_lr + self.last_iter * (self.warmup_lr[pos] - self.base_lr) / self.warmup_steps[pos] - else: - curr_lr = self.warmup_lr[pos - 1] + (self.last_iter - self.warmup_steps[pos - 1]) * (self.warmup_lr[pos] - self.warmup_lr[pos - 1]) / (self.warmup_steps[pos] - self.warmup_steps[pos - 1]) - scale = curr_lr / self.base_lr - return [scale * base_lr for base_lr in self.base_lrs] - -class StepLRScheduler(_WarmUpLRScheduler): - def __init__(self, optimizer, milestones, lr_mults, base_lr, warmup_lr, warmup_steps, last_iter=-1): - super(StepLRScheduler, self).__init__(optimizer, base_lr, warmup_lr, warmup_steps, last_iter) - - assert len(milestones) == len(lr_mults), "{} vs {}".format(milestones, lr_mults) - for x in milestones: - assert isinstance(x, int) - if not list(milestones) == sorted(milestones): - raise ValueError('Milestones should be a list of' - ' increasing integers. Got {}', milestones) - self.milestones = milestones - self.lr_mults = [1.0] - for x in lr_mults: - self.lr_mults.append(self.lr_mults[-1]*x) - - def _get_new_lr(self): - warmup_lrs = self._get_warmup_lr() - if warmup_lrs is not None: - return warmup_lrs - - pos = bisect_right(self.milestones, self.last_iter) - if len(self.warmup_lr) == 0: - scale = self.lr_mults[pos] - else: - scale = self.warmup_lr[-1] * self.lr_mults[pos] / self.base_lr - return [base_lr * scale for base_lr in self.base_lrs] diff --git a/models/cmp/utils/visualize_utils.py b/models/cmp/utils/visualize_utils.py deleted file mode 100644 index cfb4796a980156e9a9e23f0cf86604ba24dfbc4e..0000000000000000000000000000000000000000 --- a/models/cmp/utils/visualize_utils.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np - -import torch -from . import flowlib - -class Fuser(object): - def __init__(self, nbins, fmax): - self.nbins = nbins - self.fmax = fmax - self.step = 2 * fmax / float(nbins) - self.mesh = torch.arange(nbins).view(1,-1,1,1).float().cuda() * self.step - fmax + self.step / 2 - - def convert_flow(self, flow_prob): - flow_probx = torch.nn.functional.softmax(flow_prob[:, :self.nbins, :, :], dim=1) - flow_proby = torch.nn.functional.softmax(flow_prob[:, self.nbins:, :, :], dim=1) - flow_probx = flow_probx * self.mesh - flow_proby = flow_proby * self.mesh - flow = torch.cat([flow_probx.sum(dim=1, keepdim=True), flow_proby.sum(dim=1, keepdim=True)], dim=1) - return flow - -def visualize_tensor_old(image, mask, flow_pred, flow_target, warped, rgb_gen, image_target, image_mean, image_div): - together = [ - draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.)), - flow_to_image(flow_pred.detach().cpu()), - flow_to_image(flow_target.detach().cpu())] - if warped is not None: - together.append(torch.clamp(unormalize(warped.detach().cpu(), mean=image_mean, div=image_div), 0, 255)) - if rgb_gen is not None: - together.append(torch.clamp(unormalize(rgb_gen.detach().cpu(), mean=image_mean, div=image_div), 0, 255)) - if image_target is not None: - together.append(torch.clamp(unormalize(image_target.cpu(), mean=image_mean, div=image_div), 0, 255)) - together = torch.cat(together, dim=3) - return together - -def visualize_tensor(image, mask, flow_tensors, common_tensors, rgb_tensors, image_mean, image_div): - together = [ - draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.))] - for ft in flow_tensors: - together.append(flow_to_image(ft.cpu())) - for ct in common_tensors: - together.append(torch.clamp(ct.cpu(), 0, 255)) - for rt in rgb_tensors: - together.append(torch.clamp(unormalize(rt.cpu(), mean=image_mean, div=image_div), 0, 255)) - together = torch.cat(together, dim=3) - return together - - -def unormalize(tensor, mean, div): - for c, (m, d) in enumerate(zip(mean, div)): - tensor[:,c,:,:].mul_(d).add_(m) - return tensor - - -def flow_to_image(flow): - flow = flow.numpy() - flow_img = np.array([flowlib.flow_to_image(fl.transpose((1,2,0))).transpose((2,0,1)) for fl in flow]).astype(np.float32) - return torch.from_numpy(flow_img) - -def shift_tensor(input, offh, offw): - new = torch.zeros(input.size()) - h = input.size(2) - w = input.size(3) - new[:,:,max(0,offh):min(h,h+offh),max(0,offw):min(w,w+offw)] = input[:,:,max(0,-offh):min(h,h-offh),max(0,-offw):min(w,w-offw)] - return new - -def draw_block(mask, radius=5): - ''' - input: tensor (NxCxHxW) - output: block_mask (Nx1xHxW) - ''' - all_mask = [] - mask = mask[:,0:1,:,:] - for offh in range(-radius, radius+1): - for offw in range(-radius, radius+1): - all_mask.append(shift_tensor(mask, offh, offw)) - block_mask = sum(all_mask) - block_mask[block_mask > 0] = 1 - return block_mask - -def expand_block(sparse, radius=5): - ''' - input: sparse (NxCxHxW) - output: block_sparse (NxCxHxW) - ''' - all_sparse = [] - for offh in range(-radius, radius+1): - for offw in range(-radius, radius+1): - all_sparse.append(shift_tensor(sparse, offh, offw)) - block_sparse = sum(all_sparse) - return block_sparse - -def draw_cross(tensor, mask, radius=5, thickness=2): - ''' - input: tensor (NxCxHxW) - mask (NxXxHxW) - output: new_tensor (NxCxHxW) - ''' - all_mask = [] - mask = mask[:,0:1,:,:] - for off in range(-radius, radius+1): - for t in range(-thickness, thickness+1): - all_mask.append(shift_tensor(mask, off, t)) - all_mask.append(shift_tensor(mask, t, off)) - cross_mask = sum(all_mask) - new_tensor = tensor.clone() - new_tensor[:,0:1,:,:][cross_mask > 0] = 255.0 - new_tensor[:,1:2,:,:][cross_mask > 0] = 0.0 - new_tensor[:,2:3,:,:][cross_mask > 0] = 0.0 - return new_tensor diff --git a/models/controlnet_sdv.py b/models/controlnet_sdv.py deleted file mode 100644 index d45f1597955446b5e8e6e92ac0346a94a56828f4..0000000000000000000000000000000000000000 --- a/models/controlnet_sdv.py +++ /dev/null @@ -1,782 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn -from torch.nn import functional as F - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import FromOriginalControlnetMixin -from diffusers.utils import BaseOutput, logging -from diffusers.models.attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, - AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, -) -from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_3d_blocks import ( - get_down_block, get_up_block,UNetMidBlockSpatioTemporal, -) -from diffusers.models import UNetSpatioTemporalConditionModel - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class ControlNetOutput(BaseOutput): - """ - The output of [`ControlNetModel`]. - - Args: - down_block_res_samples (`tuple[torch.Tensor]`): - A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should - be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be - used to condition the original UNet's downsampling activations. - mid_down_block_re_sample (`torch.Tensor`): - The activation of the midde block (the lowest sample resolution). Each tensor should be of shape - `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. - Output can be used to condition the original UNet's middle block activation. - """ - - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - - -class ControlNetConditioningEmbeddingSVD(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames - #combine batch and frames dimensions - batch_size, frames, channels, height, width = conditioning.size() - conditioning = conditioning.view(batch_size * frames, channels, height, width) - - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - #split them apart again - #actually not needed - #new_channels, new_height, new_width = embedding.shape[1], embedding.shape[2], embedding.shape[3] - #embedding = embedding.view(batch_size, frames, new_channels, new_height, new_width) - - - return embedding - - -class ControlNetSDVModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): - r""" - A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample - shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): - The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): - The tuple of upsample blocks to use. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - addition_time_embed_dim: (`int`, defaults to 256): - Dimension to to encode the additional time ids. - projection_class_embeddings_input_dim (`int`, defaults to 768): - The dimension of the projection of encoded `added_time_ids`. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], - [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. - num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): - The number of attention heads. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 8, - out_channels: int = 4, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "DownBlockSpatioTemporal", - ), - up_block_types: Tuple[str] = ( - "UpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - addition_time_embed_dim: int = 256, - projection_class_embeddings_input_dim: int = 768, - layers_per_block: Union[int, Tuple[int]] = 2, - cross_attention_dim: Union[int, Tuple[int]] = 1024, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), - num_frames: int = 25, - conditioning_channels: int = 3, - conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256), - ): - super().__init__() - self.sample_size = sample_size - - print("layers per block is", layers_per_block) - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - - # input - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[0], - kernel_size=3, - padding=1, - ) - - # time - time_embed_dim = block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - blocks_time_embed_dim = time_embed_dim - self.controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - # down - output_channel = block_out_channels[0] - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - - - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=1e-5, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - resnet_act_fn="silu", - ) - self.down_blocks.append(down_block) - - for _ in range(layers_per_block[i]): - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - if not is_final_block: - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - - # mid - mid_block_channel = block_out_channels[-1] - controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_mid_block = controlnet_block - - - self.mid_block = UNetMidBlockSpatioTemporal( - block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - transformer_layers_per_block=transformer_layers_per_block[-1], - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - ) - - - - - # out - #self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) - #self.conv_act = nn.SiLU() - - #self.conv_out = nn.Conv2d( - # block_out_channels[0], - # out_channels, - # kernel_size=3, - # padding=1, - #) - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors( - name: str, - module: torch.nn.Module, - processors: Dict[str, AttentionProcessor], - ): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - """ - Sets the attention processor to use [feed forward - chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). - - Parameters: - chunk_size (`int`, *optional*): - The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually - over each tensor of dim=`dim`. - dim (`int`, *optional*, defaults to `0`): - The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) - or dim=1 (sequence length). - """ - if dim not in [0, 1]: - raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") - - # By default chunk size is 1 - chunk_size = chunk_size or 1 - - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, chunk_size, dim) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - added_time_ids: torch.Tensor, - controlnet_cond: torch.FloatTensor = None, - image_only_indicator: Optional[torch.Tensor] = None, - return_dict: bool = True, - guess_mode: bool = False, - conditioning_scale: float = 1.0, - - - ) -> Union[ControlNetOutput, Tuple]: - r""" - The [`UNetSpatioTemporalConditionModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.FloatTensor`): - The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. - added_time_ids: (`torch.FloatTensor`): - The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal - embeddings and added to the time embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain - tuple. - Returns: - [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. - """ - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - batch_size, num_frames = sample.shape[:2] - timesteps = timesteps.expand(batch_size) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - # print(t_emb.dtype) - - emb = self.time_embedding(t_emb) - - time_embeds = self.add_time_proj(added_time_ids.flatten()) - time_embeds = time_embeds.reshape((batch_size, -1)) - time_embeds = time_embeds.to(emb.dtype) - aug_emb = self.add_embedding(time_embeds) - emb = emb + aug_emb - - # Flatten the batch and frames dimensions - # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] - sample = sample.flatten(0, 1) - # Repeat the embeddings num_video_frames times - # emb: [batch, channels] -> [batch * frames, channels] - emb = emb.repeat_interleave(num_frames, dim=0) - # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] - encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) - - # 2. pre-process - sample = self.conv_in(sample) - - #controlnet cond - if controlnet_cond != None: - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample = sample + controlnet_cond - - - image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) - - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) - else: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - image_only_indicator=image_only_indicator, - ) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) - - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(sample) - - # 6. scaling - - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale - - if not return_dict: - return (down_block_res_samples, mid_block_res_sample) - - return ControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample - ) - - - @classmethod - def from_unet( - cls, - unet: UNetSpatioTemporalConditionModel, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - load_weights_from_unet: bool = True, - conditioning_channels: int = 3, - ): - r""" - Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. - - Parameters: - unet (`UNet2DConditionModel`): - The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied - where applicable. - """ - - transformer_layers_per_block = ( - unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 - ) - encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None - encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None - addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None - addition_time_embed_dim = ( - unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None - ) - print(unet.config) - controlnet = cls( - in_channels=unet.config.in_channels, - down_block_types=unet.config.down_block_types, - block_out_channels=unet.config.block_out_channels, - addition_time_embed_dim=unet.config.addition_time_embed_dim, - transformer_layers_per_block=unet.config.transformer_layers_per_block, - cross_attention_dim=unet.config.cross_attention_dim, - num_attention_heads=unet.config.num_attention_heads, - num_frames=unet.config.num_frames, - sample_size=unet.config.sample_size, # Added based on the dict - layers_per_block=unet.config.layers_per_block, - projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, - conditioning_channels = conditioning_channels, - conditioning_embedding_out_channels = conditioning_embedding_out_channels, - ) - #controlnet rgb channel order ignored, set to not makea difference by default - - if load_weights_from_unet: - controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) - controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) - controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) - - # if controlnet.class_embedding: - # controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) - - controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) - controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) - - return controlnet - - @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False - ): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor, _remove_lora=_remove_lora) - else: - module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor, _remove_lora=True) - - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - # def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - # if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - # module.gradient_checkpointing = value - - -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module diff --git a/models/ldmk_ctrlnet.py b/models/ldmk_ctrlnet.py deleted file mode 100644 index 50239721612b1db928e9f59a45c82364cc40a967..0000000000000000000000000000000000000000 --- a/models/ldmk_ctrlnet.py +++ /dev/null @@ -1,575 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union -from dataclasses import dataclass - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from diffusers.configuration_utils import register_to_config -from diffusers.utils import BaseOutput - -from models.controlnet_sdv import ControlNetSDVModel, zero_module -from models.softsplat import softsplat -import models.cmp.models as cmp_models -import models.cmp.utils as cmp_utils -from models.occlusion.hourglass import ForegroundMatting - -import yaml -import os -import torchvision.transforms as transforms - - -class ArgObj(object): - def __init__(self): - pass - - -class CMP_demo(nn.Module): - def __init__(self, configfn, load_iter): - super().__init__() - args = ArgObj() - with open(configfn) as f: - config = yaml.full_load(f) - for k, v in config.items(): - setattr(args, k, v) - setattr(args, 'load_iter', load_iter) - setattr(args, 'exp_path', os.path.dirname(configfn)) - - self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False) - self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False) - self.model.switch_to('eval') - - self.data_mean = args.data['data_mean'] - self.data_div = args.data['data_div'] - - self.img_transform = transforms.Compose([ - transforms.Normalize(self.data_mean, self.data_div)]) - - self.args = args - self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax']) - torch.cuda.synchronize() - - def run(self, image, sparse, mask): - image = image * 2 - 1 - cmp_output = self.model.model(image, torch.cat([sparse, mask], dim=1)) - flow = self.fuser.convert_flow(cmp_output) - if flow.shape[2] != image.shape[2]: - flow = nn.functional.interpolate( - flow, size=image.shape[2:4], - mode="bilinear", align_corners=True) - - return flow # [b, 2, h, w] - - # tensor_dict = self.model.eval(ret_loss=False) - # flow = tensor_dict['flow_tensors'][0].cpu().numpy().squeeze().transpose(1,2,0) - - # return flow - - - -class FlowControlNetConditioningEmbeddingSVD(nn.Module): - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - - - -class FlowControlNetFirstFrameEncoderLayer(nn.Module): - - def __init__( - self, - c_in, - c_out, - is_downsample=False - ): - super().__init__() - - self.conv_in = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2 if is_downsample else 1) - - def forward(self, feature): - ''' - feature: [b, c, h, w] - ''' - - embedding = self.conv_in(feature) - embedding = F.silu(embedding) - - return embedding - - - -class FlowControlNetFirstFrameEncoder(nn.Module): - def __init__( - self, - c_in=320, - channels=[320, 640, 1280], - downsamples=[True, True, True], - use_zeroconv=True - ): - super().__init__() - - self.encoders = nn.ModuleList([]) - # self.zeroconvs = nn.ModuleList([]) - - for channel, downsample in zip(channels, downsamples): - self.encoders.append(FlowControlNetFirstFrameEncoderLayer(c_in, channel, is_downsample=downsample)) - # self.zeroconvs.append(zero_module(nn.Conv2d(channel, channel, kernel_size=1)) if use_zeroconv else nn.Identity()) - c_in = channel - - def forward(self, first_frame): - feature = first_frame - deep_features = [] - # for encoder, zeroconv in zip(self.encoders, self.zeroconvs): - for encoder in self.encoders: - feature = encoder(feature) - # print(feature.shape) - # deep_features.append(zeroconv(feature)) - deep_features.append(feature) - return deep_features - - - -@dataclass -class FlowControlNetOutput(BaseOutput): - """ - The output of [`FlowControlNetOutput`]. - - Args: - down_block_res_samples (`tuple[torch.Tensor]`): - A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should - be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be - used to condition the original UNet's downsampling activations. - mid_down_block_re_sample (`torch.Tensor`): - The activation of the midde block (the lowest sample resolution). Each tensor should be of shape - `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. - Output can be used to condition the original UNet's middle block activation. - """ - - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - controlnet_flow: torch.Tensor - occlusion_masks: torch.Tensor - - -class FlowControlNet(ControlNetSDVModel): - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 8, - out_channels: int = 4, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "DownBlockSpatioTemporal", - ), - up_block_types: Tuple[str] = ( - "UpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - addition_time_embed_dim: int = 256, - projection_class_embeddings_input_dim: int = 768, - layers_per_block: Union[int, Tuple[int]] = 2, - cross_attention_dim: Union[int, Tuple[int]] = 1024, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), - num_frames: int = 25, - conditioning_channels: int = 3, - conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256), - ): - super().__init__() - - self.flow_encoder = FlowControlNetFirstFrameEncoder() - - # time_embed_dim = block_out_channels[0] * 4 - # blocks_time_embed_dim = time_embed_dim - self.controlnet_cond_embedding = FlowControlNetConditioningEmbeddingSVD( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - self.controlnet_ldmk_embedding = FlowControlNetConditioningEmbeddingSVD( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=(16, 32, 64, 128), - conditioning_channels=conditioning_channels, - ) - - self.zero_outs = nn.ModuleDict( - { - '8': zero_module(nn.Conv2d(320, 320, kernel_size=1)), - '16': zero_module(nn.Conv2d(320, 320, kernel_size=1)), - '32': zero_module(nn.Conv2d(640, 640, kernel_size=1)), - '64': zero_module(nn.Conv2d(1280, 1280, kernel_size=1)) - } - ) - - self.occlusions = nn.ModuleDict( - { - '8': ForegroundMatting(320), - '16': ForegroundMatting(320), - '32': ForegroundMatting(640), - '64': ForegroundMatting(1280), - } - ) - - # self.occlusions = nn.ModuleDict( - # {'8': nn.Sequential( - # nn.Conv2d(320+320, 128, 7, 1, 3), - # nn.SiLU(), - # nn.Conv2d(128, 64, 5, 1, 2), - # nn.SiLU(), - # nn.Conv2d(64, 1, 3, 1, 1), - # nn.Sigmoid() - # ), - # '16': nn.Sequential( - # nn.Conv2d(320+320, 128, 5, 1, 2), - # nn.SiLU(), - # nn.Conv2d(128, 64, 5, 1, 2), - # nn.SiLU(), - # nn.Conv2d(64, 1, 3, 1, 1), - # nn.Sigmoid() - # ), - # '32': nn.Sequential( - # nn.Conv2d(640+640, 128, 5, 1, 2), - # nn.SiLU(), - # nn.Conv2d(128, 64, 3, 1, 1), - # nn.SiLU(), - # nn.Conv2d(64, 1, 3, 1, 1), - # nn.Sigmoid() - # ), - # '64': nn.Sequential( - # nn.Conv2d(1280+1280, 128, 3, 1, 1), - # nn.SiLU(), - # nn.Conv2d(128, 64, 3, 1, 1), - # nn.SiLU(), - # nn.Conv2d(64, 1, 3, 1, 1), - # nn.Sigmoid() - # )} - # ) - - def get_warped_frames(self, first_frame, flows, scale): - ''' - video_frame: [b, c, w, h] - flows: [b, t-1, c, w, h] - ''' - dtype = first_frame.dtype - warped_frames = [] - occlusion_masks = [] - for i in range(flows.shape[1]): - warped_frame = softsplat(tenIn=first_frame.float(), tenFlow=flows[:, i].float(), tenMetric=None, strMode='avg').to(dtype) # [b, c, w, h] - - # print(first_frame.shape) - # print(warped_frame.shape) - - # occlusion_mask = self.occlusions[str(scale)](torch.cat([first_frame, warped_frame], dim=1)) # [b, 1, w, h] - # warped_frame = warped_frame * occlusion_mask - - warped_frame, occlusion_mask = self.occlusions[str(scale)]( - first_frame, flows[:, i], warped_frame - ) - - # occlusion_mask = torch.ones_like(warped_frame[:, 0:1, :, :]) - - warped_frame = self.zero_outs[str(scale)](warped_frame) - - warped_frames.append(warped_frame.unsqueeze(1)) # [b, 1, c, w, h] - occlusion_masks.append(occlusion_mask.unsqueeze(1)) # [b, 1, 1, w, h] - warped_frames = torch.cat(warped_frames, dim=1) # [b, t-1, c, w, h] - occlusion_masks = torch.cat(occlusion_masks, dim=1) # [b, t-1, 1, w, h] - return warped_frames, occlusion_masks - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - added_time_ids: torch.Tensor, - controlnet_cond: torch.FloatTensor = None, # [b, 3, h, w] - controlnet_flow: torch.FloatTensor = None, # [b, 13, 2, h, w] - landmarks: torch.FloatTensor = None, # [b, 14, 2, h, w] - # controlnet_mask: torch.FloatTensor = None, # [b, 13, 2, h, w] - # pixel_values_384: torch.FloatTensor = None, - # sparse_optical_flow_384: torch.FloatTensor = None, - # mask_384: torch.FloatTensor = None, - image_only_indicator: Optional[torch.Tensor] = None, - return_dict: bool = True, - guess_mode: bool = False, - conditioning_scale: float = 1.0, - ) -> Union[FlowControlNetOutput, Tuple]: - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - batch_size, num_frames = sample.shape[:2] - timesteps = timesteps.expand(batch_size) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb) - - time_embeds = self.add_time_proj(added_time_ids.flatten()) - time_embeds = time_embeds.reshape((batch_size, -1)) - time_embeds = time_embeds.to(emb.dtype) - aug_emb = self.add_embedding(time_embeds) - emb = emb + aug_emb - - # Flatten the batch and frames dimensions - # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] - sample = sample.flatten(0, 1) - # Repeat the embeddings num_video_frames times - # emb: [batch, channels] -> [batch * frames, channels] - emb = emb.repeat_interleave(num_frames, dim=0) - # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] - encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) - - # 2. pre-process - sample = self.conv_in(sample) # [b*l, 320, h//8, w//8] - - # controlnet cond - if controlnet_cond != None: - # embed 成 64*64,和latent一个shape - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # [b, 320, h//8, w//8] - # sample = sample + controlnet_cond - - # ldb, ldl, ldc, ldh, ldw = landmarks.shape - - landmarks = landmarks.flatten(0, 1) - - # print(landmarks.shape) - # print(sample.shape) - - if landmarks != None: - # embed 成 64*64,和latent一个shape - landmarks = self.controlnet_ldmk_embedding(landmarks) # [b, 320, h//8, w//8] - - scale_landmarks = {landmarks.shape[-2]: landmarks} - for scale in [2, 4]: - scaled_ldmk = F.interpolate(landmarks, scale_factor=1/scale) - # print(scaled_ldmk.shape) - scale_landmarks[scaled_ldmk.shape[-2]] = scaled_ldmk - - - # assert False - controlnet_cond_features = [controlnet_cond] + self.flow_encoder(controlnet_cond) # [4] - - # print(controlnet_cond.shape) - - ''' - torch.Size([2, 320, 32, 32]) - torch.Size([2, 320, 16, 16]) - torch.Size([2, 640, 8, 8]) - torch.Size([2, 1280, 4, 4]) - ''' - - # for x in controlnet_cond_features: - # print(x.shape) - - # assert False - - scales = [8, 16, 32, 64] - scale_flows = {} - fb, fl, fc, fh, fw = controlnet_flow.shape - # print(controlnet_flow.shape) - for scale in scales: - scaled_flow = F.interpolate(controlnet_flow.reshape(-1, fc, fh, fw), scale_factor=1/scale) - scaled_flow = scaled_flow.reshape(fb, fl, fc, fh // scale, fw // scale) / scale - scale_flows[scale] = scaled_flow - - # for k in scale_flows.keys(): - # print(scale_flows[k].shape) - - # assert False - - warped_cond_features = [] - occlusion_masks = [] - for cond_feature in controlnet_cond_features: - cb, cc, ch, cw = cond_feature.shape - # print(cond_feature.shape) - warped_cond_feature, occlusion_mask = self.get_warped_frames(cond_feature, scale_flows[fh // ch], fh // ch) - warped_cond_feature = torch.cat([cond_feature.unsqueeze(1), warped_cond_feature], dim=1) # [b, c, h, w] - wb, wl, wc, wh, ww = warped_cond_feature.shape - # print(warped_cond_feature.shape) - warped_cond_features.append(warped_cond_feature.reshape(wb * wl, wc, wh, ww)) - occlusion_masks.append(occlusion_mask) - - # for x in warped_cond_features: - # print(x.shape) - # assert False - - ''' - torch.Size([28, 320, 32, 32]) - torch.Size([28, 320, 16, 16]) - torch.Size([28, 640, 8, 8]) - torch.Size([28, 1280, 4, 4]) - ''' - - image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) - - - count = 0 - length = len(warped_cond_features) - - # print(sample.shape) - # print(warped_cond_features[count].shape) - - # add the warped feature in the first scale - sample = sample + warped_cond_features[count] + scale_landmarks[sample.shape[-2]] - count += 1 - - down_block_res_samples = (sample,) - - # print(sample.shape) - - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) - else: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - image_only_indicator=image_only_indicator, - ) - - # print(sample.shape) - # print(warped_cond_features[min(count, length - 1)].shape) - # print(sample.shape[-2]) - # print(scale_landmarks[sample.shape[-2]].shape) - - if sample.shape[1] == 320: - sample = sample + warped_cond_features[min(count, length - 1)] + scale_landmarks[sample.shape[-2]] - else: - sample = sample + warped_cond_features[min(count, length - 1)] - - count += 1 - - down_block_res_samples += res_samples - - # print(len(res_samples)) - # for i in range(len(res_samples)): - # print(res_samples[i].shape) - - # [28, 320, 32, 32] - # [28, 320, 32, 32] - # [28, 320, 16, 16] - - # [28, 640, 16, 16] - # [28, 640, 16, 16] - # [28, 640, 8, 8] - - # [28, 1280, 8, 8] - # [28, 1280, 8, 8] - # [28, 1280, 4, 4] - - # [28, 1280, 4, 4] - # [28, 1280, 4, 4] - - # print(sample.shape) - # print(warped_cond_features[-1].shape) - - # add the warped feature in the last scale - sample = sample + warped_cond_features[-1] - - # sample = sample + warped_cond_features[-1] + scale_landmarks[sample.shape[-2]] - - # 4. mid - sample = self.mid_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) # [b*l, 1280, h // 64, w // 64] - - # print(sample.shape) - - # assert False - - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(sample) - - # 6. scaling - - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale - - # for sample in down_block_res_samples: - # print(torch.max(sample), torch.min(sample)) - # print(torch.max(mid_block_res_sample), torch.min(mid_block_res_sample)) - # assert False - - if not return_dict: - return (down_block_res_samples, mid_block_res_sample, controlnet_flow, occlusion_masks) - - return FlowControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, controlnet_flow=controlnet_flow, occlusion_masks=occlusion_masks - ) - diff --git a/models/occlusion/hourglass.py b/models/occlusion/hourglass.py deleted file mode 100644 index f0b57090ae30b037d66373fb2394b7fd2791ce99..0000000000000000000000000000000000000000 --- a/models/occlusion/hourglass.py +++ /dev/null @@ -1,298 +0,0 @@ -from torch import nn -from torch import nn -import torch.nn.functional as F -import torch -# from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d - -# class ResBlock2d(nn.Module): -# def __init__(self, in_features, kernel_size, padding): -# super(ResBlock2d, self).__init__() -# self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, -# padding=padding) -# self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, -# padding=padding) -# self.norm1 = BatchNorm2d(in_features) -# self.norm2 = BatchNorm2d(in_features) -# self.relu = nn.ReLU() -# def forward(self, x): -# out = self.norm1(x) -# out = self.relu(out) -# out = self.conv1(out) -# out = self.norm2(out) -# out = self.relu(out) -# out = self.conv2(out) -# out += x -# return out - -class UpBlock2d(nn.Module): - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(UpBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - # self.norm = BatchNorm2d(out_features) - self.relu = nn.ReLU() - def forward(self, x): - out = x - # out = F.interpolate(x, scale_factor=2) - out = self.conv(out) - # out = self.norm(out) - out = F.relu(out) - return out - -class DownBlock2d(nn.Module): - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(DownBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - # self.norm = BatchNorm2d(out_features) - # self.pool = nn.AvgPool2d(kernel_size=(2, 2)) - self.relu = nn.ReLU() - def forward(self, x): - out = self.conv(x) - # out = self.norm(out) - out = self.relu(out) - # out = self.pool(out) - return out - -class SameBlock2d(nn.Module): - def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): - super(SameBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, - kernel_size=kernel_size, padding=padding, groups=groups) - # self.norm = BatchNorm2d(out_features) - self.relu = nn.ReLU() - def forward(self, x): - out = self.conv(x) - # out = self.norm(out) - out = self.relu(out) - return out - -class HourglassEncoder(nn.Module): - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(HourglassEncoder, self).__init__() - down_blocks = [] - for i in range(num_blocks): - down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), - min(max_features, block_expansion * (2 ** (i + 1))), - kernel_size=3, padding=1)) - self.down_blocks = nn.ModuleList(down_blocks) - - def forward(self, x): - outs = [x] - for down_block in self.down_blocks: - outs.append(down_block(outs[-1])) - outs = outs[1:] - return outs - -class HourglassDecoder(nn.Module): - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(HourglassDecoder, self).__init__() - up_blocks = [] - for i in range(num_blocks)[::-1]: - in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) - out_filters = min(max_features, block_expansion * (2 ** i)) - up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) - self.up_blocks = nn.ModuleList(up_blocks) - self.out_filters = block_expansion - def forward(self, x): - new_out = None - for up_block in self.up_blocks: - out = x.pop() - if new_out is not None: - out = torch.cat([out, new_out], dim=1) - new_out = up_block(out) - - return new_out - -class Hourglass(nn.Module): - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(Hourglass, self).__init__() - self.encoder = HourglassEncoder(block_expansion, in_features, num_blocks, max_features) - self.decoder = HourglassDecoder(block_expansion, in_features, num_blocks, max_features) - self.out_filters = self.decoder.out_filters - def forward(self, x): - return self.decoder(self.encoder(x)) - -# class AntiAliasInterpolation2d(nn.Module): -# """ -# Band-limited downsampling, for better preservation of the input signal. -# """ -# def __init__(self, channels, scale): -# super(AntiAliasInterpolation2d, self).__init__() -# sigma = (1 / scale - 1) / 2 -# kernel_size = 2 * round(sigma * 4) + 1 -# self.ka = kernel_size // 2 -# self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka - -# kernel_size = [kernel_size, kernel_size] -# sigma = [sigma, sigma] -# # The gaussian kernel is the product of the -# # gaussian function of each dimension. -# kernel = 1 -# meshgrids = torch.meshgrid( -# [ -# torch.arange(size, dtype=torch.float32) -# for size in kernel_size -# ] -# ) -# for size, std, mgrid in zip(kernel_size, sigma, meshgrids): -# mean = (size - 1) / 2 -# kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) - -# # Make sure sum of values in gaussian kernel equals 1. -# kernel = kernel / torch.sum(kernel) -# # Reshape to depthwise convolutional weight -# kernel = kernel.view(1, 1, *kernel.size()) -# kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - -# self.register_buffer('weight', kernel) -# self.groups = channels -# self.scale = scale - -# def forward(self, input): -# if self.scale == 1.0: -# return input - -# out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) -# out = F.conv2d(out, weight=self.weight, groups=self.groups) -# out = F.interpolate(out, scale_factor=(self.scale, self.scale)) - -# return out - -# class Encoder(nn.Module): -# def __init__(self, num_channels, num_down_blocks=3, block_expansion=64, max_features=512, -# ): -# super(Encoder, self).__init__() -# self.in_conv = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) -# down_blocks = [] -# for i in range(num_down_blocks): -# in_features = min(max_features, block_expansion * (2 ** i)) -# out_features = min(max_features, block_expansion * (2 ** (i + 1))) -# down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) -# self.down_blocks = nn.Sequential(*down_blocks) -# def forward(self, image): -# out = self.in_conv(image) -# out = self.down_blocks(out) -# return out - -# class Bottleneck(nn.Module): -# def __init__(self, num_bottleneck_blocks,num_down_blocks=3, block_expansion=64, max_features=512): -# super(Bottleneck, self).__init__() -# bottleneck = [] -# in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) -# for i in range(num_bottleneck_blocks): -# bottleneck.append(ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) -# self.bottleneck = nn.Sequential(*bottleneck) -# def forward(self, feature_map): -# out = self.bottleneck(feature_map) -# return out - -class Decoder(nn.Module): - def __init__(self,num_channels, num_down_blocks=3, block_expansion=64, max_features=512): - super(Decoder, self).__init__() - up_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) - out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) - up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - self.up_blocks = nn.Sequential(*up_blocks) - self.out_conv = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) - self.sigmoid = nn.Sigmoid() - def forward(self, feature_map): - out = self.up_blocks(feature_map) - out = self.out_conv(out) - out = self.sigmoid(out) - return out - -# def warp_image(image, motion_flow): -# _, h_old, w_old, _ = motion_flow.shape -# _, _, h, w = image.shape -# if h_old != h or w_old != w: -# motion_flow = motion_flow.permute(0, 3, 1, 2) -# motion_flow = F.interpolate(motion_flow, size=(h, w), mode='bilinear') -# motion_flow = motion_flow.permute(0, 2, 3, 1) -# return F.grid_sample(image, motion_flow) - -# def make_coordinate_grid(spatial_size, type): -# h, w = spatial_size -# x = torch.arange(w).type(type) -# y = torch.arange(h).type(type) -# x = (2 * (x / (w - 1)) - 1) -# y = (2 * (y / (h - 1)) - 1) -# yy = y.view(-1, 1).repeat(1, w) -# xx = x.view(1, -1).repeat(h, 1) -# meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) -# return meshed - -class ForegroundMatting(nn.Module): - def __init__(self, num_channels, num_blocks=3, block_expansion=64, max_features=512): - super(ForegroundMatting, self).__init__() - # self.down_sample_image = AntiAliasInterpolation2d(num_channels, scale_factor) - # self.down_sample_flow = AntiAliasInterpolation2d(2, scale_factor) - self.hourglass = Hourglass( - block_expansion=block_expansion, - in_features=num_channels * 2 + 2, - max_features=max_features, - num_blocks=num_blocks - ) - - # self.foreground_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) - - self.matting_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) - self.matting = nn.Conv2d(self.hourglass.out_filters, num_channels, kernel_size=(7, 7), padding=(3, 3)) - - # self.scale_factor = scale_factor - self.sigmoid = nn.Sigmoid() - - def forward(self, reference_image, dense_flow, warped_image): - ''' - source_image : b, c, h, w - dense_tensor: b, 2, h, w - warped_image: b, c, h, w - ''' - - # res_out = {} - # batch, _, h, w = reference_image.shape - - # warped_image = warp_image(reference_image, dense_flow)#warp the image with dense flow - # res_out['warped_image'] = warped_image - - hourglass_input = torch.cat([reference_image, dense_flow, warped_image], dim=1) - hourglass_out = self.hourglass(hourglass_input) - - # foreground_mask = self.foreground_mask(hourglass_out) # compute foreground mask - # foreground_mask = self.sigmoid(foreground_mask).permute(0,2,3,1) - # res_out['foreground_mask'] = foreground_mask - # grid_flow = make_coordinate_grid((h, w), dense_flow.type()) - # dense_flow_foreground = dense_flow * foreground_mask + (1-foreground_mask) * grid_flow.unsqueeze(0) ## revise the dense flow - # res_out['dense_flow_foreground'] = dense_flow_foreground - # res_out['dense_flow_foreground_vis'] = dense_flow * foreground_mask - - matting_mask = self.matting_mask(hourglass_out) # compute matting mask - matting_mask = self.sigmoid(matting_mask) - # res_out['matting_mask'] = matting_mask - - matting_image = self.matting(hourglass_out) # computing matting image - # res_out['matting_image'] = matting_image - - out = warped_image * matting_mask + matting_image * (1 - matting_mask) - - return out, matting_mask - - - -if __name__ == '__main__': - - device = 'cuda' - b, c, h, w = 2, 1280, 40, 40 - - m = ForegroundMatting(c).to(device) - - print(m) - - - reference_image = torch.randn(b, c, h, w).to(device) - dense_flow = torch.randn(b, 2, h, w).to(device) - warped_image = torch.randn(b, c, h, w).to(device) - - o = m(reference_image, dense_flow, warped_image) \ No newline at end of file diff --git a/models/softsplat.py b/models/softsplat.py deleted file mode 100644 index f35ccc21604479940c2c86580c287e73f3dc327d..0000000000000000000000000000000000000000 --- a/models/softsplat.py +++ /dev/null @@ -1,529 +0,0 @@ -#!/usr/bin/env python - -import collections -import cupy -import os -import re -import torch -import typing - - -########################################################## - - -objCudacache = {} - - -def cuda_int32(intIn:int): - return cupy.int32(intIn) -# end - - -def cuda_float32(fltIn:float): - return cupy.float32(fltIn) -# end - - -def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): - if 'device' not in objCudacache: - objCudacache['device'] = torch.cuda.get_device_name() - # end - - strKey = strFunction - - for strVariable in objVariables: - objValue = objVariables[strVariable] - - strKey += strVariable - - if objValue is None: - continue - - elif type(objValue) == int: - strKey += str(objValue) - - elif type(objValue) == float: - strKey += str(objValue) - - elif type(objValue) == bool: - strKey += str(objValue) - - elif type(objValue) == str: - strKey += objValue - - elif type(objValue) == torch.Tensor: - strKey += str(objValue.dtype) - strKey += str(objValue.shape) - strKey += str(objValue.stride()) - - elif True: - print(strVariable, type(objValue)) - assert(False) - - # end - # end - - strKey += objCudacache['device'] - - if strKey not in objCudacache: - for strVariable in objVariables: - objValue = objVariables[strVariable] - - if objValue is None: - continue - - elif type(objValue) == int: - strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) - - elif type(objValue) == float: - strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) - - elif type(objValue) == bool: - strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) - - elif type(objValue) == str: - strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: - strKernel = strKernel.replace('{{type}}', 'unsigned char') - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: - strKernel = strKernel.replace('{{type}}', 'half') - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: - strKernel = strKernel.replace('{{type}}', 'float') - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: - strKernel = strKernel.replace('{{type}}', 'double') - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: - strKernel = strKernel.replace('{{type}}', 'int') - - elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: - strKernel = strKernel.replace('{{type}}', 'long') - - elif type(objValue) == torch.Tensor: - print(strVariable, objValue.dtype) - assert(False) - - elif True: - print(strVariable, type(objValue)) - assert(False) - - # end - # end - - while True: - objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) - - if objMatch is None: - break - # end - - intArg = int(objMatch.group(2)) - - strTensor = objMatch.group(4) - intSizes = objVariables[strTensor].size() - - strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) - # end - - while True: - objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) - - if objMatch is None: - break - # end - - intStart = objMatch.span()[1] - intStop = objMatch.span()[1] - intParentheses = 1 - - while True: - intParentheses += 1 if strKernel[intStop] == '(' else 0 - intParentheses -= 1 if strKernel[intStop] == ')' else 0 - - if intParentheses == 0: - break - # end - - intStop += 1 - # end - - intArgs = int(objMatch.group(2)) - strArgs = strKernel[intStart:intStop].split(',') - - assert(intArgs == len(strArgs) - 1) - - strTensor = strArgs[0] - intStrides = objVariables[strTensor].stride() - - strIndex = [] - - for intArg in range(intArgs): - strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') - # end - - strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') - # end - - while True: - objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) - - if objMatch is None: - break - # end - - intStart = objMatch.span()[1] - intStop = objMatch.span()[1] - intParentheses = 1 - - while True: - intParentheses += 1 if strKernel[intStop] == '(' else 0 - intParentheses -= 1 if strKernel[intStop] == ')' else 0 - - if intParentheses == 0: - break - # end - - intStop += 1 - # end - - intArgs = int(objMatch.group(2)) - strArgs = strKernel[intStart:intStop].split(',') - - assert(intArgs == len(strArgs) - 1) - - strTensor = strArgs[0] - intStrides = objVariables[strTensor].stride() - - strIndex = [] - - for intArg in range(intArgs): - strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') - # end - - strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') - # end - - objCudacache[strKey] = { - 'strFunction': strFunction, - 'strKernel': strKernel - } - # end - - return strKey -# end - - -@cupy.memoize(for_each_device=True) -def cuda_launch(strKey:str): - if 'CUDA_HOME' not in os.environ: - os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() - # end - - return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) -# end - - -########################################################## - - -def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): - assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) - - if strMode == 'sum': assert(tenMetric is None) - if strMode == 'avg': assert(tenMetric is None) - if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) - if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) - - if strMode == 'avg': - tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) - - elif strMode.split('-')[0] == 'linear': - tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) - - elif strMode.split('-')[0] == 'soft': - tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) - - # end - - tenOut = softsplat_func.apply(tenIn, tenFlow) - - if strMode.split('-')[0] in ['avg', 'linear', 'soft']: - tenNormalize = tenOut[:, -1:, :, :] - - if len(strMode.split('-')) == 1: - tenNormalize = tenNormalize + 0.0000001 - - elif strMode.split('-')[1] == 'addeps': - tenNormalize = tenNormalize + 0.0000001 - - elif strMode.split('-')[1] == 'zeroeps': - tenNormalize[tenNormalize == 0.0] = 1.0 - - elif strMode.split('-')[1] == 'clipeps': - tenNormalize = tenNormalize.clip(0.0000001, None) - - # end - - tenOut = tenOut[:, :-1, :, :] / tenNormalize - # end - - return tenOut -# end - - -class softsplat_func(torch.autograd.Function): - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) - def forward(self, tenIn, tenFlow): - tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) - - if tenIn.is_cuda == True: - cuda_launch(cuda_kernel('softsplat_out', ''' - extern "C" __global__ void __launch_bounds__(512) softsplat_out( - const int n, - const {{type}}* __restrict__ tenIn, - const {{type}}* __restrict__ tenFlow, - {{type}}* __restrict__ tenOut - ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { - const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); - const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); - const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); - const int intX = ( intIndex ) % SIZE_3(tenOut); - - assert(SIZE_1(tenFlow) == 2); - - {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); - {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); - - if (isfinite(fltX) == false) { return; } - if (isfinite(fltY) == false) { return; } - - {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); - - int intNorthwestX = (int) (floor(fltX)); - int intNorthwestY = (int) (floor(fltY)); - int intNortheastX = intNorthwestX + 1; - int intNortheastY = intNorthwestY; - int intSouthwestX = intNorthwestX; - int intSouthwestY = intNorthwestY + 1; - int intSoutheastX = intNorthwestX + 1; - int intSoutheastY = intNorthwestY + 1; - - {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); - {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); - {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); - {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); - - if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { - atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); - } - - if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { - atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); - } - - if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { - atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); - } - - if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { - atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); - } - } } - ''', { - 'tenIn': tenIn, - 'tenFlow': tenFlow, - 'tenOut': tenOut - }))( - grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), - block=tuple([512, 1, 1]), - args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], - stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) - ) - - elif tenIn.is_cuda != True: - assert(False) - - # end - - self.save_for_backward(tenIn, tenFlow) - - return tenOut - # end - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(self, tenOutgrad): - tenIn, tenFlow = self.saved_tensors - - tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) - - tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None - tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None - - if tenIngrad is not None: - cuda_launch(cuda_kernel('softsplat_ingrad', ''' - extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( - const int n, - const {{type}}* __restrict__ tenIn, - const {{type}}* __restrict__ tenFlow, - const {{type}}* __restrict__ tenOutgrad, - {{type}}* __restrict__ tenIngrad, - {{type}}* __restrict__ tenFlowgrad - ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { - const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); - const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); - const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); - const int intX = ( intIndex ) % SIZE_3(tenIngrad); - - assert(SIZE_1(tenFlow) == 2); - - {{type}} fltIngrad = 0.0f; - - {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); - {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); - - if (isfinite(fltX) == false) { return; } - if (isfinite(fltY) == false) { return; } - - int intNorthwestX = (int) (floor(fltX)); - int intNorthwestY = (int) (floor(fltY)); - int intNortheastX = intNorthwestX + 1; - int intNortheastY = intNorthwestY; - int intSouthwestX = intNorthwestX; - int intSouthwestY = intNorthwestY + 1; - int intSoutheastX = intNorthwestX + 1; - int intSoutheastY = intNorthwestY + 1; - - {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); - {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); - {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); - {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); - - if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { - fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; - } - - if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { - fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; - } - - if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { - fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; - } - - if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { - fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; - } - - tenIngrad[intIndex] = fltIngrad; - } } - ''', { - 'tenIn': tenIn, - 'tenFlow': tenFlow, - 'tenOutgrad': tenOutgrad, - 'tenIngrad': tenIngrad, - 'tenFlowgrad': tenFlowgrad - }))( - grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), - block=tuple([512, 1, 1]), - args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], - stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) - ) - # end - - if tenFlowgrad is not None: - cuda_launch(cuda_kernel('softsplat_flowgrad', ''' - extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( - const int n, - const {{type}}* __restrict__ tenIn, - const {{type}}* __restrict__ tenFlow, - const {{type}}* __restrict__ tenOutgrad, - {{type}}* __restrict__ tenIngrad, - {{type}}* __restrict__ tenFlowgrad - ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { - const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); - const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); - const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); - const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); - - assert(SIZE_1(tenFlow) == 2); - - {{type}} fltFlowgrad = 0.0f; - - {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); - {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); - - if (isfinite(fltX) == false) { return; } - if (isfinite(fltY) == false) { return; } - - int intNorthwestX = (int) (floor(fltX)); - int intNorthwestY = (int) (floor(fltY)); - int intNortheastX = intNorthwestX + 1; - int intNortheastY = intNorthwestY; - int intSouthwestX = intNorthwestX; - int intSouthwestY = intNorthwestY + 1; - int intSoutheastX = intNorthwestX + 1; - int intSoutheastY = intNorthwestY + 1; - - {{type}} fltNorthwest = 0.0f; - {{type}} fltNortheast = 0.0f; - {{type}} fltSouthwest = 0.0f; - {{type}} fltSoutheast = 0.0f; - - if (intC == 0) { - fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); - fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); - fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); - fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); - - } else if (intC == 1) { - fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); - fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); - fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); - fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); - - } - - for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { - {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); - - if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { - fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; - } - - if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { - fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; - } - - if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { - fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; - } - - if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { - fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; - } - } - - tenFlowgrad[intIndex] = fltFlowgrad; - } } - ''', { - 'tenIn': tenIn, - 'tenFlow': tenFlow, - 'tenOutgrad': tenOutgrad, - 'tenIngrad': tenIngrad, - 'tenFlowgrad': tenFlowgrad - }))( - grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), - block=tuple([512, 1, 1]), - args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], - stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) - ) - # end - - return tenIngrad, tenFlowgrad - # end -# end diff --git a/models/traj_ctrlnet.py b/models/traj_ctrlnet.py deleted file mode 100644 index f9e8305c6a2cf8b8bb85d9251ecd549b947175d1..0000000000000000000000000000000000000000 --- a/models/traj_ctrlnet.py +++ /dev/null @@ -1,515 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union -from dataclasses import dataclass - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from diffusers.configuration_utils import register_to_config -from diffusers.utils import BaseOutput - -from models.controlnet_sdv import ControlNetSDVModel, zero_module -# from unimatch.unimatch.geometry import flow_warp -from models.softsplat import softsplat -# from models.hourglass.dense_motion import DenseMotionNetwork -import models.cmp.models as cmp_models -import models.cmp.utils as cmp_utils - -import yaml -import os -import torchvision.transforms as transforms - - -class ArgObj(object): - def __init__(self): - pass - - -class CMP_demo(nn.Module): - def __init__(self, configfn, load_iter): - super().__init__() - args = ArgObj() - with open(configfn) as f: - config = yaml.full_load(f) - for k, v in config.items(): - setattr(args, k, v) - setattr(args, 'load_iter', load_iter) - setattr(args, 'exp_path', os.path.dirname(configfn)) - - self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False) - self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False) - self.model.switch_to('eval') - - self.data_mean = args.data['data_mean'] - self.data_div = args.data['data_div'] - - self.img_transform = transforms.Compose([ - transforms.Normalize(self.data_mean, self.data_div)]) - - self.args = args - self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax']) - torch.cuda.synchronize() - - def run(self, image, sparse, mask): - dtype = image.dtype - image = image * 2 - 1 - self.model.set_input(image.float(), torch.cat([sparse, mask], dim=1).float(), None) - cmp_output = self.model.model(self.model.image_input, self.model.sparse_input) - flow = self.fuser.convert_flow(cmp_output) - if flow.shape[2] != self.model.image_input.shape[2]: - flow = nn.functional.interpolate( - flow, size=self.model.image_input.shape[2:4], - mode="bilinear", align_corners=True) - - return flow.to(dtype) # [b, 2, h, w] - - # tensor_dict = self.model.eval(ret_loss=False) - # flow = tensor_dict['flow_tensors'][0].cpu().numpy().squeeze().transpose(1,2,0) - - # return flow - - - -class FlowControlNetConditioningEmbeddingSVD(nn.Module): - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - - - -class FlowControlNetFirstFrameEncoderLayer(nn.Module): - - def __init__( - self, - c_in, - c_out, - is_downsample=False - ): - super().__init__() - - self.conv_in = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2 if is_downsample else 1) - - def forward(self, feature): - ''' - feature: [b, c, h, w] - ''' - - embedding = self.conv_in(feature) - embedding = F.silu(embedding) - - return embedding - - - -class FlowControlNetFirstFrameEncoder(nn.Module): - def __init__( - self, - c_in=320, - channels=[320, 640, 1280], - downsamples=[True, True, True], - use_zeroconv=True - ): - super().__init__() - - self.encoders = nn.ModuleList([]) - self.zeroconvs = nn.ModuleList([]) - - for channel, downsample in zip(channels, downsamples): - self.encoders.append(FlowControlNetFirstFrameEncoderLayer(c_in, channel, is_downsample=downsample)) - self.zeroconvs.append(zero_module(nn.Conv2d(channel, channel, kernel_size=1)) if use_zeroconv else nn.Identity()) - c_in = channel - - def forward(self, first_frame): - feature = first_frame - deep_features = [] - for encoder, zeroconv in zip(self.encoders, self.zeroconvs): - feature = encoder(feature) - # print(feature.shape) - deep_features.append(zeroconv(feature)) - return deep_features - - -@dataclass -class FlowControlNetOutput(BaseOutput): - """ - The output of [`FlowControlNetOutput`]. - - Args: - down_block_res_samples (`tuple[torch.Tensor]`): - A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should - be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be - used to condition the original UNet's downsampling activations. - mid_down_block_re_sample (`torch.Tensor`): - The activation of the midde block (the lowest sample resolution). Each tensor should be of shape - `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. - Output can be used to condition the original UNet's middle block activation. - """ - - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - controlnet_flow: torch.Tensor - cmp_output: torch.Tensor - - -class FlowControlNet(ControlNetSDVModel): - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 8, - out_channels: int = 4, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "DownBlockSpatioTemporal", - ), - up_block_types: Tuple[str] = ( - "UpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - addition_time_embed_dim: int = 256, - projection_class_embeddings_input_dim: int = 768, - layers_per_block: Union[int, Tuple[int]] = 2, - cross_attention_dim: Union[int, Tuple[int]] = 1024, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), - num_frames: int = 25, - conditioning_channels: int = 3, - conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256), - ): - super().__init__() - - self.flow_encoder = FlowControlNetFirstFrameEncoder() - - # time_embed_dim = block_out_channels[0] * 4 - # blocks_time_embed_dim = time_embed_dim - self.controlnet_cond_embedding = FlowControlNetConditioningEmbeddingSVD( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, - ) - - def get_warped_frames(self, first_frame, flows): - ''' - video_frame: [b, c, w, h] - flows: [b, t-1, c, w, h] - ''' - dtype = first_frame.dtype - warped_frames = [] - for i in range(flows.shape[1]): - warped_frame = softsplat(tenIn=first_frame.float(), tenFlow=flows[:, i].float(), tenMetric=None, strMode='avg').to(dtype) # [b, c, w, h] - warped_frames.append(warped_frame.unsqueeze(1)) # [b, 1, c, w, h] - warped_frames = torch.cat(warped_frames, dim=1) # [b, t-1, c, w, h] - return warped_frames - - def get_cmp_flow(self, frames, sparse_optical_flow, mask): - ''' - frames: [b, 13, 3, 384, 384] (0, 1) tensor - sparse_optical_flow: [b, 13, 2, 384, 384] (-384, 384) tensor - mask: [b, 13, 2, 384, 384] {0, 1} tensor - ''' - b, t, c, h, w = frames.shape - assert h == 384 and w == 384 - frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] - sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] - mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] - cmp_flow, cmp_output = self.run(frames, sparse_optical_flow, mask) # [b*13, 2, 256, 256] - # cmp_flow = self.run(frames.float(), sparse_optical_flow.float(), mask.float()) # [b*13, 2, 256, 256] - cmp_flow = cmp_flow.reshape(b, t, 2, h, w) - return cmp_flow, cmp_output - # return cmp_flow.to(dtype=dtype) - - def run(self, image, sparse, mask): - image = image * 2 - 1 - cmp_output = self.cmp_model(image, torch.cat([sparse, mask], dim=1)) - flow = self.fuser.convert_flow(cmp_output) - if flow.shape[2] != image.shape[2]: - flow = nn.functional.interpolate( - flow, size=image.shape[2:4], - mode="bilinear", align_corners=True) - - return flow, cmp_output # [b, 2, h, w] - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - added_time_ids: torch.Tensor, - controlnet_cond: torch.FloatTensor = None, # [b, 3, h, w] - controlnet_flow: torch.FloatTensor = None, # [b, 13, 2, h, w] - # controlnet_mask: torch.FloatTensor = None, # [b, 13, 2, h, w] - # pixel_values_384: torch.FloatTensor = None, - # sparse_optical_flow_384: torch.FloatTensor = None, - # mask_384: torch.FloatTensor = None, - image_only_indicator: Optional[torch.Tensor] = None, - return_dict: bool = True, - guess_mode: bool = False, - conditioning_scale: float = 1.0, - ) -> Union[FlowControlNetOutput, Tuple]: - - - # print(sample.shape) - # print(controlnet_cond.shape) - # print(controlnet_flow.shape) - - # assert False - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - batch_size, num_frames = sample.shape[:2] - timesteps = timesteps.expand(batch_size) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb) - - time_embeds = self.add_time_proj(added_time_ids.flatten()) - time_embeds = time_embeds.reshape((batch_size, -1)) - time_embeds = time_embeds.to(emb.dtype) - aug_emb = self.add_embedding(time_embeds) - emb = emb + aug_emb - - # Flatten the batch and frames dimensions - # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] - sample = sample.flatten(0, 1) - # Repeat the embeddings num_video_frames times - # emb: [batch, channels] -> [batch * frames, channels] - emb = emb.repeat_interleave(num_frames, dim=0) - # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] - encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) - - - - # hourglass_output = self.hourglass_forward( - # controlnet_cond, controlnet_sparse_flow, controlnet_mask, controlnet_init_flow) # [b, l, 3+2+2, h, w] - - # controlnet_flow = controlnet_init_flow + hourglass_output - - # 2. pre-process - sample = self.conv_in(sample) # [b*l, 320, h//8, w//8] - - # print(controlnet_cond.shape) - - # controlnet cond - if controlnet_cond != None: - # embed 成 64*64,和latent一个shape - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # [b, 320, h//8, w//8] - # sample = sample + controlnet_cond - - # print(controlnet_cond.shape) - - # assert False - controlnet_cond_features = [controlnet_cond] + self.flow_encoder(controlnet_cond) # [4] - - # print(controlnet_cond.shape) - - ''' - torch.Size([2, 320, 32, 32]) - torch.Size([2, 320, 16, 16]) - torch.Size([2, 640, 8, 8]) - torch.Size([2, 1280, 4, 4]) - ''' - - # for x in controlnet_cond_features: - # print(x.shape) - - # assert False - - scales = [8, 16, 32, 64] - scale_flows = {} - fb, fl, fc, fh, fw = controlnet_flow.shape - # print(controlnet_flow.shape) - for scale in scales: - scaled_flow = F.interpolate(controlnet_flow.reshape(-1, fc, fh, fw), scale_factor=1/scale) - scaled_flow = scaled_flow.reshape(fb, fl, fc, fh // scale, fw // scale) / scale - scale_flows[scale] = scaled_flow - - # for k in scale_flows.keys(): - # print(scale_flows[k].shape) - - # assert False - - warped_cond_features = [] - for cond_feature in controlnet_cond_features: - cb, cc, ch, cw = cond_feature.shape - # print(cond_feature.shape) - warped_cond_feature = self.get_warped_frames(cond_feature, scale_flows[fh // ch]) - warped_cond_feature = torch.cat([cond_feature.unsqueeze(1), warped_cond_feature], dim=1) # [b, c, h, w] - wb, wl, wc, wh, ww = warped_cond_feature.shape - # print(warped_cond_feature.shape) - warped_cond_features.append(warped_cond_feature.reshape(wb * wl, wc, wh, ww)) - - # for x in warped_cond_features: - # print(x.shape) - # assert False - - ''' - torch.Size([28, 320, 32, 32]) - torch.Size([28, 320, 16, 16]) - torch.Size([28, 640, 8, 8]) - torch.Size([28, 1280, 4, 4]) - ''' - - image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) - - - count = 0 - length = len(warped_cond_features) - - # print(sample.shape) - # print(warped_cond_features[0].shape) - - # add the warped feature in the first scale - sample = sample + warped_cond_features[count] - count += 1 - - down_block_res_samples = (sample,) - - # print(sample.shape) - - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) - else: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - image_only_indicator=image_only_indicator, - ) - - # print(sample.shape) - # print(warped_cond_features[min(count, length - 1)].shape) - - sample = sample + warped_cond_features[min(count, length - 1)] - count += 1 - - down_block_res_samples += res_samples - - # print(len(res_samples)) - # for i in range(len(res_samples)): - # print(res_samples[i].shape) - - # [28, 320, 32, 32] - # [28, 320, 32, 32] - # [28, 320, 16, 16] - - # [28, 640, 16, 16] - # [28, 640, 16, 16] - # [28, 640, 8, 8] - - # [28, 1280, 8, 8] - # [28, 1280, 8, 8] - # [28, 1280, 4, 4] - - # [28, 1280, 4, 4] - # [28, 1280, 4, 4] - - # print(sample.shape) - # print(warped_cond_features[-1].shape) - - # add the warped feature in the last scale - sample = sample + warped_cond_features[-1] - - # 4. mid - sample = self.mid_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) # [b*l, 1280, h // 64, w // 64] - - # print(sample.shape) - - # assert False - - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(sample) - - # 6. scaling - - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale - - # for sample in down_block_res_samples: - # print(sample.shape) - # print(mid_block_res_sample.shape) - # assert False - - if not return_dict: - return (down_block_res_samples, mid_block_res_sample, controlnet_flow, None) - - return FlowControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, controlnet_flow=controlnet_flow, cmp_output=None - ) - diff --git a/models/unet_spatio_temporal_condition_controlnet.py b/models/unet_spatio_temporal_condition_controlnet.py deleted file mode 100644 index 1361eeb83ab634ed298c05d3dbddda2b56376c8b..0000000000000000000000000000000000000000 --- a/models/unet_spatio_temporal_condition_controlnet.py +++ /dev/null @@ -1,504 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import UNet2DConditionLoadersMixin -from diffusers.utils import BaseOutput, logging -from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor -from diffusers.models.embeddings import TimestepEmbedding, Timesteps -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class UNetSpatioTemporalConditionOutput(BaseOutput): - """ - The output of [`UNetSpatioTemporalConditionModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor = None - - -class UNetSpatioTemporalConditionControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): - r""" - A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample - shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): - The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): - The tuple of upsample blocks to use. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - addition_time_embed_dim: (`int`, defaults to 256): - Dimension to to encode the additional time ids. - projection_class_embeddings_input_dim (`int`, defaults to 768): - The dimension of the projection of encoded `added_time_ids`. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], - [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. - num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): - The number of attention heads. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 8, - out_channels: int = 4, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "CrossAttnDownBlockSpatioTemporal", - "DownBlockSpatioTemporal", - ), - up_block_types: Tuple[str] = ( - "UpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - "CrossAttnUpBlockSpatioTemporal", - ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - addition_time_embed_dim: int = 256, - projection_class_embeddings_input_dim: int = 768, - layers_per_block: Union[int, Tuple[int]] = 2, - cross_attention_dim: Union[int, Tuple[int]] = 1024, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), - num_frames: int = 25, - ): - super().__init__() - - self.sample_size = sample_size - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - - # input - self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[0], - kernel_size=3, - padding=1, - ) - - # time - time_embed_dim = block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - blocks_time_embed_dim = time_embed_dim - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=1e-5, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - resnet_act_fn="silu", - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlockSpatioTemporal( - block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - transformer_layers_per_block=transformer_layers_per_block[-1], - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - ) - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_num_attention_heads = list(reversed(num_attention_heads)) - reversed_layers_per_block = list(reversed(layers_per_block)) - reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=blocks_time_embed_dim, - add_upsample=add_upsample, - resnet_eps=1e-5, - resolution_idx=i, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - resnet_act_fn="silu", - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) - self.conv_act = nn.SiLU() - - self.conv_out = nn.Conv2d( - block_out_channels[0], - out_channels, - kernel_size=3, - padding=1, - ) - - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors( - name: str, - module: torch.nn.Module, - processors: Dict[str, AttentionProcessor], - ): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking - def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: - """ - Sets the attention processor to use [feed forward - chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). - - Parameters: - chunk_size (`int`, *optional*): - The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually - over each tensor of dim=`dim`. - dim (`int`, *optional*, defaults to `0`): - The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) - or dim=1 (sequence length). - """ - if dim not in [0, 1]: - raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") - - # By default chunk size is 1 - chunk_size = chunk_size or 1 - - def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): - if hasattr(module, "set_chunk_feed_forward"): - module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) - - for child in module.children(): - fn_recursive_feed_forward(child, chunk_size, dim) - - for module in self.children(): - fn_recursive_feed_forward(module, chunk_size, dim) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - return_dict: bool = True, - added_time_ids: torch.Tensor=None, - ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: - r""" - The [`UNetSpatioTemporalConditionModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.FloatTensor`): - The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. - added_time_ids: (`torch.FloatTensor`): - The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal - embeddings and added to the time embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain - tuple. - Returns: - [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. - """ - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - batch_size, num_frames = sample.shape[:2] - timesteps = timesteps.expand(batch_size) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb) - - time_embeds = self.add_time_proj(added_time_ids.flatten()) - time_embeds = time_embeds.reshape((batch_size, -1)) - time_embeds = time_embeds.to(emb.dtype) - aug_emb = self.add_embedding(time_embeds) - emb = emb + aug_emb - - # Flatten the batch and frames dimensions - # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] - sample = sample.flatten(0, 1) - # Repeat the embeddings num_video_frames times - # emb: [batch, channels] -> [batch * frames, channels] - emb = emb.repeat_interleave(num_frames, dim=0) - # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] - encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) - - # 2. pre-process - sample = self.conv_in(sample) - - image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) - - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) - else: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - image_only_indicator=image_only_indicator, - ) - - down_block_res_samples += res_samples - - new_down_block_res_samples = () - - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - - # 4. mid - sample = self.mid_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) - sample = sample + mid_block_additional_residual - - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - image_only_indicator=image_only_indicator, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - image_only_indicator=image_only_indicator, - ) - - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - # 7. Reshape back to original shape - sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) - - if not return_dict: - return (sample,) - - return UNetSpatioTemporalConditionOutput(sample=sample) diff --git a/pipeline/pipeline.py b/pipeline/pipeline.py deleted file mode 100644 index a4e30aab19574d9a94973e27315c7f8ef01cf687..0000000000000000000000000000000000000000 --- a/pipeline/pipeline.py +++ /dev/null @@ -1,660 +0,0 @@ -import inspect -from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection - -from diffusers.image_processor import VaeImageProcessor -from diffusers.models import AutoencoderKLTemporalDecoder -from diffusers.utils import BaseOutput, logging -from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from utils.scheduling_euler_discrete_karras_fix import EulerDiscreteScheduler - -from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel -from models.traj_ctrlnet import FlowControlNet as DragControlNet -from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def _get_add_time_ids( - noise_aug_strength, - dtype, - batch_size, - fps=4, - motion_bucket_id=128, - unet=None, - ): - add_time_ids = [fps, motion_bucket_id, noise_aug_strength] - - passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids) - expected_add_embed_dim = unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - # add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) - - - return add_time_ids - - -def _append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") - return x[(...,) + (None,) * dims_to_append] - - -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - return outputs - - -@dataclass -class FlowControlNetPipelineOutput(BaseOutput): - r""" - Output class for zero-shot text-to-video pipeline. - - Args: - frames (`[List[PIL.Image.Image]`, `np.ndarray`]): - List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, - num_channels)`. - """ - - frames: Union[List[PIL.Image.Image], np.ndarray] - controlnet_flow: torch.Tensor - - -class FlowControlNetPipeline(DiffusionPipeline): - model_cpu_offload_seq = "image_encoder->unet->vae" - _callback_tensor_inputs = ["latents"] - def __init__( - self, - vae: AutoencoderKLTemporalDecoder, - image_encoder: CLIPVisionModelWithProjection, - unet: UNetSpatioTemporalConditionControlNetModel, - drag_controlnet: DragControlNet, - face_controlnet: FaceControlNet, - scheduler: EulerDiscreteScheduler, - feature_extractor: CLIPImageProcessor, - ): - super().__init__() - - self.register_modules( - vae=vae, - image_encoder=image_encoder, - drag_controlnet=drag_controlnet, - face_controlnet=face_controlnet, - unet=unet, - scheduler=scheduler, - feature_extractor=feature_extractor, - ) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - - def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.image_processor.pil_to_numpy(image) - image = self.image_processor.numpy_to_pt(image) - - #image = image.unsqueeze(0) - image = _resize_with_antialiasing(image, (224, 224)) - - image = image.to(device=device, dtype=dtype) - image_embeddings = self.image_encoder(image).image_embeds - image_embeddings = image_embeddings.unsqueeze(1) - - # duplicate image embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = image_embeddings.shape - image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) - image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - negative_image_embeddings = torch.zeros_like(image_embeddings) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) - - return image_embeddings - - def _encode_vae_image( - self, - image: torch.Tensor, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - ): - image = image.to(device=device) - image_latents = self.vae.encode(image).latent_dist.mode() - - if do_classifier_free_guidance: - negative_image_latents = torch.zeros_like(image_latents) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_latents = torch.cat([negative_image_latents, image_latents]) - - # duplicate image_latents for each generation per prompt, using mps friendly method - image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) - - return image_latents - - def _get_add_time_ids( - self, - fps, - motion_bucket_id, - noise_aug_strength, - dtype, - batch_size, - num_videos_per_prompt, - do_classifier_free_guidance, - ): - add_time_ids = [fps, motion_bucket_id, noise_aug_strength] - - passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) - expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) - - if do_classifier_free_guidance: - add_time_ids = torch.cat([add_time_ids, add_time_ids]) - - return add_time_ids - - def decode_latents(self, latents, num_frames, decode_chunk_size=14): - # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] - latents = latents.flatten(0, 1) - - latents = 1 / self.vae.config.scaling_factor * latents - - accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) - - # decode decode_chunk_size frames at a time to avoid OOM - frames = [] - for i in range(0, latents.shape[0], decode_chunk_size): - num_frames_in = latents[i : i + decode_chunk_size].shape[0] - decode_kwargs = {} - if accepts_num_frames: - # we only pass num_frames_in if it's expected - decode_kwargs["num_frames"] = num_frames_in - - frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample - frames.append(frame) - frames = torch.cat(frames, dim=0) - - # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] - frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - frames = frames.float() - return frames - - def check_inputs(self, image, height, width): - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" - f" {type(image)}" - ) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - def prepare_latents( - self, - batch_size, - num_frames, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - shape = ( - batch_size, - num_frames, - num_channels_latents // 2, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - - # print(shape) - - # assert False - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - - @property - def num_timesteps(self): - return self._num_timesteps - - @torch.no_grad() - def __call__( - self, - image: Union[PIL.Image.Image, torch.FloatTensor], - controlnet_condition: torch.FloatTensor = None, - - controlnet_flow: torch.FloatTensor = None, - landmarks: torch.FloatTensor = None, - - drag_flow: torch.FloatTensor = None, - mask: torch.FloatTensor = None, - - height: int = 576, - width: int = 1024, - num_frames: Optional[int] = None, - num_inference_steps: int = 25, - min_guidance_scale: float = 1.0, - max_guidance_scale: float = 3.0, - fps: int = 7, - motion_bucket_id: int = 127, - noise_aug_strength: int = 0.02, - decode_chunk_size: Optional[int] = None, - num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - return_dict: bool = True, - ctrl_scale_traj=1.0, - ctrl_scale_ldmk=1.0, - batch_size=1, - ): - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - num_frames = num_frames if num_frames is not None else self.unet.config.num_frames - decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames - - # 1. Check inputs. Raise error if not correct - self.check_inputs(image, height, width) - - # 2. Define call parameters - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = max_guidance_scale > 1.0 - - # 3. Encode input image - image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) - - # NOTE: Stable Diffusion Video was conditioned on fps - 1, which - # is why it is reduced here. - # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 - fps = fps - 1 - - # 4. Encode input image using VAE - image = self.image_processor.preprocess(image, height=height, width=width) - noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) - image = image + noise_aug_strength * noise - - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - if needs_upcasting: - self.vae.to(dtype=torch.float32) - - image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) - image_latents = image_latents.to(image_embeddings.dtype) - - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - - # Repeat the image latents for each frame so we can concatenate them with the noise - # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] - image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) - #image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents - - # 5. Get Added Time IDs - added_time_ids = self._get_add_time_ids( - fps, - motion_bucket_id, - noise_aug_strength, - image_embeddings.dtype, - batch_size, - num_videos_per_prompt, - do_classifier_free_guidance, - ) - added_time_ids = added_time_ids.to(device) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_frames, - num_channels_latents, - height, - width, - image_embeddings.dtype, - device, - generator, - latents, - ) - - - #prepare controlnet condition - controlnet_condition = self.image_processor.preprocess(controlnet_condition, height=height, width=width) - # controlnet_condition = controlnet_condition.unsqueeze(0) - controlnet_condition = torch.cat([controlnet_condition] * 2) if do_classifier_free_guidance else controlnet_condition - controlnet_condition = controlnet_condition.to(device, latents.dtype) - - controlnet_flow = torch.cat([controlnet_flow] * 2) if do_classifier_free_guidance else controlnet_flow - controlnet_flow = controlnet_flow.to(device, latents.dtype) - - drag_flow = torch.cat([drag_flow] * 2) if do_classifier_free_guidance else drag_flow - drag_flow = drag_flow.to(device, latents.dtype) - - # mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask - mask = mask.to(device, latents.dtype) - - landmarks = torch.cat([landmarks] * 2) if do_classifier_free_guidance else landmarks - landmarks = landmarks.to(device, latents.dtype) - - # 7. Prepare guidance scale - # modified num_frames to window_size here !!!!!!!!!!!!!! - guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) - guidance_scale = guidance_scale.to(device, latents.dtype) - guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) - guidance_scale = _append_dims(guidance_scale, latents.ndim) - - self._guidance_scale = guidance_scale - - noise_aug_strength = 0.02 #"¯\_(ツ)_/¯ - added_time_ids = _get_add_time_ids( - noise_aug_strength, - image_embeddings.dtype, - batch_size, - 6, - 128, - unet=self.unet, - ) - added_time_ids = torch.cat([added_time_ids] * 2) - added_time_ids = added_time_ids.to(latents.device) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - - # expand the latents if we are doing classifier free guidance - latent_model_input_tmp = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input_tmp = self.scheduler.scale_model_input(latent_model_input_tmp, t) - - # Concatenate image_latents over channels dimention - latent_model_input_tmp = torch.cat([latent_model_input_tmp, image_latents], dim=2) - - down_res_face_tmp, mid_res_face_tmp, controlnet_flow, _ = self.face_controlnet( - latent_model_input_tmp, - t, - encoder_hidden_states=image_embeddings, - controlnet_cond=controlnet_condition, - controlnet_flow=controlnet_flow, - landmarks=landmarks, - added_time_ids=added_time_ids, - conditioning_scale=ctrl_scale_ldmk, - guess_mode=False, - return_dict=False, - ) - - down_res_drag_tmp, mid_res_drag_tmp, _, _ = self.drag_controlnet( - latent_model_input_tmp, - t, - encoder_hidden_states=image_embeddings, - controlnet_cond=controlnet_condition, - controlnet_flow=drag_flow, - added_time_ids=added_time_ids, - conditioning_scale=ctrl_scale_traj, - guess_mode=False, - return_dict=False, - ) - - down_block_res_samples_tmp = [] - for down_face, down_drag in zip(down_res_face_tmp, down_res_drag_tmp): - _, _, h, w = down_face.shape - mask_tmp = F.interpolate(mask, (h, w), mode='nearest') - res = down_face * mask_tmp + down_drag * (1 - mask_tmp) - down_block_res_samples_tmp.append(res) - - _, _, h, w = mid_res_face_tmp.shape - mask_tmp = F.interpolate(mask, (h, w), mode='nearest') - mid_block_res_sample_tmp = mid_res_face_tmp * mask_tmp + mid_res_drag_tmp * (1 - mask_tmp) - - # predict the noise residual - noise_pred_tmp = self.unet( - latent_model_input_tmp, - t, - encoder_hidden_states=image_embeddings, - down_block_additional_residuals=down_block_res_samples_tmp, - mid_block_additional_residual=mid_block_res_sample_tmp, - added_time_ids=added_time_ids, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond_tmp, noise_pred_cond_tmp = noise_pred_tmp.chunk(2) - noise_pred_tmp = noise_pred_uncond_tmp + self.guidance_scale * (noise_pred_cond_tmp - noise_pred_uncond_tmp) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred_tmp, t, latents).prev_sample - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if not output_type == "latent": - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - frames = self.decode_latents(latents.to(self.vae.dtype), num_frames, decode_chunk_size) - frames = tensor2vid(frames, self.image_processor, output_type=output_type) - else: - frames = latents - - self.maybe_free_model_hooks() - - if not return_dict: - return frames, controlnet_flow - - return FlowControlNetPipelineOutput( - frames=frames, - controlnet_flow=controlnet_flow - ) - - -# resizing utils -# TODO: clean up later -def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): - - if input.ndim == 3: - input = input.unsqueeze(0) # Add a batch dimension - - h, w = input.shape[-2:] - factors = (h / size[0], w / size[1]) - - # First, we have to determine sigma - # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 - sigmas = ( - max((factors[0] - 1.0) / 2.0, 0.001), - max((factors[1] - 1.0) / 2.0, 0.001), - ) - - # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma - # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 - # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now - ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) - - # Make sure it is odd - if (ks[0] % 2) == 0: - ks = ks[0] + 1, ks[1] - - if (ks[1] % 2) == 0: - ks = ks[0], ks[1] + 1 - - input = _gaussian_blur2d(input, ks, sigmas) - - output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) - return output - - -def _compute_padding(kernel_size): - """Compute padding tuple.""" - # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) - # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad - if len(kernel_size) < 2: - raise AssertionError(kernel_size) - computed = [k - 1 for k in kernel_size] - - # for even kernels we need to do asymmetric padding :( - out_padding = 2 * len(kernel_size) * [0] - - for i in range(len(kernel_size)): - computed_tmp = computed[-(i + 1)] - - pad_front = computed_tmp // 2 - pad_rear = computed_tmp - pad_front - - out_padding[2 * i + 0] = pad_front - out_padding[2 * i + 1] = pad_rear - - return out_padding - - -def _filter2d(input, kernel): - # prepare kernel - b, c, h, w = input.shape - tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) - - tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) - - height, width = tmp_kernel.shape[-2:] - - padding_shape: list[int] = _compute_padding([height, width]) - input = torch.nn.functional.pad(input, padding_shape, mode="reflect") - - # kernel and input tensor reshape to align element-wise or batch-wise params - tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) - input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) - - # convolve the tensor with the kernel. - output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) - - out = output.view(b, c, h, w) - return out - - -def _gaussian(window_size: int, sigma): - if isinstance(sigma, float): - sigma = torch.tensor([[sigma]]) - - batch_size = sigma.shape[0] - - x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) - - if window_size % 2 == 0: - x = x + 0.5 - - gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) - - return gauss / gauss.sum(-1, keepdim=True) - - -def _gaussian_blur2d(input, kernel_size, sigma): - if isinstance(sigma, tuple): - sigma = torch.tensor([sigma], dtype=input.dtype) - else: - sigma = sigma.to(dtype=input.dtype) - - ky, kx = int(kernel_size[0]), int(kernel_size[1]) - bs = sigma.shape[0] - kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) - kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) - out_x = _filter2d(input, kernel_x[..., None, :]) - out = _filter2d(out_x, kernel_y[..., None]) - - return out - - -def get_views(video_length, window_size=14, stride=7): - num_blocks_time = (video_length - window_size) // stride + 1 - views = [] - for i in range(num_blocks_time): - t_start = int(i * stride) - t_end = t_start + window_size - views.append((t_start,t_end)) - return views \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index bde8762942e5c906456b7ac7973574783e762967..0000000000000000000000000000000000000000 --- a/requirements.txt +++ /dev/null @@ -1,21 +0,0 @@ -diffusers==0.24.0 -gradio==4.5.0 -scikit-image -torch==2.0.1 -torchvision==0.15.2 -einops==0.8.0 -accelerate==0.30.1 -transformers==4.41.1 -colorlog==6.8.2 -cupy-cuda117==10.6.0 -av==12.1.0 -gpustat==1.1.1 -trimesh==4.4.1 -facexlib==0.3.0 -omegaconf==2.3.0 -librosa==0.10.2.post1 -mediapipe==0.10.14 -kornia==0.7.2 -yacs==0.1.8 -gfpgan==1.3.8 -numpy==1.23.0 \ No newline at end of file diff --git a/run_gradio_audio_driven.py b/run_gradio_audio_driven.py deleted file mode 100644 index a82fdeab9ddce69ef91cd91e3fd7dc40ad477b82..0000000000000000000000000000000000000000 --- a/run_gradio_audio_driven.py +++ /dev/null @@ -1,1240 +0,0 @@ -import gradio as gr -import numpy as np -import cv2 -import os -from PIL import Image -from scipy.interpolate import PchipInterpolator -import torchvision -import time -from tqdm import tqdm -import imageio - -import torch -import torch.nn.functional as F -import torchvision -import torchvision.transforms as transforms -from einops import repeat - -from pydub import AudioSegment - -from packaging import version - -from accelerate.utils import set_seed -from transformers import CLIPVisionModelWithProjection - -from diffusers import AutoencoderKLTemporalDecoder -from diffusers.utils.import_utils import is_xformers_available - -from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel -from pipeline.pipeline import FlowControlNetPipeline -from models.traj_ctrlnet import FlowControlNet as DragControlNet, CMP_demo -from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet - -from utils.flow_viz import flow_to_image -from utils.utils import split_filename, image2arr, image2pil, ensure_dirname - - -output_dir = "Output_audio_driven" - - -ensure_dirname(output_dir) - - -def draw_landmarks_cv2(image, landmarks): - for i, point in enumerate(landmarks): - cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 0, 255), -1) - return image - - -def sample_optical_flow(A, B, h, w): - b, l, k, _ = A.shape - - sparse_optical_flow = torch.zeros((b, l, h, w, 2), dtype=B.dtype, device=B.device) - mask = torch.zeros((b, l, h, w), dtype=torch.uint8, device=B.device) - - x_coords = A[..., 0].long() - y_coords = A[..., 1].long() - - x_coords = torch.clip(x_coords, 0, h - 1) - y_coords = torch.clip(y_coords, 0, w - 1) - - b_idx = torch.arange(b)[:, None, None].repeat(1, l, k) - l_idx = torch.arange(l)[None, :, None].repeat(b, 1, k) - - sparse_optical_flow[b_idx, l_idx, x_coords, y_coords] = B - - mask[b_idx, l_idx, x_coords, y_coords] = 1 - - mask = mask.unsqueeze(-1).repeat(1, 1, 1, 1, 2) - - return sparse_optical_flow, mask - - -@torch.no_grad() -def get_sparse_flow(landmarks, h, w, t): - - landmarks = torch.flip(landmarks, dims=[3]) - - pose_flow = (landmarks - landmarks[:, 0:1].repeat(1, t, 1, 1))[:, 1:] # 前向光流 - according_poses = landmarks[:, 0:1].repeat(1, t - 1, 1, 1) - - pose_flow = torch.flip(pose_flow, dims=[3]) - - b, t, K, _ = pose_flow.shape - - sparse_optical_flow, mask = sample_optical_flow(according_poses, pose_flow, h, w) - - return sparse_optical_flow.permute(0, 1, 4, 2, 3), mask.permute(0, 1, 4, 2, 3) - - - -def sample_inputs_face(first_frame, landmarks): - - pc, ph, pw = first_frame.shape - landmarks = landmarks.unsqueeze(0) - - pl = landmarks.shape[1] - - sparse_optical_flow, mask = get_sparse_flow(landmarks, ph, pw, pl) - - if ph != 384 or pw != 384: - - first_frame_384 = F.interpolate(first_frame.unsqueeze(0), (384, 384)) # [3, 384, 384] - - landmarks_384 = torch.zeros_like(landmarks) - landmarks_384[:, :, :, 0] = landmarks[:, :, :, 0] / pw * 384 - landmarks_384[:, :, :, 1] = landmarks[:, :, :, 1] / ph * 384 - - sparse_optical_flow_384, mask_384 = get_sparse_flow(landmarks_384, 384, 384, pl) - - else: - first_frame_384, landmarks_384 = first_frame, landmarks - sparse_optical_flow_384, mask_384 = sparse_optical_flow, mask - - controlnet_image = first_frame.unsqueeze(0) - - return controlnet_image, sparse_optical_flow, mask, first_frame_384, sparse_optical_flow_384, mask_384 - - - -PARTS = [ - ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)), - ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)), - ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)), - ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)), - ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)), - ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)), - ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)), - ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)), - ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)), - ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)), - ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)), - ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)), - ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)), - ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)), - ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)), -] - - -def draw_landmarks(keypoints, h, w): - - image = np.zeros((h, w, 3)) - - for name, indices, color in PARTS: - indices = np.array(indices) - 1 - current_part_keypoints = keypoints[indices] - - for i in range(len(indices) - 1): - x1, y1 = current_part_keypoints[i] - x2, y2 = current_part_keypoints[i + 1] - cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2) - - return image - - -def divide_points_afterinterpolate(resized_all_points, motion_brush_mask): - k = resized_all_points.shape[0] - starts = resized_all_points[:, 0] # [K, 2] - - in_masks = [] - out_masks = [] - - for i in range(k): - x, y = int(starts[i][1]), int(starts[i][0]) - if motion_brush_mask[x][y] == 255: - in_masks.append(resized_all_points[i]) - else: - out_masks.append(resized_all_points[i]) - - in_masks = np.array(in_masks) - out_masks = np.array(out_masks) - - return in_masks, out_masks - - -def get_sparseflow_and_mask_forward( - resized_all_points, - n_steps, H, W, - is_backward_flow=False - ): - - K = resized_all_points.shape[0] - - starts = resized_all_points[:, 0] - - interpolated_ends = resized_all_points[:, 1:] - - s_flow = np.zeros((K, n_steps, H, W, 2)) - mask = np.zeros((K, n_steps, H, W)) - - for k in range(K): - for i in range(n_steps): - start, end = starts[k], interpolated_ends[k][i] - flow = np.int64(end - start) * (-1 if is_backward_flow is True else 1) - s_flow[k][i][int(start[1]), int(start[0])] = flow - mask[k][i][int(start[1]), int(start[0])] = 1 - - s_flow = np.sum(s_flow, axis=0) - mask = np.sum(mask, axis=0) - - return s_flow, mask - - -def init_models(pretrained_model_name_or_path, weight_dtype, device='cuda', enable_xformers_memory_efficient_attention=False, allow_tf32=False): - - drag_ckpt = "./ckpts/mofa/traj_controlnet" - face_ckpt = "./ckpts/mofa/ldmk_controlnet" - - print('start loading models...') - - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - pretrained_model_name_or_path, subfolder="image_encoder", revision=None, variant="fp16" - ) - vae = AutoencoderKLTemporalDecoder.from_pretrained( - pretrained_model_name_or_path, subfolder="vae", revision=None, variant="fp16") - unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained( - pretrained_model_name_or_path, - subfolder="unet", - low_cpu_mem_usage=True, - variant="fp16", - ) - - drag_controlnet = DragControlNet.from_pretrained(drag_ckpt) - face_controlnet = FaceControlNet.from_pretrained(face_ckpt) - - cmp = CMP_demo( - './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml', - 42000 - ).to(device) - cmp.requires_grad_(False) - - # Freeze vae and image_encoder - vae.requires_grad_(False) - image_encoder.requires_grad_(False) - unet.requires_grad_(False) - drag_controlnet.requires_grad_(False) - face_controlnet.requires_grad_(False) - - # Move image_encoder and vae to gpu and cast to weight_dtype - image_encoder.to(device, dtype=weight_dtype) - vae.to(device, dtype=weight_dtype) - unet.to(device, dtype=weight_dtype) - drag_controlnet.to(device, dtype=weight_dtype) - face_controlnet.to(device, dtype=weight_dtype) - - if enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - print( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError( - "xformers is not available. Make sure it is installed correctly") - - if allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - pipeline = FlowControlNetPipeline.from_pretrained( - pretrained_model_name_or_path, - unet=unet, - face_controlnet=face_controlnet, - drag_controlnet=drag_controlnet, - image_encoder=image_encoder, - vae=vae, - torch_dtype=weight_dtype, - ) - pipeline = pipeline.to(device) - - print('models loaded.') - - return pipeline, cmp - - -def interpolate_trajectory(points, n_points): - x = [point[0] for point in points] - y = [point[1] for point in points] - - t = np.linspace(0, 1, len(points)) - - fx = PchipInterpolator(t, x) - fy = PchipInterpolator(t, y) - - new_t = np.linspace(0, 1, n_points) - - new_x = fx(new_t) - new_y = fy(new_t) - new_points = list(zip(new_x, new_y)) - - return new_points - - -def visualize_drag_v2(background_image_path, splited_tracks, width, height): - trajectory_maps = [] - - background_image = Image.open(background_image_path).convert('RGBA') - background_image = background_image.resize((width, height)) - w, h = background_image.size - transparent_background = np.array(background_image) - transparent_background[:, :, -1] = 128 - transparent_background = Image.fromarray(transparent_background) - - # Create a transparent layer with the same size as the background image - transparent_layer = np.zeros((h, w, 4)) - for splited_track in splited_tracks: - if len(splited_track) > 1: - splited_track = interpolate_trajectory(splited_track, 16) - splited_track = splited_track[:16] - for i in range(len(splited_track)-1): - start_point = (int(splited_track[i][0]), int(splited_track[i][1])) - end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) - vx = end_point[0] - start_point[0] - vy = end_point[1] - start_point[1] - arrow_length = np.sqrt(vx**2 + vy**2) - if i == len(splited_track)-2: - cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length) - else: - cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2) - else: - cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1) - - transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) - trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) - trajectory_maps.append(trajectory_map) - return trajectory_maps, transparent_layer - - -class Drag: - def __init__(self, device, height, width, model_length): - self.device = device - - pretrained_model_name_or_path = "./ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1" - - self.device = 'cuda' - self.weight_dtype = torch.float16 - - self.pipeline, self.cmp = init_models( - pretrained_model_name_or_path, - weight_dtype=self.weight_dtype, - device=self.device, - ) - - self.height = height - self.width = width - self.model_length = model_length - - def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None): - - b, t, c, h, w = frames.shape - assert h == 384 and w == 384 - frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] - sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] - mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] - - cmp_flow = [] - for i in range(b*t): - tmp_flow = self.cmp.run(frames[i:i+1], sparse_optical_flow[i:i+1], mask[i:i+1]) # [1, 2, 256, 256] - cmp_flow.append(tmp_flow) - cmp_flow = torch.cat(cmp_flow, dim=0) # [b*13, 2, 256, 256] - - if brush_mask is not None: - brush_mask = torch.from_numpy(brush_mask) / 255. - brush_mask = brush_mask.to(cmp_flow.device, dtype=cmp_flow.dtype) - brush_mask = brush_mask.unsqueeze(0).unsqueeze(0) - cmp_flow = cmp_flow * brush_mask - - cmp_flow = cmp_flow.reshape(b, t, 2, h, w) - - return cmp_flow - - - def get_flow(self, pixel_values_384, sparse_optical_flow_384, mask_384, motion_brush_mask=None): - - fb, fl, fc, _, _ = pixel_values_384.shape - - controlnet_flow = self.get_cmp_flow( - pixel_values_384[:, 0:1, :, :, :].repeat(1, fl, 1, 1, 1), - sparse_optical_flow_384, - mask_384, motion_brush_mask - ) - - if self.height != 384 or self.width != 384: - scales = [self.height / 384, self.width / 384] - controlnet_flow = F.interpolate(controlnet_flow.flatten(0, 1), (self.height, self.width), mode='nearest').reshape(fb, fl, 2, self.height, self.width) - controlnet_flow[:, :, 0] *= scales[1] - controlnet_flow[:, :, 1] *= scales[0] - - return controlnet_flow - - @torch.no_grad() - def forward_sample(self, save_root, first_frame_path, audio_path, hint_path, input_drag_384_inmask, input_drag_384_outmask, input_first_frame, input_mask_384_inmask, input_mask_384_outmask, in_mask_flag, out_mask_flag, motion_brush_mask_384=None, ldmk_mask_mask_origin=None, ctrl_scale_traj=1., ctrl_scale_ldmk=1., ldmk_render='sadtalker'): - - seed = 42 - - num_frames = self.model_length - - set_seed(seed) - - input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) - input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) - input_first_frame_pil = Image.fromarray(np.uint8(input_first_frame[0].cpu().permute(1, 2, 0)*255)) - height, width = input_first_frame.shape[-2:] - - input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - - input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) - mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) - input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) - mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) - - input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) - - if in_mask_flag: - flow_inmask = self.get_flow( - input_first_frame_384, - input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 - ) - else: - fb, fl = mask_384_inmask.shape[:2] - flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - if out_mask_flag: - flow_outmask = self.get_flow( - input_first_frame_384, - input_drag_384_outmask, mask_384_outmask - ) - else: - fb, fl = mask_384_outmask.shape[:2] - flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - inmask_no_zero = (flow_inmask != 0).all(dim=2) - inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) - - controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) - - ldmk_controlnet_flow, ldmk_pose_imgs, landmarks, num_frames = self.get_landmarks(save_root, first_frame_path, audio_path, input_first_frame[0], self.model_length, ldmk_render=ldmk_render) - - ldmk_flow_len = ldmk_controlnet_flow.shape[1] - drag_flow_len = controlnet_flow.shape[1] - repeat_num = ldmk_flow_len // drag_flow_len + 1 - drag_controlnet_flow = controlnet_flow.repeat(1, repeat_num, 1, 1, 1) - drag_controlnet_flow = drag_controlnet_flow[:, :ldmk_flow_len] - - ldmk_mask_mask_origin = ldmk_mask_mask_origin.unsqueeze(0).unsqueeze(0) # [1, 1, h, w] - - val_output = self.pipeline( - input_first_frame_pil, - input_first_frame_pil, - - ldmk_controlnet_flow, - ldmk_pose_imgs, - - drag_controlnet_flow, - ldmk_mask_mask_origin, - - height=height, - width=width, - num_frames=num_frames, - decode_chunk_size=8, - motion_bucket_id=127, - fps=7, - noise_aug_strength=0.02, - ctrl_scale_traj=ctrl_scale_traj, - ctrl_scale_ldmk=ctrl_scale_ldmk, - ) - - video_frames, estimated_flow = val_output.frames[0], val_output.controlnet_flow - - for i in range(num_frames): - img = video_frames[i] - video_frames[i] = np.array(img) - - video_frames = np.array(video_frames) - - outputs = self.save_video(ldmk_pose_imgs, first_frame_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow) - - return outputs - - def save_video(self, pose_imgs, image_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow, outputs=dict()): - - pose_img_nps = (pose_imgs[0].permute(0, 2, 3, 1).cpu().numpy()*255).astype(np.uint8) - - cv2_firstframe = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) - cv2_hint = cv2.cvtColor(cv2.imread(hint_path), cv2.COLOR_BGR2RGB) - - viz_landmarks = [] - for k in tqdm(range(len(landmarks))): - im = draw_landmarks_cv2(video_frames[k].copy(), landmarks[k]) - viz_landmarks.append(im) - viz_landmarks = np.stack(viz_landmarks) - - viz_esti_flows = [] - for i in range(estimated_flow.shape[1]): - temp_flow = estimated_flow[0][i].permute(1, 2, 0) - viz_esti_flows.append(flow_to_image(temp_flow)) - viz_esti_flows = [np.uint8(np.ones_like(viz_esti_flows[-1]) * 255)] + viz_esti_flows - viz_esti_flows = np.stack(viz_esti_flows) # [t-1, h, w, c] - - viz_drag_flows = [] - for i in range(drag_controlnet_flow.shape[1]): - temp_flow = drag_controlnet_flow[0][i].permute(1, 2, 0) - viz_drag_flows.append(flow_to_image(temp_flow)) - viz_drag_flows = [np.uint8(np.ones_like(viz_drag_flows[-1]) * 255)] + viz_drag_flows - viz_drag_flows = np.stack(viz_drag_flows) # [t-1, h, w, c] - - out_nps = [] - for plen in range(video_frames.shape[0]): - out_nps.append(video_frames[plen]) - out_nps = np.stack(out_nps) - - first_frames = np.stack([cv2_firstframe] * out_nps.shape[0]) - hints = np.stack([cv2_hint] * out_nps.shape[0]) - - total_nps = np.concatenate([ - first_frames, hints, viz_drag_flows, viz_esti_flows, pose_img_nps, viz_landmarks, out_nps - ], axis=2) - - video_frames_tensor = torch.from_numpy(video_frames).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - - outputs['logits_imgs'] = video_frames_tensor - outputs['traj_flows'] = torch.from_numpy(viz_drag_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['ldmk_flows'] = torch.from_numpy(viz_esti_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['viz_ldmk'] = torch.from_numpy(pose_img_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['out_with_ldmk'] = torch.from_numpy(viz_landmarks).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['total'] = torch.from_numpy(total_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - - return outputs - - @torch.no_grad() - def get_cmp_flow_from_tracking_points(self, tracking_points, motion_brush_mask, first_frame_path): - - original_width, original_height = self.width, self.height - - flow_div = self.model_length - - input_all_points = tracking_points.constructor_args['value'] - - if len(input_all_points) == 0 or len(input_all_points[-1]) == 1: - return np.uint8(np.ones((original_width, original_height, 3))*255) - - resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] - resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] - - new_resized_all_points = [] - new_resized_all_points_384 = [] - for tnum in range(len(resized_all_points)): - new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) - new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) - - resized_all_points = np.array(new_resized_all_points) - resized_all_points_384 = np.array(new_resized_all_points_384) - - motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) - - resized_all_points_384_inmask, resized_all_points_384_outmask = \ - divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) - - in_mask_flag = False - out_mask_flag = False - - if resized_all_points_384_inmask.shape[0] != 0: - in_mask_flag = True - input_drag_384_inmask, input_mask_384_inmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_inmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_inmask, input_mask_384_inmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - if resized_all_points_384_outmask.shape[0] != 0: - out_mask_flag = True - input_drag_384_outmask, input_mask_384_outmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_outmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_outmask, input_mask_384_outmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] - input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w] - input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] - input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w] - - first_frames_transform = transforms.Compose([ - lambda x: Image.fromarray(x), - transforms.ToTensor(), - ]) - - input_first_frame = image2arr(first_frame_path) - input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device) - - seed = 42 - num_frames = flow_div - - set_seed(seed) - - input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) - input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) - - input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - - input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) - mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) - input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) - mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) - - input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) - - if in_mask_flag: - flow_inmask = self.get_flow( - input_first_frame_384, - input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 - ) - else: - fb, fl = mask_384_inmask.shape[:2] - flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - if out_mask_flag: - flow_outmask = self.get_flow( - input_first_frame_384, - input_drag_384_outmask, mask_384_outmask - ) - else: - fb, fl = mask_384_outmask.shape[:2] - flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - inmask_no_zero = (flow_inmask != 0).all(dim=2) - inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) - - controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) - - print(controlnet_flow.shape) - - controlnet_flow = controlnet_flow[0, -1].permute(1, 2, 0) - viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c] - - return viz_esti_flows - - @torch.no_grad() - def get_cmp_flow_landmarks(self, frames, sparse_optical_flow, mask): - - dtype = frames.dtype - b, t, c, h, w = sparse_optical_flow.shape - assert h == 384 and w == 384 - frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] - sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] - mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] - - cmp_flow = [] - for i in range(b*t): - tmp_flow = self.cmp.run(frames[i:i+1].float(), sparse_optical_flow[i:i+1].float(), mask[i:i+1].float()) # [b*13, 2, 256, 256] - cmp_flow.append(tmp_flow) - cmp_flow = torch.cat(cmp_flow, dim=0) - cmp_flow = cmp_flow.reshape(b, t, 2, h, w) - - return cmp_flow.to(dtype=dtype) - - def audio2landmark(self, audio_path, img_path, ldmk_result_dir, ldmk_render=0): - - if ldmk_render == 'sadtalker': - return_code = os.system( - f''' - python sadtalker_audio2pose/inference.py \ - --preprocess full \ - --size 256 \ - --driven_audio {audio_path} \ - --source_image {img_path} \ - --result_dir {ldmk_result_dir} \ - --facerender pirender \ - --verbose \ - --face3dvis - ''') - assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report." - elif ldmk_render == 'aniportrait': - return_code = os.system( - f''' - python aniportrait/audio2ldmk.py \ - --ref_image_path {img_path} \ - --audio_path {audio_path} \ - --save_dir {ldmk_result_dir} \ - ''' - ) - assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report." - else: - assert False - - return os.path.join(ldmk_result_dir, 'landmarks.npy') - - - def get_landmarks(self, save_root, first_frame_path, audio_path, first_frame, num_frames=25, ldmk_render='sadtalker'): - - ldmk_dir = os.path.join(save_root, 'landmarks') - ldmknpy_dir = self.audio2landmark(audio_path, first_frame_path, ldmk_dir, ldmk_render) - - landmarks = np.load(ldmknpy_dir) - landmarks = landmarks[:num_frames] # [25, 68, 2] - flow_len = landmarks.shape[0] - - ldmk_clip = landmarks.copy() - - assert ldmk_clip.ndim == 3 - - ldmk_clip[:, :, 0] = ldmk_clip[:, :, 0] / self.width * 320 - ldmk_clip[:, :, 1] = ldmk_clip[:, :, 1] / self.height * 320 - - pose_imgs = [] - for i in range(ldmk_clip.shape[0]): - pose_img = draw_landmarks(ldmk_clip[i], 320, 320) - pose_img = cv2.resize(pose_img, (self.width, self.height), cv2.INTER_NEAREST) - pose_imgs.append(pose_img) - pose_imgs = np.array(pose_imgs) - pose_imgs = torch.from_numpy(pose_imgs).permute(0, 3, 1, 2).float() / 255. - pose_imgs = pose_imgs.unsqueeze(0).to(self.weight_dtype).to(self.device) - - landmarks = torch.from_numpy(landmarks).to(self.weight_dtype).to(self.device) - - val_controlnet_image, val_sparse_optical_flow, \ - val_mask, val_first_frame_384, \ - val_sparse_optical_flow_384, val_mask_384 = sample_inputs_face(first_frame, landmarks) - - fb, fl, fc, fh, fw = val_sparse_optical_flow.shape - - val_controlnet_flow = self.get_cmp_flow_landmarks( - val_first_frame_384.unsqueeze(0).repeat(1, fl, 1, 1, 1), - val_sparse_optical_flow_384, - val_mask_384 - ) - - if fh != 384 or fw != 384: - scales = [fh / 384, fw / 384] - val_controlnet_flow = F.interpolate(val_controlnet_flow.flatten(0, 1), (fh, fw), mode='nearest').reshape(fb, fl, 2, fh, fw) - val_controlnet_flow[:, :, 0] *= scales[1] - val_controlnet_flow[:, :, 1] *= scales[0] - - val_controlnet_image = val_controlnet_image.unsqueeze(0).repeat(1, fl, 1, 1, 1) - - return val_controlnet_flow, pose_imgs, landmarks, flow_len - - - def run(self, first_frame_path, audio_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render): - - - timestamp = str(time.time()).split('.')[0] - save_name = f"trajscale{ctrl_scale_traj}_ldmkscale{ctrl_scale_ldmk}_{ldmk_render}_ts{timestamp}" - save_root = os.path.join(os.path.dirname(audio_path), save_name) - os.makedirs(save_root, exist_ok=True) - - - original_width, original_height = self.width, self.height - - flow_div = self.model_length - - input_all_points = tracking_points.constructor_args['value'] - - # print(input_all_points) - - resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] - resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] - - new_resized_all_points = [] - new_resized_all_points_384 = [] - for tnum in range(len(resized_all_points)): - new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) - new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) - - resized_all_points = np.array(new_resized_all_points) - resized_all_points_384 = np.array(new_resized_all_points_384) - - motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) - # ldmk_mask_mask_384 = cv2.resize(ldmk_mask_mask, (384, 384), cv2.INTER_NEAREST) - - # motion_brush_mask = torch.from_numpy(motion_brush_mask) / 255. - # motion_brush_mask = motion_brush_mask.to(self.device) - - ldmk_mask_mask = torch.from_numpy(ldmk_mask_mask) / 255. - ldmk_mask_mask = ldmk_mask_mask.to(self.device) - - if resized_all_points_384.shape[0] != 0: - resized_all_points_384_inmask, resized_all_points_384_outmask = \ - divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) - else: - resized_all_points_384_inmask = np.array([]) - resized_all_points_384_outmask = np.array([]) - - in_mask_flag = False - out_mask_flag = False - - if resized_all_points_384_inmask.shape[0] != 0: - in_mask_flag = True - input_drag_384_inmask, input_mask_384_inmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_inmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_inmask, input_mask_384_inmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - if resized_all_points_384_outmask.shape[0] != 0: - out_mask_flag = True - input_drag_384_outmask, input_mask_384_outmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_outmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_outmask, input_mask_384_outmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0) # [1, 13, h, w, 2] - input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0) # [1, 13, h, w] - input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0) # [1, 13, h, w, 2] - input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0) # [1, 13, h, w] - - dir, base, ext = split_filename(first_frame_path) - id = base.split('_')[0] - - image_pil = image2pil(first_frame_path) - image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGBA') - - visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height) - - motion_brush_viz_pil = Image.fromarray(motion_brush_viz.astype(np.uint8)).convert('RGBA') - visualized_drag = visualized_drag[0].convert('RGBA') - ldmk_mask_viz_pil = Image.fromarray(ldmk_mask_viz.astype(np.uint8)).convert('RGBA') - - drag_input = Image.alpha_composite(image_pil, visualized_drag) - motionbrush_ldmkmask = Image.alpha_composite(motion_brush_viz_pil, ldmk_mask_viz_pil) - - visualized_drag_brush_ldmk_mask = Image.alpha_composite(drag_input, motionbrush_ldmkmask) - - first_frames_transform = transforms.Compose([ - lambda x: Image.fromarray(x), - transforms.ToTensor(), - ]) - - hint_path = os.path.join(save_root, f'hint.png') - visualized_drag_brush_ldmk_mask.save(hint_path) - - first_frames = image2arr(first_frame_path) - first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=1).to(self.device) - - outputs = self.forward_sample( - save_root, - first_frame_path, - audio_path, - hint_path, - input_drag_384_inmask.to(self.device), - input_drag_384_outmask.to(self.device), - first_frames.to(self.device), - input_mask_384_inmask.to(self.device), - input_mask_384_outmask.to(self.device), - in_mask_flag, - out_mask_flag, - motion_brush_mask_384, ldmk_mask_mask, - ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render=ldmk_render) - - traj_flow_tensor = outputs['traj_flows'][0] # [25, 3, h, w] - ldmk_flow_tensor = outputs['ldmk_flows'][0] # [25, 3, h, w] - viz_ldmk_tensor = outputs['viz_ldmk'][0] # [25, 3, h, w] - out_with_ldmk_tensor = outputs['out_with_ldmk'][0] # [25, 3, h, w] - output_tensor = outputs['logits_imgs'][0] # [25, 3, h, w] - total_tensor = outputs['total'][0] # [25, 3, h, w] - - traj_flows_path = os.path.join(save_root, f'traj_flow.gif') - ldmk_flows_path = os.path.join(save_root, f'ldmk_flow.gif') - viz_ldmk_path = os.path.join(save_root, f'viz_ldmk.gif') - out_with_ldmk_path = os.path.join(save_root, f'output_w_ldmk.gif') - outputs_path = os.path.join(save_root, f'output.gif') - total_path = os.path.join(save_root, f'total.gif') - - traj_flows_path_mp4 = os.path.join(save_root, f'traj_flow.mp4') - ldmk_flows_path_mp4 = os.path.join(save_root, f'ldmk_flow.mp4') - viz_ldmk_path_mp4 = os.path.join(save_root, f'viz_ldmk.mp4') - out_with_ldmk_path_mp4 = os.path.join(save_root, f'output_w_ldmk.mp4') - outputs_path_mp4 = os.path.join(save_root, f'output.mp4') - total_path_mp4 = os.path.join(save_root, f'total.mp4') - - # print(output_tensor.shape) - - traj_flow_np = traj_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - ldmk_flow_np = ldmk_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - viz_ldmk_np = viz_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - out_with_ldmk_np = out_with_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - output_np = output_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - total_np = total_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - - torchvision.io.write_video( - traj_flows_path_mp4, - traj_flow_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - ldmk_flows_path_mp4, - ldmk_flow_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - viz_ldmk_path_mp4, - viz_ldmk_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - out_with_ldmk_path_mp4, - out_with_ldmk_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - outputs_path_mp4, - output_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - - imageio.mimsave(traj_flows_path, np.uint8(traj_flow_np), fps=20, loop=0) - imageio.mimsave(ldmk_flows_path, np.uint8(ldmk_flow_np), fps=20, loop=0) - imageio.mimsave(viz_ldmk_path, np.uint8(viz_ldmk_np), fps=20, loop=0) - imageio.mimsave(out_with_ldmk_path, np.uint8(out_with_ldmk_np), fps=20, loop=0) - imageio.mimsave(outputs_path, np.uint8(output_np), fps=20, loop=0) - - torchvision.io.write_video(total_path_mp4, total_np, fps=20, video_codec='h264', options={'crf': '10'}) - imageio.mimsave(total_path, np.uint8(total_np), fps=20, loop=0) - - return hint_path, traj_flows_path, ldmk_flows_path, viz_ldmk_path, outputs_path, traj_flows_path_mp4, ldmk_flows_path_mp4, viz_ldmk_path_mp4, outputs_path_mp4 - - -with gr.Blocks() as demo: - gr.Markdown("""

MOFA-Video


""") - - gr.Markdown("""

Official Gradio Demo for MOFA-Video: Controllable Image Animation via Generative Motion Field Adaptions in Frozen Image-to-Video Diffusion Model.

""") - - gr.Markdown( - """ -

1. Use the "Upload Image" button to upload an image. Avoid dragging the image directly into the window.

-

2. Proceed to trajectory control:

- 2.1. Click "Add Trajectory" first, then select points on the "Add Trajectory Here" image. The first click sets the starting point. Click multiple points to create a non-linear trajectory. To add a new trajectory, click "Add Trajectory" again and select points on the image.
- 2.2. After adding each trajectory, an optical flow image will be displayed automatically in "Temporary Trajectory Flow Visualization". Use it as a reference to adjust the trajectory for desired effects (e.g., area, intensity).
- 2.3. To delete the latest trajectory, click "Delete Last Trajectory."
- 2.4. To use the motion brush for restraining the control area of the trajectory, click to add masks on the "Add Motion Brush Here" image. The motion brush restricts the optical flow area derived from the trajectory whose starting point is within the motion brush. The displayed optical flow image will change correspondingly. Adjust the motion brush radius using the "Motion Brush Radius" slider.
- 2.5. Choose the Control scale for trajectory using the "Control Scale for Trajectory" slider. This determines the control intensity of trajectory. Setting it to 0 means no control (pure generation result of SVD itself), while setting it to 1 results in the strongest control (which will not lead to good results in most cases because of twisting artifacts). A preset value of 0.6 is recommended for most cases.
-

3. Proceed to landmark control from audio:

- 3.1. Use the "Upload Audio" button to upload an audio (currently support .wav and .mp3 extensions).
- 3.2. Click to add masks on the "Add Landmark Mask Here" image. This mask restricts the optical flow area derived from the landmarks, which should usually covers the area of the person's head parts, and, if desired, body parts for more natural body movement instead of being stationary. Adjust the landmark brush radius using the "Landmark Brush Radius" slider.
- 3.3. Choose the Control scale for landmarks using the "Control Scale for Landmark" slider. This determines the control intensity of landmarks. Different from trajectory controls, a preset value of 1 is recommended for most cases.
- 3.4. Choose the landmark renderer to generate landmark sequences from the input audio. The landmark generation codes are based on either SadTalker or AniPortrait. We empirically find that SadTalker provides landmarks that follow the audio more precisely in the lips part, while Aniportrait provides more significant lips movement. Note that while pure landmark-based control of MOFA-Video supports long video generation via the periodic sampling strategy, current version of hybrid control only supports short video generation (25 frames), which means that the first 25 frames of the generated landmark sequences are used to obtain the result. -

4. Click the "Run" button to animate the image according to the trajectory and the landmark.

- """ - ) - - target_size = 512 # NOTICE: changing to lower resolution may impair the performance of the model. - DragNUWA_net = Drag("cuda:0", target_size, target_size, 25) - first_frame_path = gr.State() - audio_path = gr.State() - tracking_points = gr.State([]) - motion_brush_points = gr.State([]) - motion_brush_mask = gr.State() - motion_brush_viz = gr.State() - ldmk_mask_mask = gr.State() - ldmk_mask_viz = gr.State() - - def preprocess_image(image): - - image_pil = image2pil(image.name) - raw_w, raw_h = image_pil.size - - max_edge = min(raw_w, raw_h) - resize_ratio = target_size / max_edge - - image_pil = image_pil.resize((round(raw_w * resize_ratio), round(raw_h * resize_ratio)), Image.BILINEAR) - - new_w, new_h = image_pil.size - crop_w = new_w - (new_w % 64) - crop_h = new_h - (new_h % 64) - - image_pil = transforms.CenterCrop((crop_h, crop_w))(image_pil.convert('RGB')) - - DragNUWA_net.width = crop_w - DragNUWA_net.height = crop_h - - id = str(time.time()).split('.')[0] - os.makedirs(os.path.join(output_dir, str(id)), exist_ok=True) - - first_frame_path = os.path.join(output_dir, str(id), f"input.png") - image_pil.save(first_frame_path) - - return first_frame_path, first_frame_path, first_frame_path, first_frame_path, gr.State([]), gr.State([]), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)) - - def convert_audio_to_wav(input_audio_file, output_wav_file): - - extension = os.path.splitext(os.path.basename(input_audio_file))[-1] - - if extension.lower() == ".mp3": - audio = AudioSegment.from_mp3(input_audio_file) - elif extension.lower() == ".wav": - audio = AudioSegment.from_wav(input_audio_file) - elif extension.lower() == ".ogg": - audio = AudioSegment.from_ogg(input_audio_file) - elif extension.lower() == ".flac": - audio = AudioSegment.from_file(input_audio_file, "flac") - else: - raise ValueError(f"Not supported extension: {extension}") - - audio.export(output_wav_file, format="wav") - - def save_audio(audio, first_frame_path): - - assert first_frame_path is not None, "First upload image, then audio!" - - img_basedir = os.path.dirname(first_frame_path) - - id = str(time.time()).split('.')[0] - - audio_path = os.path.join(img_basedir, f'audio_{str(id)}', 'audio.wav') - os.makedirs(os.path.dirname(audio_path), exist_ok=True) - - # os.system(f'cp -r {audio.name} {audio_path}') - - convert_audio_to_wav(audio.name, audio_path) - - return audio_path, audio_path - - def add_drag(tracking_points): - if len(tracking_points.constructor_args['value']) != 0 and tracking_points.constructor_args['value'][-1] == []: - return tracking_points - tracking_points.constructor_args['value'].append([]) - return tracking_points - - def delete_last_drag(tracking_points, first_frame_path, motion_brush_mask): - - if len(tracking_points.constructor_args['value']) > 0: - tracking_points.constructor_args['value'].pop() - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - transparent_layer = np.zeros((h, w, 4)) - for track in tracking_points.constructor_args['value']: - if len(track) > 1: - for i in range(len(track)-1): - start_point = track[i] - end_point = track[i+1] - vx = end_point[0] - start_point[0] - vy = end_point[1] - start_point[1] - arrow_length = np.sqrt(vx**2 + vy**2) - if i == len(track)-2: - cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length) - else: - cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,) - else: - cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1) - - transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) - trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) - - viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) - - return tracking_points, trajectory_map, viz_flow - - def add_motion_brushes(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, tracking_points, evt: gr.SelectData): - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - - motion_points = motion_brush_points.constructor_args['value'] - motion_points.append(evt.index) - - x, y = evt.index - - cv2.circle(motion_brush_mask, (x, y), radius, 255, -1) - cv2.circle(transparent_layer, (x, y), radius, (128, 0, 128, 127), -1) - - transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8)) - motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil) - - viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) - - return motion_brush_mask, transparent_layer, motion_map, viz_flow - - - def add_ldmk_mask(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, evt: gr.SelectData): - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - - motion_points = motion_brush_points.constructor_args['value'] - motion_points.append(evt.index) - - x, y = evt.index - - cv2.circle(motion_brush_mask, (x, y), radius, 255, -1) - cv2.circle(transparent_layer, (x, y), radius, (0, 0, 255, 127), -1) - - transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8)) - motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil) - - return motion_brush_mask, transparent_layer, motion_map - - - - def add_tracking_points(tracking_points, first_frame_path, motion_brush_mask, evt: gr.SelectData): # SelectData is a subclass of EventData - print(f"You selected {evt.value} at {evt.index} from {evt.target}") - - if len(tracking_points.constructor_args['value']) == 0: - tracking_points.constructor_args['value'].append([]) - - tracking_points.constructor_args['value'][-1].append(evt.index) - - print(tracking_points.constructor_args['value']) - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - transparent_layer = np.zeros((h, w, 4)) - for track in tracking_points.constructor_args['value']: - if len(track) > 1: - for i in range(len(track)-1): - start_point = track[i] - end_point = track[i+1] - vx = end_point[0] - start_point[0] - vy = end_point[1] - start_point[1] - arrow_length = np.sqrt(vx**2 + vy**2) - if i == len(track)-2: - cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length) - else: - cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,) - else: - cv2.circle(transparent_layer, tuple(track[0]), 3, (255, 0, 0, 255), -1) - - transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) - trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) - - viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) - - return tracking_points, trajectory_map, viz_flow - - with gr.Row(): - with gr.Column(scale=3): - image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"]) - audio_upload_button = gr.UploadButton(label="Upload Audio", file_types=["audio"]) - input_audio = gr.Audio(label="Audio") - with gr.Column(scale=3): - add_drag_button = gr.Button(value="Add Trajectory") - delete_last_drag_button = gr.Button(value="Delete Last Trajectory") - run_button = gr.Button(value="Run") - with gr.Column(scale=3): - motion_brush_radius = gr.Slider(label='Motion Brush Radius', - minimum=1, - maximum=200, - step=1, - value=10) - ldmk_mask_radius = gr.Slider(label='Landmark Brush Radius', - minimum=1, - maximum=200, - step=1, - value=10) - with gr.Column(scale=3): - ctrl_scale_traj = gr.Slider(label='Control Scale for Trajectory', - minimum=0, - maximum=1., - step=0.01, - value=0.6) - ctrl_scale_ldmk = gr.Slider(label='Control Scale for Landmark', - minimum=0, - maximum=1., - step=0.01, - value=1.) - ldmk_render = gr.Radio(label='Landmark Renderer', - choices=['sadtalker', 'aniportrait'], - value='aniportrait') - - with gr.Column(scale=4): - input_image = gr.Image(label="Add Trajectory Here", - interactive=True) - with gr.Column(scale=4): - motion_brush_image = gr.Image(label="Add Motion Brush Here", - interactive=True) - with gr.Column(scale=4): - ldmk_mask_image = gr.Image(label="Add Landmark Mask Here", - interactive=True) - - with gr.Row(): - with gr.Column(scale=6): - viz_flow = gr.Image(label="Temporary Trajectory Flow Visualization") - with gr.Column(scale=6): - hint_image = gr.Image(label="Final Hint Image") - - with gr.Row(): - with gr.Column(scale=6): - traj_flows_gif = gr.Image(label="Trajectory Flow GIF") - with gr.Column(scale=6): - ldmk_flows_gif = gr.Image(label="Landmark Flow GIF") - with gr.Row(): - with gr.Column(scale=6): - viz_ldmk_gif = gr.Image(label="Landmark Visualization GIF") - with gr.Column(scale=6): - outputs_gif = gr.Image(label="Output GIF") - - with gr.Row(): - with gr.Column(scale=6): - traj_flows_mp4 = gr.Video(label="Trajectory Flow MP4") - with gr.Column(scale=6): - ldmk_flows_mp4 = gr.Video(label="Landmark Flow MP4") - with gr.Row(): - with gr.Column(scale=6): - viz_ldmk_mp4 = gr.Video(label="Landmark Visualization MP4") - with gr.Column(scale=6): - outputs_mp4 = gr.Video(label="Output MP4") - - image_upload_button.upload(preprocess_image, image_upload_button, [input_image, motion_brush_image, ldmk_mask_image, first_frame_path, tracking_points, motion_brush_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz]) - - audio_upload_button.upload(save_audio, [audio_upload_button, first_frame_path], [input_audio, audio_path]) - - add_drag_button.click(add_drag, tracking_points, tracking_points) - - delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow]) - - input_image.select(add_tracking_points, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow]) - - motion_brush_image.select(add_motion_brushes, [motion_brush_points, motion_brush_mask, motion_brush_viz, first_frame_path, motion_brush_radius, tracking_points], [motion_brush_mask, motion_brush_viz, motion_brush_image, viz_flow]) - - ldmk_mask_image.select(add_ldmk_mask, [motion_brush_points, ldmk_mask_mask, ldmk_mask_viz, first_frame_path, ldmk_mask_radius], [ldmk_mask_mask, ldmk_mask_viz, ldmk_mask_image]) - - run_button.click(DragNUWA_net.run, [first_frame_path, audio_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render], [hint_image, traj_flows_gif, ldmk_flows_gif, viz_ldmk_gif, outputs_gif, traj_flows_mp4, ldmk_flows_mp4, viz_ldmk_mp4, outputs_mp4]) - - # demo.launch(server_name="0.0.0.0", debug=True, server_port=80) - demo.launch(server_name="127.0.0.1", debug=True, server_port=9080) diff --git a/run_gradio_video_driven.py b/run_gradio_video_driven.py deleted file mode 100644 index db708c711ebb68d7da9afcaf93a70496a842dff0..0000000000000000000000000000000000000000 --- a/run_gradio_video_driven.py +++ /dev/null @@ -1,1234 +0,0 @@ -import gradio as gr -import numpy as np -import cv2 -import os -from PIL import Image -from scipy.interpolate import PchipInterpolator -import torchvision -import time -from tqdm import tqdm -import imageio - -import torch -import torch.nn.functional as F -import torchvision -import torchvision.transforms as transforms -from einops import repeat - -from pydub import AudioSegment - -from packaging import version - -from accelerate.utils import set_seed -from transformers import CLIPVisionModelWithProjection - -from diffusers import AutoencoderKLTemporalDecoder -from diffusers.utils.import_utils import is_xformers_available - -from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel -from pipeline.pipeline import FlowControlNetPipeline -from models.traj_ctrlnet import FlowControlNet as DragControlNet, CMP_demo -from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet - -from utils.flow_viz import flow_to_image -from utils.utils import split_filename, image2arr, image2pil, ensure_dirname - - -output_dir = "Output_video_driven" - - -ensure_dirname(output_dir) - - -def draw_landmarks_cv2(image, landmarks): - for i, point in enumerate(landmarks): - cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 0, 255), -1) - return image - - -def sample_optical_flow(A, B, h, w): - b, l, k, _ = A.shape - - sparse_optical_flow = torch.zeros((b, l, h, w, 2), dtype=B.dtype, device=B.device) - mask = torch.zeros((b, l, h, w), dtype=torch.uint8, device=B.device) - - x_coords = A[..., 0].long() - y_coords = A[..., 1].long() - - x_coords = torch.clip(x_coords, 0, h - 1) - y_coords = torch.clip(y_coords, 0, w - 1) - - b_idx = torch.arange(b)[:, None, None].repeat(1, l, k) - l_idx = torch.arange(l)[None, :, None].repeat(b, 1, k) - - sparse_optical_flow[b_idx, l_idx, x_coords, y_coords] = B - - mask[b_idx, l_idx, x_coords, y_coords] = 1 - - mask = mask.unsqueeze(-1).repeat(1, 1, 1, 1, 2) - - return sparse_optical_flow, mask - - -@torch.no_grad() -def get_sparse_flow(landmarks, h, w, t): - - landmarks = torch.flip(landmarks, dims=[3]) - - pose_flow = (landmarks - landmarks[:, 0:1].repeat(1, t, 1, 1))[:, 1:] # 前向光流 - according_poses = landmarks[:, 0:1].repeat(1, t - 1, 1, 1) - - pose_flow = torch.flip(pose_flow, dims=[3]) - - b, t, K, _ = pose_flow.shape - - sparse_optical_flow, mask = sample_optical_flow(according_poses, pose_flow, h, w) - - return sparse_optical_flow.permute(0, 1, 4, 2, 3), mask.permute(0, 1, 4, 2, 3) - - - -def sample_inputs_face(first_frame, landmarks): - - pc, ph, pw = first_frame.shape - landmarks = landmarks.unsqueeze(0) - - pl = landmarks.shape[1] - - sparse_optical_flow, mask = get_sparse_flow(landmarks, ph, pw, pl) - - if ph != 384 or pw != 384: - - first_frame_384 = F.interpolate(first_frame.unsqueeze(0), (384, 384)) # [3, 384, 384] - - landmarks_384 = torch.zeros_like(landmarks) - landmarks_384[:, :, :, 0] = landmarks[:, :, :, 0] / pw * 384 - landmarks_384[:, :, :, 1] = landmarks[:, :, :, 1] / ph * 384 - - sparse_optical_flow_384, mask_384 = get_sparse_flow(landmarks_384, 384, 384, pl) - - else: - first_frame_384, landmarks_384 = first_frame, landmarks - sparse_optical_flow_384, mask_384 = sparse_optical_flow, mask - - controlnet_image = first_frame.unsqueeze(0) - - return controlnet_image, sparse_optical_flow, mask, first_frame_384, sparse_optical_flow_384, mask_384 - - - -PARTS = [ - ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)), - ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)), - ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)), - ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)), - ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)), - ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)), - ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)), - ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)), - ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)), - ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)), - ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)), - ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)), - ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)), - ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)), - ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)), -] - - -def draw_landmarks(keypoints, h, w): - - image = np.zeros((h, w, 3)) - - for name, indices, color in PARTS: - indices = np.array(indices) - 1 - current_part_keypoints = keypoints[indices] - - for i in range(len(indices) - 1): - x1, y1 = current_part_keypoints[i] - x2, y2 = current_part_keypoints[i + 1] - cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2) - - return image - - -def divide_points_afterinterpolate(resized_all_points, motion_brush_mask): - k = resized_all_points.shape[0] - starts = resized_all_points[:, 0] # [K, 2] - - in_masks = [] - out_masks = [] - - for i in range(k): - x, y = int(starts[i][1]), int(starts[i][0]) - if motion_brush_mask[x][y] == 255: - in_masks.append(resized_all_points[i]) - else: - out_masks.append(resized_all_points[i]) - - in_masks = np.array(in_masks) - out_masks = np.array(out_masks) - - return in_masks, out_masks - - -def get_sparseflow_and_mask_forward( - resized_all_points, - n_steps, H, W, - is_backward_flow=False - ): - - K = resized_all_points.shape[0] - - starts = resized_all_points[:, 0] - - interpolated_ends = resized_all_points[:, 1:] - - s_flow = np.zeros((K, n_steps, H, W, 2)) - mask = np.zeros((K, n_steps, H, W)) - - for k in range(K): - for i in range(n_steps): - start, end = starts[k], interpolated_ends[k][i] - flow = np.int64(end - start) * (-1 if is_backward_flow is True else 1) - s_flow[k][i][int(start[1]), int(start[0])] = flow - mask[k][i][int(start[1]), int(start[0])] = 1 - - s_flow = np.sum(s_flow, axis=0) - mask = np.sum(mask, axis=0) - - return s_flow, mask - - -def init_models(pretrained_model_name_or_path, weight_dtype, device='cuda', enable_xformers_memory_efficient_attention=False, allow_tf32=False): - - drag_ckpt = "./ckpts/mofa/traj_controlnet" - face_ckpt = "./ckpts/mofa/ldmk_controlnet" - - print('start loading models...') - - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - pretrained_model_name_or_path, subfolder="image_encoder", revision=None, variant="fp16" - ) - vae = AutoencoderKLTemporalDecoder.from_pretrained( - pretrained_model_name_or_path, subfolder="vae", revision=None, variant="fp16") - unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained( - pretrained_model_name_or_path, - subfolder="unet", - low_cpu_mem_usage=True, - variant="fp16", - ) - - drag_controlnet = DragControlNet.from_pretrained(drag_ckpt) - face_controlnet = FaceControlNet.from_pretrained(face_ckpt) - - cmp = CMP_demo( - './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml', - 42000 - ).to(device) - cmp.requires_grad_(False) - - # Freeze vae and image_encoder - vae.requires_grad_(False) - image_encoder.requires_grad_(False) - unet.requires_grad_(False) - drag_controlnet.requires_grad_(False) - face_controlnet.requires_grad_(False) - - # Move image_encoder and vae to gpu and cast to weight_dtype - image_encoder.to(device, dtype=weight_dtype) - vae.to(device, dtype=weight_dtype) - unet.to(device, dtype=weight_dtype) - drag_controlnet.to(device, dtype=weight_dtype) - face_controlnet.to(device, dtype=weight_dtype) - - if enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - print( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError( - "xformers is not available. Make sure it is installed correctly") - - if allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - pipeline = FlowControlNetPipeline.from_pretrained( - pretrained_model_name_or_path, - unet=unet, - face_controlnet=face_controlnet, - drag_controlnet=drag_controlnet, - image_encoder=image_encoder, - vae=vae, - torch_dtype=weight_dtype, - ) - pipeline = pipeline.to(device) - - print('models loaded.') - - return pipeline, cmp - - -def interpolate_trajectory(points, n_points): - x = [point[0] for point in points] - y = [point[1] for point in points] - - t = np.linspace(0, 1, len(points)) - - fx = PchipInterpolator(t, x) - fy = PchipInterpolator(t, y) - - new_t = np.linspace(0, 1, n_points) - - new_x = fx(new_t) - new_y = fy(new_t) - new_points = list(zip(new_x, new_y)) - - return new_points - - -def visualize_drag_v2(background_image_path, splited_tracks, width, height): - trajectory_maps = [] - - background_image = Image.open(background_image_path).convert('RGBA') - background_image = background_image.resize((width, height)) - w, h = background_image.size - transparent_background = np.array(background_image) - transparent_background[:, :, -1] = 128 - transparent_background = Image.fromarray(transparent_background) - - # Create a transparent layer with the same size as the background image - transparent_layer = np.zeros((h, w, 4)) - for splited_track in splited_tracks: - if len(splited_track) > 1: - splited_track = interpolate_trajectory(splited_track, 16) - splited_track = splited_track[:16] - for i in range(len(splited_track)-1): - start_point = (int(splited_track[i][0]), int(splited_track[i][1])) - end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) - vx = end_point[0] - start_point[0] - vy = end_point[1] - start_point[1] - arrow_length = np.sqrt(vx**2 + vy**2) - if i == len(splited_track)-2: - cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length) - else: - cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2) - else: - cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1) - - transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) - trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) - trajectory_maps.append(trajectory_map) - return trajectory_maps, transparent_layer - - -class Drag: - def __init__(self, device, height, width, model_length): - self.device = device - - pretrained_model_name_or_path = "./ckpts/mofa/stable-video-diffusion-img2vid-xt-1-1" - - self.device = 'cuda' - self.weight_dtype = torch.float16 - - self.pipeline, self.cmp = init_models( - pretrained_model_name_or_path, - weight_dtype=self.weight_dtype, - device=self.device, - ) - - self.height = height - self.width = width - self.model_length = model_length - - def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None): - - b, t, c, h, w = frames.shape - assert h == 384 and w == 384 - frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] - sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] - mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] - - cmp_flow = [] - for i in range(b*t): - tmp_flow = self.cmp.run(frames[i:i+1], sparse_optical_flow[i:i+1], mask[i:i+1]) # [1, 2, 256, 256] - cmp_flow.append(tmp_flow) - cmp_flow = torch.cat(cmp_flow, dim=0) # [b*13, 2, 256, 256] - - if brush_mask is not None: - brush_mask = torch.from_numpy(brush_mask) / 255. - brush_mask = brush_mask.to(cmp_flow.device, dtype=cmp_flow.dtype) - brush_mask = brush_mask.unsqueeze(0).unsqueeze(0) - cmp_flow = cmp_flow * brush_mask - - cmp_flow = cmp_flow.reshape(b, t, 2, h, w) - - return cmp_flow - - - def get_flow(self, pixel_values_384, sparse_optical_flow_384, mask_384, motion_brush_mask=None): - - fb, fl, fc, _, _ = pixel_values_384.shape - - controlnet_flow = self.get_cmp_flow( - pixel_values_384[:, 0:1, :, :, :].repeat(1, fl, 1, 1, 1), - sparse_optical_flow_384, - mask_384, motion_brush_mask - ) - - if self.height != 384 or self.width != 384: - scales = [self.height / 384, self.width / 384] - controlnet_flow = F.interpolate(controlnet_flow.flatten(0, 1), (self.height, self.width), mode='nearest').reshape(fb, fl, 2, self.height, self.width) - controlnet_flow[:, :, 0] *= scales[1] - controlnet_flow[:, :, 1] *= scales[0] - - return controlnet_flow - - @torch.no_grad() - def forward_sample(self, save_root, first_frame_path, driven_video_path, hint_path, input_drag_384_inmask, input_drag_384_outmask, input_first_frame, input_mask_384_inmask, input_mask_384_outmask, in_mask_flag, out_mask_flag, motion_brush_mask_384=None, ldmk_mask_mask_origin=None, ctrl_scale_traj=1., ctrl_scale_ldmk=1., ldmk_render='sadtalker'): - - seed = 42 - - num_frames = self.model_length - - set_seed(seed) - - input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) - input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) - input_first_frame_pil = Image.fromarray(np.uint8(input_first_frame[0].cpu().permute(1, 2, 0)*255)) - height, width = input_first_frame.shape[-2:] - - input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - - input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) - mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) - input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) - mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) - - input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) - - if in_mask_flag: - flow_inmask = self.get_flow( - input_first_frame_384, - input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 - ) - else: - fb, fl = mask_384_inmask.shape[:2] - flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - if out_mask_flag: - flow_outmask = self.get_flow( - input_first_frame_384, - input_drag_384_outmask, mask_384_outmask - ) - else: - fb, fl = mask_384_outmask.shape[:2] - flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - inmask_no_zero = (flow_inmask != 0).all(dim=2) - inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) - - controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) - - ldmk_controlnet_flow, ldmk_pose_imgs, landmarks, num_frames = self.get_landmarks(save_root, first_frame_path, driven_video_path, input_first_frame[0], self.model_length, ldmk_render=ldmk_render) - - ldmk_flow_len = ldmk_controlnet_flow.shape[1] - drag_flow_len = controlnet_flow.shape[1] - repeat_num = ldmk_flow_len // drag_flow_len + 1 - drag_controlnet_flow = controlnet_flow.repeat(1, repeat_num, 1, 1, 1) - drag_controlnet_flow = drag_controlnet_flow[:, :ldmk_flow_len] - - ldmk_mask_mask_origin = ldmk_mask_mask_origin.unsqueeze(0).unsqueeze(0) # [1, 1, h, w] - - val_output = self.pipeline( - input_first_frame_pil, - input_first_frame_pil, - - ldmk_controlnet_flow, - ldmk_pose_imgs, - - drag_controlnet_flow, - ldmk_mask_mask_origin, - - height=height, - width=width, - num_frames=num_frames, - decode_chunk_size=8, - motion_bucket_id=127, - fps=7, - noise_aug_strength=0.02, - ctrl_scale_traj=ctrl_scale_traj, - ctrl_scale_ldmk=ctrl_scale_ldmk, - ) - - video_frames, estimated_flow = val_output.frames[0], val_output.controlnet_flow - - for i in range(num_frames): - img = video_frames[i] - video_frames[i] = np.array(img) - - video_frames = np.array(video_frames) - - outputs = self.save_video(ldmk_pose_imgs, first_frame_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow) - - return outputs - - def save_video(self, pose_imgs, image_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow, outputs=dict()): - - pose_img_nps = (pose_imgs[0].permute(0, 2, 3, 1).cpu().numpy()*255).astype(np.uint8) - - cv2_firstframe = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) - cv2_hint = cv2.cvtColor(cv2.imread(hint_path), cv2.COLOR_BGR2RGB) - - viz_landmarks = [] - for k in tqdm(range(len(landmarks))): - im = draw_landmarks_cv2(video_frames[k].copy(), landmarks[k]) - viz_landmarks.append(im) - viz_landmarks = np.stack(viz_landmarks) - - viz_esti_flows = [] - for i in range(estimated_flow.shape[1]): - temp_flow = estimated_flow[0][i].permute(1, 2, 0) - viz_esti_flows.append(flow_to_image(temp_flow)) - viz_esti_flows = [np.uint8(np.ones_like(viz_esti_flows[-1]) * 255)] + viz_esti_flows - viz_esti_flows = np.stack(viz_esti_flows) # [t-1, h, w, c] - - viz_drag_flows = [] - for i in range(drag_controlnet_flow.shape[1]): - temp_flow = drag_controlnet_flow[0][i].permute(1, 2, 0) - viz_drag_flows.append(flow_to_image(temp_flow)) - viz_drag_flows = [np.uint8(np.ones_like(viz_drag_flows[-1]) * 255)] + viz_drag_flows - viz_drag_flows = np.stack(viz_drag_flows) # [t-1, h, w, c] - - out_nps = [] - for plen in range(video_frames.shape[0]): - out_nps.append(video_frames[plen]) - out_nps = np.stack(out_nps) - - first_frames = np.stack([cv2_firstframe] * out_nps.shape[0]) - hints = np.stack([cv2_hint] * out_nps.shape[0]) - - total_nps = np.concatenate([ - first_frames, hints, viz_drag_flows, viz_esti_flows, pose_img_nps, viz_landmarks, out_nps - ], axis=2) - - video_frames_tensor = torch.from_numpy(video_frames).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - - outputs['logits_imgs'] = video_frames_tensor - outputs['traj_flows'] = torch.from_numpy(viz_drag_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['ldmk_flows'] = torch.from_numpy(viz_esti_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['viz_ldmk'] = torch.from_numpy(pose_img_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['out_with_ldmk'] = torch.from_numpy(viz_landmarks).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - outputs['total'] = torch.from_numpy(total_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. - - return outputs - - @torch.no_grad() - def get_cmp_flow_from_tracking_points(self, tracking_points, motion_brush_mask, first_frame_path): - - original_width, original_height = self.width, self.height - - flow_div = self.model_length - - input_all_points = tracking_points.constructor_args['value'] - - if len(input_all_points) == 0 or len(input_all_points[-1]) == 1: - return np.uint8(np.ones((original_width, original_height, 3))*255) - - resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] - resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] - - new_resized_all_points = [] - new_resized_all_points_384 = [] - for tnum in range(len(resized_all_points)): - new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) - new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) - - resized_all_points = np.array(new_resized_all_points) - resized_all_points_384 = np.array(new_resized_all_points_384) - - motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) - - resized_all_points_384_inmask, resized_all_points_384_outmask = \ - divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) - - in_mask_flag = False - out_mask_flag = False - - if resized_all_points_384_inmask.shape[0] != 0: - in_mask_flag = True - input_drag_384_inmask, input_mask_384_inmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_inmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_inmask, input_mask_384_inmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - if resized_all_points_384_outmask.shape[0] != 0: - out_mask_flag = True - input_drag_384_outmask, input_mask_384_outmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_outmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_outmask, input_mask_384_outmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] - input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w] - input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] - input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w] - - first_frames_transform = transforms.Compose([ - lambda x: Image.fromarray(x), - transforms.ToTensor(), - ]) - - input_first_frame = image2arr(first_frame_path) - input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device) - - seed = 42 - num_frames = flow_div - - set_seed(seed) - - input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) - input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) - - input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] - mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] - - input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) - mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) - input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) - mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) - - input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) - - if in_mask_flag: - flow_inmask = self.get_flow( - input_first_frame_384, - input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 - ) - else: - fb, fl = mask_384_inmask.shape[:2] - flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - if out_mask_flag: - flow_outmask = self.get_flow( - input_first_frame_384, - input_drag_384_outmask, mask_384_outmask - ) - else: - fb, fl = mask_384_outmask.shape[:2] - flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) - - inmask_no_zero = (flow_inmask != 0).all(dim=2) - inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) - - controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) - - print(controlnet_flow.shape) - - controlnet_flow = controlnet_flow[0, -1].permute(1, 2, 0) - viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c] - - return viz_esti_flows - - @torch.no_grad() - def get_cmp_flow_landmarks(self, frames, sparse_optical_flow, mask): - - dtype = frames.dtype - b, t, c, h, w = sparse_optical_flow.shape - assert h == 384 and w == 384 - frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] - sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] - mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] - - cmp_flow = [] - for i in range(b*t): - tmp_flow = self.cmp.run(frames[i:i+1].float(), sparse_optical_flow[i:i+1].float(), mask[i:i+1].float()) # [b*13, 2, 256, 256] - cmp_flow.append(tmp_flow) - cmp_flow = torch.cat(cmp_flow, dim=0) - cmp_flow = cmp_flow.reshape(b, t, 2, h, w) - - return cmp_flow.to(dtype=dtype) - - def video2landmark(self, driven_video_path, img_path, ldmk_result_dir, ldmk_render=0): - - if ldmk_render == 'sadtalker': - return_code = os.system( - f''' - python sadtalker_video2pose/inference.py \ - --preprocess full \ - --size 256 \ - --ref_pose {driven_video_path} \ - --source_image {img_path} \ - --result_dir {ldmk_result_dir} \ - --facerender pirender \ - --verbose \ - --face3dvis - ''') - assert return_code == 0, "Errors in generating landmarks! Maybe Sadtalker can not detect the landmark from source video. Please trace back up for detailed error report." - else: - assert False - - return os.path.join(ldmk_result_dir, 'landmarks.npy') - - - def get_landmarks(self, save_root, first_frame_path, driven_video_path, first_frame, num_frames=25, ldmk_render='sadtalker'): - - ldmk_dir = os.path.join(save_root, 'landmarks') - ldmknpy_dir = self.video2landmark(driven_video_path, first_frame_path, ldmk_dir, ldmk_render) - - landmarks = np.load(ldmknpy_dir) - landmarks = landmarks[:num_frames] # [25, 68, 2] - flow_len = landmarks.shape[0] - - ldmk_clip = landmarks.copy() - - assert ldmk_clip.ndim == 3 - - ldmk_clip[:, :, 0] = ldmk_clip[:, :, 0] / self.width * 320 - ldmk_clip[:, :, 1] = ldmk_clip[:, :, 1] / self.height * 320 - - pose_imgs = [] - for i in range(ldmk_clip.shape[0]): - pose_img = draw_landmarks(ldmk_clip[i], 320, 320) - pose_img = cv2.resize(pose_img, (self.width, self.height), cv2.INTER_NEAREST) - pose_imgs.append(pose_img) - pose_imgs = np.array(pose_imgs) - pose_imgs = torch.from_numpy(pose_imgs).permute(0, 3, 1, 2).float() / 255. - pose_imgs = pose_imgs.unsqueeze(0).to(self.weight_dtype).to(self.device) - - landmarks = torch.from_numpy(landmarks).to(self.weight_dtype).to(self.device) - - val_controlnet_image, val_sparse_optical_flow, \ - val_mask, val_first_frame_384, \ - val_sparse_optical_flow_384, val_mask_384 = sample_inputs_face(first_frame, landmarks) - - fb, fl, fc, fh, fw = val_sparse_optical_flow.shape - - val_controlnet_flow = self.get_cmp_flow_landmarks( - val_first_frame_384.unsqueeze(0).repeat(1, fl, 1, 1, 1), - val_sparse_optical_flow_384, - val_mask_384 - ) - - if fh != 384 or fw != 384: - scales = [fh / 384, fw / 384] - val_controlnet_flow = F.interpolate(val_controlnet_flow.flatten(0, 1), (fh, fw), mode='nearest').reshape(fb, fl, 2, fh, fw) - val_controlnet_flow[:, :, 0] *= scales[1] - val_controlnet_flow[:, :, 1] *= scales[0] - - val_controlnet_image = val_controlnet_image.unsqueeze(0).repeat(1, fl, 1, 1, 1) - - return val_controlnet_flow, pose_imgs, landmarks, flow_len - - - def run(self, first_frame_path, driven_video_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render): - - - timestamp = str(time.time()).split('.')[0] - save_name = f"trajscale{ctrl_scale_traj}_ldmkscale{ctrl_scale_ldmk}_{ldmk_render}_ts{timestamp}" - save_root = os.path.join(os.path.dirname(driven_video_path), save_name) - os.makedirs(save_root, exist_ok=True) - - - original_width, original_height = self.width, self.height - - flow_div = self.model_length - - input_all_points = tracking_points.constructor_args['value'] - - # print(input_all_points) - - resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] - resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] - - new_resized_all_points = [] - new_resized_all_points_384 = [] - for tnum in range(len(resized_all_points)): - new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) - new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) - - resized_all_points = np.array(new_resized_all_points) - resized_all_points_384 = np.array(new_resized_all_points_384) - - motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) - # ldmk_mask_mask_384 = cv2.resize(ldmk_mask_mask, (384, 384), cv2.INTER_NEAREST) - - # motion_brush_mask = torch.from_numpy(motion_brush_mask) / 255. - # motion_brush_mask = motion_brush_mask.to(self.device) - - ldmk_mask_mask = torch.from_numpy(ldmk_mask_mask) / 255. - ldmk_mask_mask = ldmk_mask_mask.to(self.device) - - if resized_all_points_384.shape[0] != 0: - resized_all_points_384_inmask, resized_all_points_384_outmask = \ - divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) - else: - resized_all_points_384_inmask = np.array([]) - resized_all_points_384_outmask = np.array([]) - - in_mask_flag = False - out_mask_flag = False - - if resized_all_points_384_inmask.shape[0] != 0: - in_mask_flag = True - input_drag_384_inmask, input_mask_384_inmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_inmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_inmask, input_mask_384_inmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - if resized_all_points_384_outmask.shape[0] != 0: - out_mask_flag = True - input_drag_384_outmask, input_mask_384_outmask = \ - get_sparseflow_and_mask_forward( - resized_all_points_384_outmask, - flow_div - 1, 384, 384 - ) - else: - input_drag_384_outmask, input_mask_384_outmask = \ - np.zeros((flow_div - 1, 384, 384, 2)), \ - np.zeros((flow_div - 1, 384, 384)) - - input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0) # [1, 13, h, w, 2] - input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0) # [1, 13, h, w] - input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0) # [1, 13, h, w, 2] - input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0) # [1, 13, h, w] - - dir, base, ext = split_filename(first_frame_path) - id = base.split('_')[0] - - image_pil = image2pil(first_frame_path) - image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGBA') - - visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height) - - motion_brush_viz_pil = Image.fromarray(motion_brush_viz.astype(np.uint8)).convert('RGBA') - visualized_drag = visualized_drag[0].convert('RGBA') - ldmk_mask_viz_pil = Image.fromarray(ldmk_mask_viz.astype(np.uint8)).convert('RGBA') - - drag_input = Image.alpha_composite(image_pil, visualized_drag) - motionbrush_ldmkmask = Image.alpha_composite(motion_brush_viz_pil, ldmk_mask_viz_pil) - - visualized_drag_brush_ldmk_mask = Image.alpha_composite(drag_input, motionbrush_ldmkmask) - - first_frames_transform = transforms.Compose([ - lambda x: Image.fromarray(x), - transforms.ToTensor(), - ]) - - hint_path = os.path.join(save_root, f'hint.png') - visualized_drag_brush_ldmk_mask.save(hint_path) - - first_frames = image2arr(first_frame_path) - first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=1).to(self.device) - - outputs = self.forward_sample( - save_root, - first_frame_path, - driven_video_path, - hint_path, - input_drag_384_inmask.to(self.device), - input_drag_384_outmask.to(self.device), - first_frames.to(self.device), - input_mask_384_inmask.to(self.device), - input_mask_384_outmask.to(self.device), - in_mask_flag, - out_mask_flag, - motion_brush_mask_384, ldmk_mask_mask, - ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render=ldmk_render) - - traj_flow_tensor = outputs['traj_flows'][0] # [25, 3, h, w] - ldmk_flow_tensor = outputs['ldmk_flows'][0] # [25, 3, h, w] - viz_ldmk_tensor = outputs['viz_ldmk'][0] # [25, 3, h, w] - out_with_ldmk_tensor = outputs['out_with_ldmk'][0] # [25, 3, h, w] - output_tensor = outputs['logits_imgs'][0] # [25, 3, h, w] - total_tensor = outputs['total'][0] # [25, 3, h, w] - - traj_flows_path = os.path.join(save_root, f'traj_flow.gif') - ldmk_flows_path = os.path.join(save_root, f'ldmk_flow.gif') - viz_ldmk_path = os.path.join(save_root, f'viz_ldmk.gif') - out_with_ldmk_path = os.path.join(save_root, f'output_w_ldmk.gif') - outputs_path = os.path.join(save_root, f'output.gif') - total_path = os.path.join(save_root, f'total.gif') - - traj_flows_path_mp4 = os.path.join(save_root, f'traj_flow.mp4') - ldmk_flows_path_mp4 = os.path.join(save_root, f'ldmk_flow.mp4') - viz_ldmk_path_mp4 = os.path.join(save_root, f'viz_ldmk.mp4') - out_with_ldmk_path_mp4 = os.path.join(save_root, f'output_w_ldmk.mp4') - outputs_path_mp4 = os.path.join(save_root, f'output.mp4') - total_path_mp4 = os.path.join(save_root, f'total.mp4') - - # print(output_tensor.shape) - - traj_flow_np = traj_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - ldmk_flow_np = ldmk_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - viz_ldmk_np = viz_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - out_with_ldmk_np = out_with_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - output_np = output_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - total_np = total_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() - - torchvision.io.write_video( - traj_flows_path_mp4, - traj_flow_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - ldmk_flows_path_mp4, - ldmk_flow_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - viz_ldmk_path_mp4, - viz_ldmk_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - out_with_ldmk_path_mp4, - out_with_ldmk_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - torchvision.io.write_video( - outputs_path_mp4, - output_np, - fps=20, video_codec='h264', options={'crf': '10'} - ) - - imageio.mimsave(traj_flows_path, np.uint8(traj_flow_np), fps=20, loop=0) - imageio.mimsave(ldmk_flows_path, np.uint8(ldmk_flow_np), fps=20, loop=0) - imageio.mimsave(viz_ldmk_path, np.uint8(viz_ldmk_np), fps=20, loop=0) - imageio.mimsave(out_with_ldmk_path, np.uint8(out_with_ldmk_np), fps=20, loop=0) - imageio.mimsave(outputs_path, np.uint8(output_np), fps=20, loop=0) - - torchvision.io.write_video(total_path_mp4, total_np, fps=20, video_codec='h264', options={'crf': '10'}) - imageio.mimsave(total_path, np.uint8(total_np), fps=20, loop=0) - - return hint_path, traj_flows_path, ldmk_flows_path, viz_ldmk_path, outputs_path, traj_flows_path_mp4, ldmk_flows_path_mp4, viz_ldmk_path_mp4, outputs_path_mp4 - - -with gr.Blocks() as demo: - gr.Markdown("""

MOFA-Video


""") - - gr.Markdown("""

Official Gradio Demo for MOFA-Video: Controllable Image Animation via Generative Motion Field Adaptions in Frozen Image-to-Video Diffusion Model.

""") - - gr.Markdown( - """ -

1. Use the "Upload Image" button to upload an image. Avoid dragging the image directly into the window.

-

2. Proceed to trajectory control:

- 2.1. Click "Add Trajectory" first, then select points on the "Add Trajectory Here" image. The first click sets the starting point. Click multiple points to create a non-linear trajectory. To add a new trajectory, click "Add Trajectory" again and select points on the image.
- 2.2. After adding each trajectory, an optical flow image will be displayed automatically in "Temporary Trajectory Flow Visualization". Use it as a reference to adjust the trajectory for desired effects (e.g., area, intensity).
- 2.3. To delete the latest trajectory, click "Delete Last Trajectory."
- 2.4. To use the motion brush for restraining the control area of the trajectory, click to add masks on the "Add Motion Brush Here" image. The motion brush restricts the optical flow area derived from the trajectory whose starting point is within the motion brush. The displayed optical flow image will change correspondingly. Adjust the motion brush radius using the "Motion Brush Radius" slider.
- 2.5. Choose the Control scale for trajectory using the "Control Scale for Trajectory" slider. This determines the control intensity of trajectory. Setting it to 0 means no control (pure generation result of SVD itself), while setting it to 1 results in the strongest control (which will not lead to good results in most cases because of twisting artifacts). A preset value of 0.6 is recommended for most cases.
-

3. Proceed to landmark control from driven video:

- 3.1. Use the "Upload Driven Video" button to upload an driven video (We have tested .mp4 extensions, and other formats compatible with `cv2.VideoCapture` may also be uploaded without causing errors.).
- 3.2. Click to add masks on the "Add Landmark Mask Here" image. This mask restricts the optical flow area derived from the landmarks, which should usually covers the area of the person's head parts, and, if desired, body parts for more natural body movement instead of being stationary. Adjust the landmark brush radius using the "Landmark Brush Radius" slider.
- 3.3. Choose the Control scale for landmarks using the "Control Scale for Landmark" slider. This determines the control intensity of landmarks. Different from trajectory controls, a preset value of 1 is recommended for most cases.
- 3.4. For video-driven landmark generation, our codes are modified based on SadTalker. Note that while pure landmark-based control of MOFA-Video supports long video generation via the periodic sampling strategy, current version of hybrid control only supports short video generation (25 frames), which means that the first 25 frames of the generated landmark sequences are used to obtain the result. -

4. Click the "Run" button to animate the image according to the trajectory and the landmark.

- """ - ) - - target_size = 512 # NOTICE: changing to lower resolution may impair the performance of the model. - DragNUWA_net = Drag("cuda:0", target_size, target_size, 25) - first_frame_path = gr.State() - driven_video_path = gr.State() - tracking_points = gr.State([]) - motion_brush_points = gr.State([]) - motion_brush_mask = gr.State() - motion_brush_viz = gr.State() - ldmk_mask_mask = gr.State() - ldmk_mask_viz = gr.State() - - def preprocess_image(image): - - image_pil = image2pil(image.name) - raw_w, raw_h = image_pil.size - - max_edge = min(raw_w, raw_h) - resize_ratio = target_size / max_edge - - image_pil = image_pil.resize((round(raw_w * resize_ratio), round(raw_h * resize_ratio)), Image.BILINEAR) - - new_w, new_h = image_pil.size - crop_w = new_w - (new_w % 64) - crop_h = new_h - (new_h % 64) - - image_pil = transforms.CenterCrop((crop_h, crop_w))(image_pil.convert('RGB')) - - DragNUWA_net.width = crop_w - DragNUWA_net.height = crop_h - - id = str(time.time()).split('.')[0] - os.makedirs(os.path.join(output_dir, str(id)), exist_ok=True) - - first_frame_path = os.path.join(output_dir, str(id), f"input.png") - image_pil.save(first_frame_path) - - return first_frame_path, first_frame_path, first_frame_path, first_frame_path, gr.State([]), gr.State([]), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)) - - def video_to_numpy_array(video_path): - video = cv2.VideoCapture(video_path) - num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - frames = [] - for i in range(num_frames): - ret, frame = video.read() - if ret: - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - frames.append(frame) - else: - break - video.release() - frames = np.stack(frames, axis=0) - return frames - - def convert_video_to_mp4(input_audio_file, output_wav_file): - video_np = np.uint8(video_to_numpy_array(input_audio_file)) - torchvision.io.write_video( - output_wav_file, - video_np, - fps=25, video_codec='h264', options={'crf': '10'} - ) - - def save_driven_video(driven_video, first_frame_path): - - assert first_frame_path is not None, "Please first upload image, then upload audio." - - img_basedir = os.path.dirname(first_frame_path) - - id = str(time.time()).split('.')[0] - - driven_video_path = os.path.join(img_basedir, f'driven_video_{str(id)}', 'driven_video.mp4') - os.makedirs(os.path.dirname(driven_video_path), exist_ok=True) - - convert_video_to_mp4(driven_video.name, driven_video_path) - - return driven_video_path, driven_video_path - - def add_drag(tracking_points): - if len(tracking_points.constructor_args['value']) != 0 and tracking_points.constructor_args['value'][-1] == []: - return tracking_points - tracking_points.constructor_args['value'].append([]) - return tracking_points - - def delete_last_drag(tracking_points, first_frame_path, motion_brush_mask): - - if len(tracking_points.constructor_args['value']) > 0: - tracking_points.constructor_args['value'].pop() - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - transparent_layer = np.zeros((h, w, 4)) - for track in tracking_points.constructor_args['value']: - if len(track) > 1: - for i in range(len(track)-1): - start_point = track[i] - end_point = track[i+1] - vx = end_point[0] - start_point[0] - vy = end_point[1] - start_point[1] - arrow_length = np.sqrt(vx**2 + vy**2) - if i == len(track)-2: - cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length) - else: - cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,) - else: - cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1) - - transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) - trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) - - viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) - - return tracking_points, trajectory_map, viz_flow - - def add_motion_brushes(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, tracking_points, evt: gr.SelectData): - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - - motion_points = motion_brush_points.constructor_args['value'] - motion_points.append(evt.index) - - x, y = evt.index - - cv2.circle(motion_brush_mask, (x, y), radius, 255, -1) - cv2.circle(transparent_layer, (x, y), radius, (128, 0, 128, 127), -1) - - transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8)) - motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil) - - viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) - - return motion_brush_mask, transparent_layer, motion_map, viz_flow - - - def add_ldmk_mask(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, evt: gr.SelectData): - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - - motion_points = motion_brush_points.constructor_args['value'] - motion_points.append(evt.index) - - x, y = evt.index - - cv2.circle(motion_brush_mask, (x, y), radius, 255, -1) - cv2.circle(transparent_layer, (x, y), radius, (0, 0, 255, 127), -1) - - transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8)) - motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil) - - return motion_brush_mask, transparent_layer, motion_map - - - - def add_tracking_points(tracking_points, first_frame_path, motion_brush_mask, evt: gr.SelectData): # SelectData is a subclass of EventData - print(f"You selected {evt.value} at {evt.index} from {evt.target}") - - if len(tracking_points.constructor_args['value']) == 0: - tracking_points.constructor_args['value'].append([]) - - tracking_points.constructor_args['value'][-1].append(evt.index) - - print(tracking_points.constructor_args['value']) - - transparent_background = Image.open(first_frame_path).convert('RGBA') - w, h = transparent_background.size - transparent_layer = np.zeros((h, w, 4)) - for track in tracking_points.constructor_args['value']: - if len(track) > 1: - for i in range(len(track)-1): - start_point = track[i] - end_point = track[i+1] - vx = end_point[0] - start_point[0] - vy = end_point[1] - start_point[1] - arrow_length = np.sqrt(vx**2 + vy**2) - if i == len(track)-2: - cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length) - else: - cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,) - else: - cv2.circle(transparent_layer, tuple(track[0]), 3, (255, 0, 0, 255), -1) - - transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) - trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) - - viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) - - return tracking_points, trajectory_map, viz_flow - - with gr.Row(): - with gr.Column(scale=3): - image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"]) - video_upload_button = gr.UploadButton(label="Upload Driven Video", file_types=["video"]) - driven_video = gr.Video(label="Driven Video") - with gr.Column(scale=3): - add_drag_button = gr.Button(value="Add Trajectory") - delete_last_drag_button = gr.Button(value="Delete Last Trajectory") - run_button = gr.Button(value="Run") - with gr.Column(scale=3): - motion_brush_radius = gr.Slider(label='Motion Brush Radius', - minimum=1, - maximum=200, - step=1, - value=10) - ldmk_mask_radius = gr.Slider(label='Landmark Brush Radius', - minimum=1, - maximum=200, - step=1, - value=10) - with gr.Column(scale=3): - ctrl_scale_traj = gr.Slider(label='Control Scale for Trajectory', - minimum=0, - maximum=1., - step=0.01, - value=0.6) - ctrl_scale_ldmk = gr.Slider(label='Control Scale for Landmark', - minimum=0, - maximum=1., - step=0.01, - value=1.) - ldmk_render = gr.Radio(label='Landmark Renderer', - choices=['sadtalker'], - value='sadtalker') - - with gr.Column(scale=4): - input_image = gr.Image(label="Add Trajectory Here", - interactive=True) - with gr.Column(scale=4): - motion_brush_image = gr.Image(label="Add Motion Brush Here", - interactive=True) - with gr.Column(scale=4): - ldmk_mask_image = gr.Image(label="Add Landmark Mask Here", - interactive=True) - - with gr.Row(): - with gr.Column(scale=6): - viz_flow = gr.Image(label="Temporary Trajectory Flow Visualization") - with gr.Column(scale=6): - hint_image = gr.Image(label="Final Hint Image") - - with gr.Row(): - with gr.Column(scale=6): - traj_flows_gif = gr.Image(label="Trajectory Flow GIF") - with gr.Column(scale=6): - ldmk_flows_gif = gr.Image(label="Landmark Flow GIF") - with gr.Row(): - with gr.Column(scale=6): - viz_ldmk_gif = gr.Image(label="Landmark Visualization GIF") - with gr.Column(scale=6): - outputs_gif = gr.Image(label="Output GIF") - - with gr.Row(): - with gr.Column(scale=6): - traj_flows_mp4 = gr.Video(label="Trajectory Flow MP4") - with gr.Column(scale=6): - ldmk_flows_mp4 = gr.Video(label="Landmark Flow MP4") - with gr.Row(): - with gr.Column(scale=6): - viz_ldmk_mp4 = gr.Video(label="Landmark Visualization MP4") - with gr.Column(scale=6): - outputs_mp4 = gr.Video(label="Output MP4") - - image_upload_button.upload(preprocess_image, image_upload_button, [input_image, motion_brush_image, ldmk_mask_image, first_frame_path, tracking_points, motion_brush_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz]) - - video_upload_button.upload(save_driven_video, [video_upload_button, first_frame_path], [driven_video, driven_video_path]) - - add_drag_button.click(add_drag, tracking_points, tracking_points) - - delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow]) - - input_image.select(add_tracking_points, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow]) - - motion_brush_image.select(add_motion_brushes, [motion_brush_points, motion_brush_mask, motion_brush_viz, first_frame_path, motion_brush_radius, tracking_points], [motion_brush_mask, motion_brush_viz, motion_brush_image, viz_flow]) - - ldmk_mask_image.select(add_ldmk_mask, [motion_brush_points, ldmk_mask_mask, ldmk_mask_viz, first_frame_path, ldmk_mask_radius], [ldmk_mask_mask, ldmk_mask_viz, ldmk_mask_image]) - - run_button.click(DragNUWA_net.run, [first_frame_path, driven_video_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render], [hint_image, traj_flows_gif, ldmk_flows_gif, viz_ldmk_gif, outputs_gif, traj_flows_mp4, ldmk_flows_mp4, viz_ldmk_mp4, outputs_mp4]) - - # demo.launch(server_name="0.0.0.0", debug=True, server_port=80) - demo.launch(server_name="127.0.0.1", debug=True, server_port=9080) diff --git a/sadtalker_audio2pose/.DS_Store b/sadtalker_audio2pose/.DS_Store deleted file mode 100644 index 27cf1b75b1e9bcb455443bb0050ad3478243664f..0000000000000000000000000000000000000000 Binary files a/sadtalker_audio2pose/.DS_Store and /dev/null differ diff --git a/sadtalker_audio2pose/inference.py b/sadtalker_audio2pose/inference.py deleted file mode 100644 index 6198f61aa619cd7ee7a0ce973541a2c78747011b..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/inference.py +++ /dev/null @@ -1,188 +0,0 @@ -from glob import glob -import shutil -import torch -from time import strftime -import os, sys, time -from argparse import ArgumentParser -import platform - -from src.utils.preprocess import CropAndExtract -from src.test_audio2coeff import Audio2Coeff -from src.facerender.animate import AnimateFromCoeff -from src.facerender.pirender_animate import AnimateFromCoeff_PIRender -from src.generate_batch import get_data -from src.generate_facerender_batch import get_facerender_data -from src.utils.init_path import init_path - -import random -import numpy as np - - -def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def main(args): - #torch.backends.cudnn.enabled = False - - set_seed(42) - - # args.facerender = 'pirender' - - - - pic_path = args.source_image - audio_path = args.driven_audio - save_dir = args.result_dir - os.makedirs(save_dir, exist_ok=True) - pose_style = args.pose_style - device = args.device - batch_size = args.batch_size - input_yaw_list = args.input_yaw - input_pitch_list = args.input_pitch - input_roll_list = args.input_roll - ref_eyeblink = args.ref_eyeblink - ref_pose = args.ref_pose - - # print(args.still) - # assert False - - current_root_path = os.path.split(sys.argv[0])[0] - - sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess) - - #init model - preprocess_model = CropAndExtract(sadtalker_paths, device) - - audio_to_coeff = Audio2Coeff(sadtalker_paths, device) - - if args.facerender == 'facevid2vid': - animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) - elif args.facerender == 'pirender': - animate_from_coeff = AnimateFromCoeff_PIRender(sadtalker_paths, device) - else: - raise(RuntimeError('Unknown model: {}'.format(args.facerender))) - - #crop image and extract 3dmm from image - first_frame_dir = os.path.join(save_dir, 'first_frame_dir') - os.makedirs(first_frame_dir, exist_ok=True) - print('3DMM Extraction for source image') - first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\ - source_image_flag=True, pic_size=args.size) - if first_coeff_path is None: - print("Can't get the coeffs of the input") - return - - if ref_eyeblink is not None: - ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0] - ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname) - os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) - print('3DMM Extraction for the reference video providing eye blinking') - ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False) - else: - ref_eyeblink_coeff_path=None - - if ref_pose is not None: - if ref_pose == ref_eyeblink: - ref_pose_coeff_path = ref_eyeblink_coeff_path - else: - ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] - ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname) - os.makedirs(ref_pose_frame_dir, exist_ok=True) - print('3DMM Extraction for the reference video providing pose') - ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False) - else: - ref_pose_coeff_path=None - - #audio2ceoff - batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still) - coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) - - # print(ref_pose_coeff_path.shape) - # print(coeff_path.shape) - - # assert False - - # 3dface render - if args.face3dvis: - from src.face3d.visualize import gen_composed_video - gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, \ - os.path.join(save_dir, '3dface.mp4'), os.path.join(save_dir, 'landmarks.mp4'), crop_info, extended_crop= True if 'ext' in args.preprocess else False ) - return - - #coeff2video - data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, - batch_size, input_yaw_list, input_pitch_list, input_roll_list, - expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size, facemodel=args.facerender) - - result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \ - enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size) - - shutil.move(result, save_dir+'.mp4') - print('The generated video is named:', save_dir+'.mp4') - - # result = animate_from_coeff.generate_flow(data, args.result_dir, pic_path, crop_info, \ - # enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size) - - # if not args.verbose: - # shutil.rmtree(save_dir) - - -if __name__ == '__main__': - - parser = ArgumentParser() - parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio") - parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image") - parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking") - parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose") - parser.add_argument("--checkpoint_dir", default='./ckpts/sad_talker', help="path to output") - parser.add_argument("--result_dir", default='./results', help="path to output") - parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)") - parser.add_argument("--batch_size", type=int, default=1, help="the batch size of facerender") - parser.add_argument("--size", type=int, default=256, help="the image size of the facerender") - parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender") - parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ") - parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user") - parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user") - parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]") - parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]") - parser.add_argument("--cpu", dest="cpu", action="store_true") - parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks") - parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion") - parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" ) - parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" ) - parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" ) - parser.add_argument("--facerender", default='facevid2vid', choices=['pirender', 'facevid2vid'] ) - - - # net structure and parameters - parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless') - parser.add_argument('--init_path', type=str, default=None, help='Useless') - parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc') - parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talker/BFM_Fitting/') - parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') - - # default renderer parameters - parser.add_argument('--focal', type=float, default=1015.) - parser.add_argument('--center', type=float, default=112.) - parser.add_argument('--camera_d', type=float, default=10.) - parser.add_argument('--z_near', type=float, default=5.) - parser.add_argument('--z_far', type=float, default=15.) - - args = parser.parse_args() - - if torch.cuda.is_available() and not args.cpu: - args.device = "cuda" - elif platform.system() == 'Darwin' and args.facerender == 'pirender': # macos - args.device = "mps" - else: - args.device = "cpu" - - main(args) - diff --git a/sadtalker_audio2pose/src/.DS_Store b/sadtalker_audio2pose/src/.DS_Store deleted file mode 100644 index e3eba8349f8b3d6836329847e5a266205475acf2..0000000000000000000000000000000000000000 Binary files a/sadtalker_audio2pose/src/.DS_Store and /dev/null differ diff --git a/sadtalker_audio2pose/src/audio2exp_models/audio2exp.py b/sadtalker_audio2pose/src/audio2exp_models/audio2exp.py deleted file mode 100644 index e1062ab6684df01e0b3c48b6b577cc8df0503c91..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2exp_models/audio2exp.py +++ /dev/null @@ -1,41 +0,0 @@ -from tqdm import tqdm -import torch -from torch import nn - - -class Audio2Exp(nn.Module): - def __init__(self, netG, cfg, device, prepare_training_loss=False): - super(Audio2Exp, self).__init__() - self.cfg = cfg - self.device = device - self.netG = netG.to(device) - - def test(self, batch): - - mel_input = batch['indiv_mels'] # bs T 1 80 16 - bs = mel_input.shape[0] - T = mel_input.shape[1] - - exp_coeff_pred = [] - - for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames - - current_mel_input = mel_input[:,i:i+10] - - #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 - ref = batch['ref'][:, :, :64][:, i:i+10] - ratio = batch['ratio_gt'][:, i:i+10] #bs T - - audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 - - curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 - - exp_coeff_pred += [curr_exp_coeff_pred] - - # BS x T x 64 - results_dict = { - 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) - } - return results_dict - - diff --git a/sadtalker_audio2pose/src/audio2exp_models/networks.py b/sadtalker_audio2pose/src/audio2exp_models/networks.py deleted file mode 100644 index cd77a2f48d7c00ce85fe2eefe3a3e820730fbb74..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2exp_models/networks.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - -class Conv2d(nn.Module): - def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): - super().__init__(*args, **kwargs) - self.conv_block = nn.Sequential( - nn.Conv2d(cin, cout, kernel_size, stride, padding), - nn.BatchNorm2d(cout) - ) - self.act = nn.ReLU() - self.residual = residual - self.use_act = use_act - - def forward(self, x): - out = self.conv_block(x) - if self.residual: - out += x - - if self.use_act: - return self.act(out) - else: - return out - -class SimpleWrapperV2(nn.Module): - def __init__(self) -> None: - super().__init__() - self.audio_encoder = nn.Sequential( - Conv2d(1, 32, kernel_size=3, stride=1, padding=1), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(64, 128, kernel_size=3, stride=3, padding=1), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(256, 512, kernel_size=3, stride=1, padding=0), - Conv2d(512, 512, kernel_size=1, stride=1, padding=0), - ) - - #### load the pre-trained audio_encoder - #self.audio_encoder = self.audio_encoder.to(device) - ''' - wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] - state_dict = self.audio_encoder.state_dict() - - for k,v in wav2lip_state_dict.items(): - if 'audio_encoder' in k: - print('init:', k) - state_dict[k.replace('module.audio_encoder.', '')] = v - self.audio_encoder.load_state_dict(state_dict) - ''' - - self.mapping1 = nn.Linear(512+64+1, 64) - #self.mapping2 = nn.Linear(30, 64) - #nn.init.constant_(self.mapping1.weight, 0.) - nn.init.constant_(self.mapping1.bias, 0.) - - def forward(self, x, ref, ratio): - x = self.audio_encoder(x).view(x.size(0), -1) - ref_reshape = ref.reshape(x.size(0), -1) - ratio = ratio.reshape(x.size(0), -1) - - y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) - out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial - return out diff --git a/sadtalker_audio2pose/src/audio2pose_models/audio2pose.py b/sadtalker_audio2pose/src/audio2pose_models/audio2pose.py deleted file mode 100644 index 53883adc508037294ba664d05d34e5459f1879f8..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2pose_models/audio2pose.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from torch import nn -from src.audio2pose_models.cvae import CVAE -from src.audio2pose_models.discriminator import PoseSequenceDiscriminator -from src.audio2pose_models.audio_encoder import AudioEncoder - -class Audio2Pose(nn.Module): - def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): - super().__init__() - self.cfg = cfg - self.seq_len = cfg.MODEL.CVAE.SEQ_LEN - self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE - self.device = device - - self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) - self.audio_encoder.eval() - for param in self.audio_encoder.parameters(): - param.requires_grad = False - - self.netG = CVAE(cfg) - self.netD_motion = PoseSequenceDiscriminator(cfg) - - - def forward(self, x): - - batch = {} - coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 - batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6 - batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6 - batch['class'] = x['class'].squeeze(0).cuda() # bs - indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 - - # forward - audio_emb_list = [] - audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 - batch['audio_emb'] = audio_emb - batch = self.netG(batch) - - pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 - pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6 - pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6 - - batch['pose_pred'] = pose_pred - batch['pose_gt'] = pose_gt - - return batch - - def test(self, x): - - batch = {} - ref = x['ref'] #bs 1 70 - batch['ref'] = x['ref'][:,0,-6:] - batch['class'] = x['class'] - bs = ref.shape[0] - - indiv_mels= x['indiv_mels'] # bs T 1 80 16 - indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame - num_frames = x['num_frames'] - num_frames = int(num_frames) - 1 - - # - div = num_frames//self.seq_len - re = num_frames%self.seq_len - audio_emb_list = [] - pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, - device=batch['ref'].device)] - - for i in range(div): - z = torch.randn(bs, self.latent_dim).to(ref.device) - batch['z'] = z - audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 - batch['audio_emb'] = audio_emb - batch = self.netG.test(batch) - pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 - - if re != 0: - z = torch.randn(bs, self.latent_dim).to(ref.device) - batch['z'] = z - audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 - if audio_emb.shape[1] != self.seq_len: - pad_dim = self.seq_len-audio_emb.shape[1] - pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) - audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) - batch['audio_emb'] = audio_emb - batch = self.netG.test(batch) - pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) - - pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) - batch['pose_motion_pred'] = pose_motion_pred - - pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 - - batch['pose_pred'] = pose_pred - return batch diff --git a/sadtalker_audio2pose/src/audio2pose_models/audio_encoder.py b/sadtalker_audio2pose/src/audio2pose_models/audio_encoder.py deleted file mode 100644 index a0c165afbc25910cb66828d8676973fe727cb3a3..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2pose_models/audio_encoder.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -class Conv2d(nn.Module): - def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): - super().__init__(*args, **kwargs) - self.conv_block = nn.Sequential( - nn.Conv2d(cin, cout, kernel_size, stride, padding), - nn.BatchNorm2d(cout) - ) - self.act = nn.ReLU() - self.residual = residual - - def forward(self, x): - out = self.conv_block(x) - if self.residual: - out += x - return self.act(out) - -class AudioEncoder(nn.Module): - def __init__(self, wav2lip_checkpoint, device): - super(AudioEncoder, self).__init__() - - self.audio_encoder = nn.Sequential( - Conv2d(1, 32, kernel_size=3, stride=1, padding=1), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(64, 128, kernel_size=3, stride=3, padding=1), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(256, 512, kernel_size=3, stride=1, padding=0), - Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) - - #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. - # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] - # state_dict = self.audio_encoder.state_dict() - - # for k,v in wav2lip_state_dict.items(): - # if 'audio_encoder' in k: - # state_dict[k.replace('module.audio_encoder.', '')] = v - # self.audio_encoder.load_state_dict(state_dict) - - - def forward(self, audio_sequences): - # audio_sequences = (B, T, 1, 80, 16) - B = audio_sequences.size(0) - - audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) - - audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 - dim = audio_embedding.shape[1] - audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) - - return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 diff --git a/sadtalker_audio2pose/src/audio2pose_models/cvae.py b/sadtalker_audio2pose/src/audio2pose_models/cvae.py deleted file mode 100644 index 407b78894cde564dd3f2819772a84e8bb1de251d..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2pose_models/cvae.py +++ /dev/null @@ -1,149 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn -from src.audio2pose_models.res_unet import ResUnet - -def class2onehot(idx, class_num): - - assert torch.max(idx).item() < class_num - onehot = torch.zeros(idx.size(0), class_num).to(idx.device) - onehot.scatter_(1, idx, 1) - return onehot - -class CVAE(nn.Module): - def __init__(self, cfg): - super().__init__() - encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES - decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES - latent_size = cfg.MODEL.CVAE.LATENT_SIZE - num_classes = cfg.DATASET.NUM_CLASSES - audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE - audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE - seq_len = cfg.MODEL.CVAE.SEQ_LEN - - self.latent_size = latent_size - - self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len) - self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len) - def reparameterize(self, mu, logvar): - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - - def forward(self, batch): - batch = self.encoder(batch) - mu = batch['mu'] - logvar = batch['logvar'] - z = self.reparameterize(mu, logvar) - batch['z'] = z - return self.decoder(batch) - - def test(self, batch): - ''' - class_id = batch['class'] - z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device) - batch['z'] = z - ''' - return self.decoder(batch) - -class ENCODER(nn.Module): - def __init__(self, layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len): - super().__init__() - - self.resunet = ResUnet() - self.num_classes = num_classes - self.seq_len = seq_len - - self.MLP = nn.Sequential() - layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6 - for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): - self.MLP.add_module( - name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) - self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) - - self.linear_means = nn.Linear(layer_sizes[-1], latent_size) - self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size) - self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) - - self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) - - def forward(self, batch): - class_id = batch['class'] - pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6 - ref = batch['ref'] #bs 6 - bs = pose_motion_gt.shape[0] - audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size - - #pose encode - pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 - pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 - - #audio mapping - print(audio_in.shape) - audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size - audio_out = audio_out.reshape(bs, -1) - - class_bias = self.classbias[class_id] #bs latent_size - x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size - x_out = self.MLP(x_in) - - mu = self.linear_means(x_out) - logvar = self.linear_means(x_out) #bs latent_size - - batch.update({'mu':mu, 'logvar':logvar}) - return batch - -class DECODER(nn.Module): - def __init__(self, layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len): - super().__init__() - - self.resunet = ResUnet() - self.num_classes = num_classes - self.seq_len = seq_len - - self.MLP = nn.Sequential() - input_size = latent_size + seq_len*audio_emb_out_size + 6 - for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): - self.MLP.add_module( - name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) - if i+1 < len(layer_sizes): - self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) - else: - self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) - - self.pose_linear = nn.Linear(6, 6) - self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) - - self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) - - def forward(self, batch): - - z = batch['z'] #bs latent_size - bs = z.shape[0] - class_id = batch['class'] - ref = batch['ref'] #bs 6 - audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size - #print('audio_in: ', audio_in[:, :, :10]) - - audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size - #print('audio_out: ', audio_out[:, :, :10]) - audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size - class_bias = self.classbias[class_id] #bs latent_size - - z = z + class_bias - x_in = torch.cat([ref, z, audio_out], dim=-1) - x_out = self.MLP(x_in) # bs layer_sizes[-1] - x_out = x_out.reshape((bs, self.seq_len, -1)) - - #print('x_out: ', x_out) - - pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 - - pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 - - batch.update({'pose_motion_pred':pose_motion_pred}) - return batch diff --git a/sadtalker_audio2pose/src/audio2pose_models/discriminator.py b/sadtalker_audio2pose/src/audio2pose_models/discriminator.py deleted file mode 100644 index 2f8ed6e36708d4a70227ff90109f56c6f73a17d2..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2pose_models/discriminator.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - -class ConvNormRelu(nn.Module): - def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, - kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): - super().__init__() - if kernel_size is None: - if downsample: - kernel_size, stride, padding = 4, 2, 1 - else: - kernel_size, stride, padding = 3, 1, 1 - - if conv_type == '2d': - self.conv = nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - bias=False, - ) - if norm == 'BN': - self.norm = nn.BatchNorm2d(out_channels) - elif norm == 'IN': - self.norm = nn.InstanceNorm2d(out_channels) - else: - raise NotImplementedError - elif conv_type == '1d': - self.conv = nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - bias=False, - ) - if norm == 'BN': - self.norm = nn.BatchNorm1d(out_channels) - elif norm == 'IN': - self.norm = nn.InstanceNorm1d(out_channels) - else: - raise NotImplementedError - nn.init.kaiming_normal_(self.conv.weight) - - self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) - - def forward(self, x): - x = self.conv(x) - if isinstance(self.norm, nn.InstanceNorm1d): - x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] - else: - x = self.norm(x) - x = self.act(x) - return x - - -class PoseSequenceDiscriminator(nn.Module): - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU - - self.seq = nn.Sequential( - ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 - ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 - ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 - nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 - ) - - def forward(self, x): - x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) - x = self.seq(x) - x = x.squeeze(1) - return x \ No newline at end of file diff --git a/sadtalker_audio2pose/src/audio2pose_models/networks.py b/sadtalker_audio2pose/src/audio2pose_models/networks.py deleted file mode 100644 index 9212b49836d9221895993d1d490a476707599922..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2pose_models/networks.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch.nn as nn -import torch - - -class ResidualConv(nn.Module): - def __init__(self, input_dim, output_dim, stride, padding): - super(ResidualConv, self).__init__() - - self.conv_block = nn.Sequential( - nn.BatchNorm2d(input_dim), - nn.ReLU(), - nn.Conv2d( - input_dim, output_dim, kernel_size=3, stride=stride, padding=padding - ), - nn.BatchNorm2d(output_dim), - nn.ReLU(), - nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), - ) - self.conv_skip = nn.Sequential( - nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), - nn.BatchNorm2d(output_dim), - ) - - def forward(self, x): - - return self.conv_block(x) + self.conv_skip(x) - - -class Upsample(nn.Module): - def __init__(self, input_dim, output_dim, kernel, stride): - super(Upsample, self).__init__() - - self.upsample = nn.ConvTranspose2d( - input_dim, output_dim, kernel_size=kernel, stride=stride - ) - - def forward(self, x): - return self.upsample(x) - - -class Squeeze_Excite_Block(nn.Module): - def __init__(self, channel, reduction=16): - super(Squeeze_Excite_Block, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction, bias=False), - nn.ReLU(inplace=True), - nn.Linear(channel // reduction, channel, bias=False), - nn.Sigmoid(), - ) - - def forward(self, x): - b, c, _, _ = x.size() - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - return x * y.expand_as(x) - - -class ASPP(nn.Module): - def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): - super(ASPP, self).__init__() - - self.aspp_block1 = nn.Sequential( - nn.Conv2d( - in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] - ), - nn.ReLU(inplace=True), - nn.BatchNorm2d(out_dims), - ) - self.aspp_block2 = nn.Sequential( - nn.Conv2d( - in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] - ), - nn.ReLU(inplace=True), - nn.BatchNorm2d(out_dims), - ) - self.aspp_block3 = nn.Sequential( - nn.Conv2d( - in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] - ), - nn.ReLU(inplace=True), - nn.BatchNorm2d(out_dims), - ) - - self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) - self._init_weights() - - def forward(self, x): - x1 = self.aspp_block1(x) - x2 = self.aspp_block2(x) - x3 = self.aspp_block3(x) - out = torch.cat([x1, x2, x3], dim=1) - return self.output(out) - - def _init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - -class Upsample_(nn.Module): - def __init__(self, scale=2): - super(Upsample_, self).__init__() - - self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) - - def forward(self, x): - return self.upsample(x) - - -class AttentionBlock(nn.Module): - def __init__(self, input_encoder, input_decoder, output_dim): - super(AttentionBlock, self).__init__() - - self.conv_encoder = nn.Sequential( - nn.BatchNorm2d(input_encoder), - nn.ReLU(), - nn.Conv2d(input_encoder, output_dim, 3, padding=1), - nn.MaxPool2d(2, 2), - ) - - self.conv_decoder = nn.Sequential( - nn.BatchNorm2d(input_decoder), - nn.ReLU(), - nn.Conv2d(input_decoder, output_dim, 3, padding=1), - ) - - self.conv_attn = nn.Sequential( - nn.BatchNorm2d(output_dim), - nn.ReLU(), - nn.Conv2d(output_dim, 1, 1), - ) - - def forward(self, x1, x2): - out = self.conv_encoder(x1) + self.conv_decoder(x2) - out = self.conv_attn(out) - return out * x2 \ No newline at end of file diff --git a/sadtalker_audio2pose/src/audio2pose_models/res_unet.py b/sadtalker_audio2pose/src/audio2pose_models/res_unet.py deleted file mode 100644 index 280404c2a2804038705f792dd800ddf707b75cf8..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/audio2pose_models/res_unet.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch.nn as nn -from src.audio2pose_models.networks import ResidualConv, Upsample - - -class ResUnet(nn.Module): - def __init__(self, channel=1, filters=[32, 64, 128, 256]): - super(ResUnet, self).__init__() - - self.input_layer = nn.Sequential( - nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), - nn.BatchNorm2d(filters[0]), - nn.ReLU(), - nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), - ) - self.input_skip = nn.Sequential( - nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) - ) - - self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) - self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) - - self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) - - self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) - self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) - - self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) - self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) - - self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) - self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) - - self.output_layer = nn.Sequential( - nn.Conv2d(filters[0], 1, 1, 1), - nn.Sigmoid(), - ) - - def forward(self, x): - # Encode - x1 = self.input_layer(x) + self.input_skip(x) - x2 = self.residual_conv_1(x1) - x3 = self.residual_conv_2(x2) - # Bridge - x4 = self.bridge(x3) - - # Decode - x4 = self.upsample_1(x4) - x5 = torch.cat([x4, x3], dim=1) - - x6 = self.up_residual_conv1(x5) - - x6 = self.upsample_2(x6) - x7 = torch.cat([x6, x2], dim=1) - - x8 = self.up_residual_conv2(x7) - - x8 = self.upsample_3(x8) - x9 = torch.cat([x8, x1], dim=1) - - x10 = self.up_residual_conv3(x9) - - output = self.output_layer(x10) - - return output \ No newline at end of file diff --git a/sadtalker_audio2pose/src/config/auido2exp.yaml b/sadtalker_audio2pose/src/config/auido2exp.yaml deleted file mode 100644 index 7e0e8fbba267158d26a147c8cb2ec5acdd73f432..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/config/auido2exp.yaml +++ /dev/null @@ -1,58 +0,0 @@ -DATASET: - TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt - EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt - TRAIN_BATCH_SIZE: 32 - EVAL_BATCH_SIZE: 32 - EXP: True - EXP_DIM: 64 - FRAME_LEN: 32 - COEFF_LEN: 73 - NUM_CLASSES: 46 - AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav - COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm - LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb - DEBUG: True - NUM_REPEATS: 2 - T: 40 - - -MODEL: - FRAMEWORK: V2 - AUDIOENCODER: - LEAKY_RELU: True - NORM: 'IN' - DISCRIMINATOR: - LEAKY_RELU: False - INPUT_CHANNELS: 6 - CVAE: - AUDIO_EMB_IN_SIZE: 512 - AUDIO_EMB_OUT_SIZE: 128 - SEQ_LEN: 32 - LATENT_SIZE: 256 - ENCODER_LAYER_SIZES: [192, 1024] - DECODER_LAYER_SIZES: [1024, 192] - - -TRAIN: - MAX_EPOCH: 300 - GENERATOR: - LR: 2.0e-5 - DISCRIMINATOR: - LR: 1.0e-5 - LOSS: - W_FEAT: 0 - W_COEFF_EXP: 2 - W_LM: 1.0e-2 - W_LM_MOUTH: 0 - W_REG: 0 - W_SYNC: 0 - W_COLOR: 0 - W_EXPRESSION: 0 - W_LIPREADING: 0.01 - W_LIPREADING_VV: 0 - W_EYE_BLINK: 4 - -TAG: - NAME: small_dataset - - diff --git a/sadtalker_audio2pose/src/config/auido2pose.yaml b/sadtalker_audio2pose/src/config/auido2pose.yaml deleted file mode 100644 index 7702414b11581ff99aef7a3187f0d0d1388ae3f3..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/config/auido2pose.yaml +++ /dev/null @@ -1,49 +0,0 @@ -DATASET: - TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt - EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt - TRAIN_BATCH_SIZE: 64 - EVAL_BATCH_SIZE: 1 - EXP: True - EXP_DIM: 64 - FRAME_LEN: 32 - COEFF_LEN: 73 - NUM_CLASSES: 46 - AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav - COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb - DEBUG: True - - -MODEL: - AUDIOENCODER: - LEAKY_RELU: True - NORM: 'IN' - DISCRIMINATOR: - LEAKY_RELU: False - INPUT_CHANNELS: 6 - CVAE: - AUDIO_EMB_IN_SIZE: 512 - AUDIO_EMB_OUT_SIZE: 6 - SEQ_LEN: 32 - LATENT_SIZE: 64 - ENCODER_LAYER_SIZES: [192, 128] - DECODER_LAYER_SIZES: [128, 192] - - -TRAIN: - MAX_EPOCH: 150 - GENERATOR: - LR: 1.0e-4 - DISCRIMINATOR: - LR: 1.0e-4 - LOSS: - LAMBDA_REG: 1 - LAMBDA_LANDMARKS: 0 - LAMBDA_VERTICES: 0 - LAMBDA_GAN_MOTION: 0.7 - LAMBDA_GAN_COEFF: 0 - LAMBDA_KL: 1 - -TAG: - NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder - - diff --git a/sadtalker_audio2pose/src/config/facerender.yaml b/sadtalker_audio2pose/src/config/facerender.yaml deleted file mode 100644 index dd1e1ddfe265698e49dac4a6e103cba0aac4f3ce..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/config/facerender.yaml +++ /dev/null @@ -1,45 +0,0 @@ -model_params: - common_params: - num_kp: 15 - image_channel: 3 - feature_channel: 32 - estimate_jacobian: False # True - kp_detector_params: - temperature: 0.1 - block_expansion: 32 - max_features: 1024 - scale_factor: 0.25 # 0.25 - num_blocks: 5 - reshape_channel: 16384 # 16384 = 1024 * 16 - reshape_depth: 16 - he_estimator_params: - block_expansion: 64 - max_features: 2048 - num_bins: 66 - generator_params: - block_expansion: 64 - max_features: 512 - num_down_blocks: 2 - reshape_channel: 32 - reshape_depth: 16 # 512 = 32 * 16 - num_resblocks: 6 - estimate_occlusion_map: True - dense_motion_params: - block_expansion: 32 - max_features: 1024 - num_blocks: 5 - reshape_depth: 16 - compress: 4 - discriminator_params: - scales: [1] - block_expansion: 32 - max_features: 512 - num_blocks: 4 - sn: True - mapping_params: - coeff_nc: 70 - descriptor_nc: 1024 - layer: 3 - num_kp: 15 - num_bins: 66 - diff --git a/sadtalker_audio2pose/src/config/facerender_pirender.yaml b/sadtalker_audio2pose/src/config/facerender_pirender.yaml deleted file mode 100644 index f893b5d0a22f0546642c2d2bdafda88740c81138..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/config/facerender_pirender.yaml +++ /dev/null @@ -1,83 +0,0 @@ -# How often do you want to log the training stats. -# network_list: -# gen: gen_optimizer -# dis: dis_optimizer - -distributed: False -image_to_tensorboard: True -snapshot_save_iter: 40000 -snapshot_save_epoch: 20 -snapshot_save_start_iter: 20000 -snapshot_save_start_epoch: 10 -image_save_iter: 1000 -max_epoch: 200 -logging_iter: 100 -results_dir: ./eval_results - -gen_optimizer: - type: adam - lr: 0.0001 - adam_beta1: 0.5 - adam_beta2: 0.999 - lr_policy: - iteration_mode: True - type: step - step_size: 300000 - gamma: 0.2 - -trainer: - type: trainers.face_trainer::FaceTrainer - pretrain_warp_iteration: 200000 - loss_weight: - weight_perceptual_warp: 2.5 - weight_perceptual_final: 4 - vgg_param_warp: - network: vgg19 - layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] - use_style_loss: False - num_scales: 4 - vgg_param_final: - network: vgg19 - layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] - use_style_loss: True - num_scales: 4 - style_to_perceptual: 250 - init: - type: 'normal' - gain: 0.02 -gen: - type: generators.face_model::FaceGenerator - param: - mapping_net: - coeff_nc: 73 - descriptor_nc: 256 - layer: 3 - warpping_net: - encoder_layer: 5 - decoder_layer: 3 - base_nc: 32 - editing_net: - layer: 3 - num_res_blocks: 2 - base_nc: 64 - common: - image_nc: 3 - descriptor_nc: 256 - max_nc: 256 - use_spect: False - - -# Data options. -data: - type: data.vox_dataset::VoxDataset - path: ./dataset/vox_lmdb - resolution: 256 - semantic_radius: 13 - train: - batch_size: 5 - distributed: True - val: - batch_size: 8 - distributed: True - - diff --git a/sadtalker_audio2pose/src/config/facerender_still.yaml b/sadtalker_audio2pose/src/config/facerender_still.yaml deleted file mode 100644 index d6b84181763caf7184a0769e53a7e419e2e3f604..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/config/facerender_still.yaml +++ /dev/null @@ -1,45 +0,0 @@ -model_params: - common_params: - num_kp: 15 - image_channel: 3 - feature_channel: 32 - estimate_jacobian: False # True - kp_detector_params: - temperature: 0.1 - block_expansion: 32 - max_features: 1024 - scale_factor: 0.25 # 0.25 - num_blocks: 5 - reshape_channel: 16384 # 16384 = 1024 * 16 - reshape_depth: 16 - he_estimator_params: - block_expansion: 64 - max_features: 2048 - num_bins: 66 - generator_params: - block_expansion: 64 - max_features: 512 - num_down_blocks: 2 - reshape_channel: 32 - reshape_depth: 16 # 512 = 32 * 16 - num_resblocks: 6 - estimate_occlusion_map: True - dense_motion_params: - block_expansion: 32 - max_features: 1024 - num_blocks: 5 - reshape_depth: 16 - compress: 4 - discriminator_params: - scales: [1] - block_expansion: 32 - max_features: 512 - num_blocks: 4 - sn: True - mapping_params: - coeff_nc: 73 - descriptor_nc: 1024 - layer: 3 - num_kp: 15 - num_bins: 66 - diff --git a/sadtalker_audio2pose/src/config/similarity_Lm3D_all.mat b/sadtalker_audio2pose/src/config/similarity_Lm3D_all.mat deleted file mode 100644 index 9f5b0bd4ecffb926128a29cb1bbf9d9081c3d4e7..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/config/similarity_Lm3D_all.mat +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:53b83ce6e35c50ddc3e97603650cef4970320c157e75c241c844f29c1dcba65a -size 994 diff --git a/sadtalker_audio2pose/src/face3d/data/__init__.py b/sadtalker_audio2pose/src/face3d/data/__init__.py deleted file mode 100644 index be2378c5877af8e749db18d8a67a382f3eb0912b..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/data/__init__.py +++ /dev/null @@ -1,116 +0,0 @@ -"""This package includes all the modules related to data loading and preprocessing - - To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. - You need to implement four functions: - -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). - -- <__len__>: return the size of dataset. - -- <__getitem__>: get a data point from data loader. - -- : (optionally) add dataset-specific options and set default options. - -Now you can use the dataset class by specifying flag '--dataset_mode dummy'. -See our template dataset class 'template_dataset.py' for more details. -""" -import numpy as np -import importlib -import torch.utils.data -from face3d.data.base_dataset import BaseDataset - - -def find_dataset_using_name(dataset_name): - """Import the module "data/[dataset_name]_dataset.py". - - In the file, the class called DatasetNameDataset() will - be instantiated. It has to be a subclass of BaseDataset, - and it is case-insensitive. - """ - dataset_filename = "data." + dataset_name + "_dataset" - datasetlib = importlib.import_module(dataset_filename) - - dataset = None - target_dataset_name = dataset_name.replace('_', '') + 'dataset' - for name, cls in datasetlib.__dict__.items(): - if name.lower() == target_dataset_name.lower() \ - and issubclass(cls, BaseDataset): - dataset = cls - - if dataset is None: - raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) - - return dataset - - -def get_option_setter(dataset_name): - """Return the static method of the dataset class.""" - dataset_class = find_dataset_using_name(dataset_name) - return dataset_class.modify_commandline_options - - -def create_dataset(opt, rank=0): - """Create a dataset given the option. - - This function wraps the class CustomDatasetDataLoader. - This is the main interface between this package and 'train.py'/'test.py' - - Example: - >>> from data import create_dataset - >>> dataset = create_dataset(opt) - """ - data_loader = CustomDatasetDataLoader(opt, rank=rank) - dataset = data_loader.load_data() - return dataset - -class CustomDatasetDataLoader(): - """Wrapper class of Dataset class that performs multi-threaded data loading""" - - def __init__(self, opt, rank=0): - """Initialize this class - - Step 1: create a dataset instance given the name [dataset_mode] - Step 2: create a multi-threaded data loader. - """ - self.opt = opt - dataset_class = find_dataset_using_name(opt.dataset_mode) - self.dataset = dataset_class(opt) - self.sampler = None - print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) - if opt.use_ddp and opt.isTrain: - world_size = opt.world_size - self.sampler = torch.utils.data.distributed.DistributedSampler( - self.dataset, - num_replicas=world_size, - rank=rank, - shuffle=not opt.serial_batches - ) - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - sampler=self.sampler, - num_workers=int(opt.num_threads / world_size), - batch_size=int(opt.batch_size / world_size), - drop_last=True) - else: - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=opt.batch_size, - shuffle=(not opt.serial_batches) and opt.isTrain, - num_workers=int(opt.num_threads), - drop_last=True - ) - - def set_epoch(self, epoch): - self.dataset.current_epoch = epoch - if self.sampler is not None: - self.sampler.set_epoch(epoch) - - def load_data(self): - return self - - def __len__(self): - """Return the number of data in the dataset""" - return min(len(self.dataset), self.opt.max_dataset_size) - - def __iter__(self): - """Return a batch of data""" - for i, data in enumerate(self.dataloader): - if i * self.opt.batch_size >= self.opt.max_dataset_size: - break - yield data diff --git a/sadtalker_audio2pose/src/face3d/data/base_dataset.py b/sadtalker_audio2pose/src/face3d/data/base_dataset.py deleted file mode 100644 index 34a7ea5024206e6e58c2f404ac6a1bf0987f5fd4..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/data/base_dataset.py +++ /dev/null @@ -1,125 +0,0 @@ -"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. - -It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. -""" -import random -import numpy as np -import torch.utils.data as data -from PIL import Image -import torchvision.transforms as transforms -from abc import ABC, abstractmethod - - -class BaseDataset(data.Dataset, ABC): - """This class is an abstract base class (ABC) for datasets. - - To create a subclass, you need to implement the following four functions: - -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). - -- <__len__>: return the size of dataset. - -- <__getitem__>: get a data point. - -- : (optionally) add dataset-specific options and set default options. - """ - - def __init__(self, opt): - """Initialize the class; save the options in the class - - Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - self.opt = opt - # self.root = opt.dataroot - self.current_epoch = 0 - - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new dataset-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - return parser - - @abstractmethod - def __len__(self): - """Return the total number of images in the dataset.""" - return 0 - - @abstractmethod - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index - - a random integer for data indexing - - Returns: - a dictionary of data with their names. It ususally contains the data itself and its metadata information. - """ - pass - - -def get_transform(grayscale=False): - transform_list = [] - if grayscale: - transform_list.append(transforms.Grayscale(1)) - transform_list += [transforms.ToTensor()] - return transforms.Compose(transform_list) - -def get_affine_mat(opt, size): - shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False - w, h = size - - if 'shift' in opt.preprocess: - shift_pixs = int(opt.shift_pixs) - shift_x = random.randint(-shift_pixs, shift_pixs) - shift_y = random.randint(-shift_pixs, shift_pixs) - if 'scale' in opt.preprocess: - scale = 1 + opt.scale_delta * (2 * random.random() - 1) - if 'rot' in opt.preprocess: - rot_angle = opt.rot_angle * (2 * random.random() - 1) - rot_rad = -rot_angle * np.pi/180 - if 'flip' in opt.preprocess: - flip = random.random() > 0.5 - - shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) - flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) - shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) - rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) - scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) - shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) - - affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin - affine_inv = np.linalg.inv(affine) - return affine, affine_inv, flip - -def apply_img_affine(img, affine_inv, method=Image.BICUBIC): - return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) - -def apply_lm_affine(landmark, affine, flip, size): - _, h = size - lm = landmark.copy() - lm[:, 1] = h - 1 - lm[:, 1] - lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) - lm = lm @ np.transpose(affine) - lm[:, :2] = lm[:, :2] / lm[:, 2:] - lm = lm[:, :2] - lm[:, 1] = h - 1 - lm[:, 1] - if flip: - lm_ = lm.copy() - lm_[:17] = lm[16::-1] - lm_[17:22] = lm[26:21:-1] - lm_[22:27] = lm[21:16:-1] - lm_[31:36] = lm[35:30:-1] - lm_[36:40] = lm[45:41:-1] - lm_[40:42] = lm[47:45:-1] - lm_[42:46] = lm[39:35:-1] - lm_[46:48] = lm[41:39:-1] - lm_[48:55] = lm[54:47:-1] - lm_[55:60] = lm[59:54:-1] - lm_[60:65] = lm[64:59:-1] - lm_[65:68] = lm[67:64:-1] - lm = lm_ - return lm diff --git a/sadtalker_audio2pose/src/face3d/data/flist_dataset.py b/sadtalker_audio2pose/src/face3d/data/flist_dataset.py deleted file mode 100644 index 63b49caa8020f8e9aedb73a839b7112320cad68a..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/data/flist_dataset.py +++ /dev/null @@ -1,125 +0,0 @@ -"""This script defines the custom dataset for Deep3DFaceRecon_pytorch -""" - -import os.path -from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine -from data.image_folder import make_dataset -from PIL import Image -import random -import util.util as util -import numpy as np -import json -import torch -from scipy.io import loadmat, savemat -import pickle -from util.preprocess import align_img, estimate_norm -from util.load_mats import load_lm3d - - -def default_flist_reader(flist): - """ - flist format: impath label\nimpath label\n ...(same to caffe's filelist) - """ - imlist = [] - with open(flist, 'r') as rf: - for line in rf.readlines(): - impath = line.strip() - imlist.append(impath) - - return imlist - -def jason_flist_reader(flist): - with open(flist, 'r') as fp: - info = json.load(fp) - return info - -def parse_label(label): - return torch.tensor(np.array(label).astype(np.float32)) - - -class FlistDataset(BaseDataset): - """ - It requires one directories to host training images '/path/to/data/train' - You can train the model with the dataset flag '--dataroot /path/to/data'. - """ - - def __init__(self, opt): - """Initialize this dataset class. - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - BaseDataset.__init__(self, opt) - - self.lm3d_std = load_lm3d(opt.bfm_folder) - - msk_names = default_flist_reader(opt.flist) - self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] - - self.size = len(self.msk_paths) - self.opt = opt - - self.name = 'train' if opt.isTrain else 'val' - if '_' in opt.flist: - self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] - - - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index (int) -- a random integer for data indexing - - Returns a dictionary that contains A, B, A_paths and B_paths - img (tensor) -- an image in the input domain - msk (tensor) -- its corresponding attention mask - lm (tensor) -- its corresponding 3d landmarks - im_paths (str) -- image paths - aug_flag (bool) -- a flag used to tell whether its raw or augmented - """ - msk_path = self.msk_paths[index % self.size] # make sure index is within then range - img_path = msk_path.replace('mask/', '') - lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' - - raw_img = Image.open(img_path).convert('RGB') - raw_msk = Image.open(msk_path).convert('RGB') - raw_lm = np.loadtxt(lm_path).astype(np.float32) - - _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) - - aug_flag = self.opt.use_aug and self.opt.isTrain - if aug_flag: - img, lm, msk = self._augmentation(img, lm, self.opt, msk) - - _, H = img.size - M = estimate_norm(lm, H) - transform = get_transform() - img_tensor = transform(img) - msk_tensor = transform(msk)[:1, ...] - lm_tensor = parse_label(lm) - M_tensor = parse_label(M) - - - return {'imgs': img_tensor, - 'lms': lm_tensor, - 'msks': msk_tensor, - 'M': M_tensor, - 'im_paths': img_path, - 'aug_flag': aug_flag, - 'dataset': self.name} - - def _augmentation(self, img, lm, opt, msk=None): - affine, affine_inv, flip = get_affine_mat(opt, img.size) - img = apply_img_affine(img, affine_inv) - lm = apply_lm_affine(lm, affine, flip, img.size) - if msk is not None: - msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) - return img, lm, msk - - - - - def __len__(self): - """Return the total number of images in the dataset. - """ - return self.size diff --git a/sadtalker_audio2pose/src/face3d/data/image_folder.py b/sadtalker_audio2pose/src/face3d/data/image_folder.py deleted file mode 100644 index 07ef069029b0db1fc40b9b5f9a6f52a48c1cd162..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/data/image_folder.py +++ /dev/null @@ -1,66 +0,0 @@ -"""A modified image folder class - -We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) -so that this class can load images from both current directory and its subdirectories. -""" -import numpy as np -import torch.utils.data as data - -from PIL import Image -import os -import os.path - -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', - '.tif', '.TIF', '.tiff', '.TIFF', -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def make_dataset(dir, max_dataset_size=float("inf")): - images = [] - assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir - - for root, _, fnames in sorted(os.walk(dir, followlinks=True)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - return images[:min(max_dataset_size, len(images))] - - -def default_loader(path): - return Image.open(path).convert('RGB') - - -class ImageFolder(data.Dataset): - - def __init__(self, root, transform=None, return_paths=False, - loader=default_loader): - imgs = make_dataset(root) - if len(imgs) == 0: - raise(RuntimeError("Found 0 images in: " + root + "\n" - "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) - - self.root = root - self.imgs = imgs - self.transform = transform - self.return_paths = return_paths - self.loader = loader - - def __getitem__(self, index): - path = self.imgs[index] - img = self.loader(path) - if self.transform is not None: - img = self.transform(img) - if self.return_paths: - return img, path - else: - return img - - def __len__(self): - return len(self.imgs) diff --git a/sadtalker_audio2pose/src/face3d/data/template_dataset.py b/sadtalker_audio2pose/src/face3d/data/template_dataset.py deleted file mode 100644 index 693b6b09085ad424e53f26e0938b61eea30ed644..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/data/template_dataset.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Dataset class template - -This module provides a template for users to implement custom datasets. -You can specify '--dataset_mode template' to use this dataset. -The class name should be consistent with both the filename and its dataset_mode option. -The filename should be _dataset.py -The class name should be Dataset.py -You need to implement the following functions: - -- : Add dataset-specific options and rewrite default values for existing options. - -- <__init__>: Initialize this dataset class. - -- <__getitem__>: Return a data point and its metadata information. - -- <__len__>: Return the number of images. -""" -from data.base_dataset import BaseDataset, get_transform -# from data.image_folder import make_dataset -# from PIL import Image - - -class TemplateDataset(BaseDataset): - """A template dataset class for you to implement custom datasets.""" - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new dataset-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') - parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values - return parser - - def __init__(self, opt): - """Initialize this dataset class. - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - - A few things can be done here. - - save the options (have been done in BaseDataset) - - get image paths and meta information of the dataset. - - define the image transformation. - """ - # save the option and dataset root - BaseDataset.__init__(self, opt) - # get the image paths of your dataset; - self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root - # define the default transform function. You can use ; You can also define your custom transform function - self.transform = get_transform(opt) - - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index -- a random integer for data indexing - - Returns: - a dictionary of data with their names. It usually contains the data itself and its metadata information. - - Step 1: get a random image path: e.g., path = self.image_paths[index] - Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). - Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) - Step 4: return a data point as a dictionary. - """ - path = 'temp' # needs to be a string - data_A = None # needs to be a tensor - data_B = None # needs to be a tensor - return {'data_A': data_A, 'data_B': data_B, 'path': path} - - def __len__(self): - """Return the total number of images.""" - return len(self.image_paths) diff --git a/sadtalker_audio2pose/src/face3d/extract_kp_videos.py b/sadtalker_audio2pose/src/face3d/extract_kp_videos.py deleted file mode 100644 index 68dd79badafd406113ee85cde83492b6c7c66a9b..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/extract_kp_videos.py +++ /dev/null @@ -1,108 +0,0 @@ -import os -import cv2 -import time -import glob -import argparse -import face_alignment -import numpy as np -from PIL import Image -from tqdm import tqdm -from itertools import cycle - -from torch.multiprocessing import Pool, Process, set_start_method - -class KeypointExtractor(): - def __init__(self, device): - self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, - device=device) - - def extract_keypoint(self, images, name=None, info=True): - if isinstance(images, list): - keypoints = [] - if info: - i_range = tqdm(images,desc='landmark Det:') - else: - i_range = images - - for image in i_range: - current_kp = self.extract_keypoint(image) - if np.mean(current_kp) == -1 and keypoints: - keypoints.append(keypoints[-1]) - else: - keypoints.append(current_kp[None]) - - keypoints = np.concatenate(keypoints, 0) - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - else: - while True: - try: - keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] - break - except RuntimeError as e: - if str(e).startswith('CUDA'): - print("Warning: out of memory, sleep for 1s") - time.sleep(1) - else: - print(e) - break - except TypeError: - print('No face detected in this image') - shape = [68, 2] - keypoints = -1. * np.ones(shape) - break - if name is not None: - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - -def read_video(filename): - frames = [] - cap = cv2.VideoCapture(filename) - while cap.isOpened(): - ret, frame = cap.read() - if ret: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = Image.fromarray(frame) - frames.append(frame) - else: - break - cap.release() - return frames - -def run(data): - filename, opt, device = data - os.environ['CUDA_VISIBLE_DEVICES'] = device - kp_extractor = KeypointExtractor() - images = read_video(filename) - name = filename.split('/')[-2:] - os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) - kp_extractor.extract_keypoint( - images, - name=os.path.join(opt.output_dir, name[-2], name[-1]) - ) - -if __name__ == '__main__': - set_start_method('spawn') - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--input_dir', type=str, help='the folder of the input files') - parser.add_argument('--output_dir', type=str, help='the folder of the output files') - parser.add_argument('--device_ids', type=str, default='0,1') - parser.add_argument('--workers', type=int, default=4) - - opt = parser.parse_args() - filenames = list() - VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} - VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) - extensions = VIDEO_EXTENSIONS - - for ext in extensions: - os.listdir(f'{opt.input_dir}') - print(f'{opt.input_dir}/*.{ext}') - filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) - print('Total number of videos:', len(filenames)) - pool = Pool(opt.workers) - args_list = cycle([opt]) - device_ids = opt.device_ids.split(",") - device_ids = cycle(device_ids) - for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): - None diff --git a/sadtalker_audio2pose/src/face3d/extract_kp_videos_safe.py b/sadtalker_audio2pose/src/face3d/extract_kp_videos_safe.py deleted file mode 100644 index bbe5a01151d3e3722b4a6e3e041fd4f352eee9e8..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/extract_kp_videos_safe.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -import cv2 -import time -import glob -import argparse -import numpy as np -from PIL import Image -import torch -from tqdm import tqdm -from itertools import cycle -from torch.multiprocessing import Pool, Process, set_start_method - -from facexlib.alignment import landmark_98_to_68 -from facexlib.detection import init_detection_model - -from facexlib.utils import load_file_from_url -from facexlib.alignment.awing_arch import FAN - -def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): - if model_name == 'awing_fan': - model = FAN(num_modules=4, num_landmarks=98, device=device) - model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' - else: - raise NotImplementedError(f'{model_name} is not implemented.') - - model_path = load_file_from_url( - url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) - model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) - model.eval() - model = model.to(device) - return model - - -class KeypointExtractor(): - def __init__(self, device='cuda'): - - ### gfpgan/weights - root_path = 'ckpts/gfpgan' - - self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path) - self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path) - - def extract_keypoint(self, images, name=None, info=True): - if isinstance(images, list): - keypoints = [] - if info: - i_range = tqdm(images,desc='landmark Det:') - else: - i_range = images - - for image in i_range: - current_kp = self.extract_keypoint(image) - # current_kp = self.detector.get_landmarks(np.array(image)) - if np.mean(current_kp) == -1 and keypoints: - keypoints.append(keypoints[-1]) - else: - keypoints.append(current_kp[None]) - - keypoints = np.concatenate(keypoints, 0) - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - else: - while True: - try: - with torch.no_grad(): - # face detection -> face alignment. - img = np.array(images) - bboxes = self.det_net.detect_faces(images, 0.97) - - bboxes = bboxes[0] - img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] - - keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] - - #### keypoints to the original location - keypoints[:,0] += int(bboxes[0]) - keypoints[:,1] += int(bboxes[1]) - - break - except RuntimeError as e: - if str(e).startswith('CUDA'): - print("Warning: out of memory, sleep for 1s") - time.sleep(1) - else: - print(e) - break - except TypeError: - print('No face detected in this image') - shape = [68, 2] - keypoints = -1. * np.ones(shape) - break - if name is not None: - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - -def read_video(filename): - frames = [] - cap = cv2.VideoCapture(filename) - while cap.isOpened(): - ret, frame = cap.read() - if ret: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = Image.fromarray(frame) - frames.append(frame) - else: - break - cap.release() - return frames - -def run(data): - filename, opt, device = data - os.environ['CUDA_VISIBLE_DEVICES'] = device - kp_extractor = KeypointExtractor() - images = read_video(filename) - name = filename.split('/')[-2:] - os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) - kp_extractor.extract_keypoint( - images, - name=os.path.join(opt.output_dir, name[-2], name[-1]) - ) - -if __name__ == '__main__': - set_start_method('spawn') - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--input_dir', type=str, help='the folder of the input files') - parser.add_argument('--output_dir', type=str, help='the folder of the output files') - parser.add_argument('--device_ids', type=str, default='0,1') - parser.add_argument('--workers', type=int, default=4) - - opt = parser.parse_args() - filenames = list() - VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} - VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) - extensions = VIDEO_EXTENSIONS - - for ext in extensions: - os.listdir(f'{opt.input_dir}') - print(f'{opt.input_dir}/*.{ext}') - filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) - print('Total number of videos:', len(filenames)) - pool = Pool(opt.workers) - args_list = cycle([opt]) - device_ids = opt.device_ids.split(",") - device_ids = cycle(device_ids) - for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): - None diff --git a/sadtalker_audio2pose/src/face3d/models/__init__.py b/sadtalker_audio2pose/src/face3d/models/__init__.py deleted file mode 100644 index ef6b5e399254bd42850f3385878f35d4acf90852..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -"""This package contains modules related to objective functions, optimizations, and network architectures. - -To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. -You need to implement the following five functions: - -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). - -- : unpack data from dataset and apply preprocessing. - -- : produce intermediate results. - -- : calculate loss, gradients, and update network weights. - -- : (optionally) add model-specific options and set default options. - -In the function <__init__>, you need to define four lists: - -- self.loss_names (str list): specify the training losses that you want to plot and save. - -- self.model_names (str list): define networks used in our training. - -- self.visual_names (str list): specify the images that you want to display and save. - -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. - -Now you can use the model class by specifying flag '--model dummy'. -See our template model class 'template_model.py' for more details. -""" - -import importlib -from src.face3d.models.base_model import BaseModel - - -def find_model_using_name(model_name): - """Import the module "models/[model_name]_model.py". - - In the file, the class called DatasetNameModel() will - be instantiated. It has to be a subclass of BaseModel, - and it is case-insensitive. - """ - model_filename = "face3d.models." + model_name + "_model" - modellib = importlib.import_module(model_filename) - model = None - target_model_name = model_name.replace('_', '') + 'model' - for name, cls in modellib.__dict__.items(): - if name.lower() == target_model_name.lower() \ - and issubclass(cls, BaseModel): - model = cls - - if model is None: - print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) - exit(0) - - return model - - -def get_option_setter(model_name): - """Return the static method of the model class.""" - model_class = find_model_using_name(model_name) - return model_class.modify_commandline_options - - -def create_model(opt): - """Create a model given the option. - - This function warps the class CustomDatasetDataLoader. - This is the main interface between this package and 'train.py'/'test.py' - - Example: - >>> from models import create_model - >>> model = create_model(opt) - """ - model = find_model_using_name(opt.model) - instance = model(opt) - print("model [%s] was created" % type(instance).__name__) - return instance diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/README.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/README.md deleted file mode 100644 index cc7f1d45f2f5e4b752c42dc81d3e2879c1459c6e..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/README.md +++ /dev/null @@ -1,164 +0,0 @@ -# Distributed Arcface Training in Pytorch - -This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions -identity on a single server. - -## Requirements - -- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). -- `pip install -r requirements.txt`. -- Download the dataset - from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) - . - -## How to Training - -To train a model, run `train.py` with the path to the configs: - -### 1. Single node, 8 GPUs: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 -``` - -### 2. Multiple nodes, each node 8 GPUs: - -Node 0: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 -``` - -Node 1: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 -``` - -### 3.Training resnet2060 with 8 GPUs: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py -``` - -## Model Zoo - -- The models are available for non-commercial research purposes only. -- All models can be found in here. -- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw -- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) - -### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) - -ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face -recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. -As the result, we can evaluate the FAIR performance for different algorithms. - -For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The -globalised multi-racial testset contains 242,143 identities and 1,624,305 images. - -For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). -Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. -There are totally 13,928 positive pairs and 96,983,824 negative pairs. - -| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | -| :---: | :--- | :--- | :--- |:--- |:--- | -| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | -| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | -| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | -| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | -| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | -| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | -| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | -| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | -| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | -| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | - -### Performance on IJB-C and Verification Datasets - -| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | -| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | -| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| -| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| -| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| -| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| -| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| -| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| -| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| -| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| -| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| - -[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) - - -## [Speed Benchmark](docs/speed_benchmark.md) - -**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of -classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same -accuracy with several times faster training performance and smaller GPU memory. -Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a -sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a -sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, -we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed -training and mixed precision training. - -![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png) - -More details see -[speed_benchmark.md](docs/speed_benchmark.md) in docs. - -### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better) - -`-` means training failed because of gpu memory limitations. - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 4681 | 4824 | 5004 | -|1400000 | **1672** | 3043 | 4738 | -|5500000 | **-** | **1389** | 3975 | -|8000000 | **-** | **-** | 3565 | -|16000000 | **-** | **-** | 2679 | -|29000000 | **-** | **-** | **1855** | - -### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 7358 | 5306 | 4868 | -|1400000 | 32252 | 11178 | 6056 | -|5500000 | **-** | 32188 | 9854 | -|8000000 | **-** | **-** | 12310 | -|16000000 | **-** | **-** | 19950 | -|29000000 | **-** | **-** | 32324 | - -## Evaluation ICCV2021-MFR and IJB-C - -More details see [eval.md](docs/eval.md) in docs. - -## Test - -We tested many versions of PyTorch. Please create an issue if you are having trouble. - -- [x] torch 1.6.0 -- [x] torch 1.7.1 -- [x] torch 1.8.0 -- [x] torch 1.9.0 - -## Citation - -``` -@inproceedings{deng2019arcface, - title={Arcface: Additive angular margin loss for deep face recognition}, - author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, - booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, - pages={4690--4699}, - year={2019} -} -@inproceedings{an2020partical_fc, - title={Partial FC: Training 10 Million Identities on a Single Machine}, - author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and - Zhang, Debing and Fu Ying}, - booktitle={Arxiv 2010.05222}, - year={2020} -} -``` diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/__init__.py deleted file mode 100644 index 5650187b4fdea84c5a23e0445440901690ab682a..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 -from .mobilefacenet import get_mbf - - -def get_model(name, **kwargs): - # resnet - if name == "r18": - return iresnet18(False, **kwargs) - elif name == "r34": - return iresnet34(False, **kwargs) - elif name == "r50": - return iresnet50(False, **kwargs) - elif name == "r100": - return iresnet100(False, **kwargs) - elif name == "r200": - return iresnet200(False, **kwargs) - elif name == "r2060": - from .iresnet2060 import iresnet2060 - return iresnet2060(False, **kwargs) - elif name == "mbf": - fp16 = kwargs.get("fp16", False) - num_features = kwargs.get("num_features", 512) - return get_mbf(fp16=fp16, num_features=num_features) - else: - raise ValueError() \ No newline at end of file diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet.py deleted file mode 100644 index d29f5f2bfbd444273717c4bc8aa20ba7edd08f80..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -from torch import nn - -__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=1, - stride=stride, - bias=False) - - -class IBasicBlock(nn.Module): - expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - groups=1, base_width=64, dilation=1): - super(IBasicBlock, self).__init__() - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) - self.conv1 = conv3x3(inplanes, planes) - self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) - self.prelu = nn.PReLU(planes) - self.conv2 = conv3x3(planes, planes, stride) - self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - out = self.bn1(x) - out = self.conv1(out) - out = self.bn2(out) - out = self.prelu(out) - out = self.conv2(out) - out = self.bn3(out) - if self.downsample is not None: - identity = self.downsample(x) - out += identity - return out - - -class IResNet(nn.Module): - fc_scale = 7 * 7 - def __init__(self, - block, layers, dropout=0, num_features=512, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): - super(IResNet, self).__init__() - self.fp16 = fp16 - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) - self.prelu = nn.PReLU(self.inplanes) - self.layer1 = self._make_layer(block, 64, layers[0], stride=2) - self.layer2 = self._make_layer(block, - 128, - layers[1], - stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, - 256, - layers[2], - stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, - 512, - layers[3], - stride=2, - dilate=replace_stride_with_dilation[2]) - self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) - self.dropout = nn.Dropout(p=dropout, inplace=True) - self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) - self.features = nn.BatchNorm1d(num_features, eps=1e-05) - nn.init.constant_(self.features.weight, 1.0) - self.features.weight.requires_grad = False - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.normal_(m.weight, 0, 0.1) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - if zero_init_residual: - for m in self.modules(): - if isinstance(m, IBasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), - ) - layers = [] - layers.append( - block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append( - block(self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation)) - - return nn.Sequential(*layers) - - def forward(self, x): - with torch.cuda.amp.autocast(self.fp16): - x = self.conv1(x) - x = self.bn1(x) - x = self.prelu(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.bn2(x) - x = torch.flatten(x, 1) - x = self.dropout(x) - x = self.fc(x.float() if self.fp16 else x) - x = self.features(x) - return x - - -def _iresnet(arch, block, layers, pretrained, progress, **kwargs): - model = IResNet(block, layers, **kwargs) - if pretrained: - raise ValueError() - return model - - -def iresnet18(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, - progress, **kwargs) - - -def iresnet34(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, - progress, **kwargs) - - -def iresnet50(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, - progress, **kwargs) - - -def iresnet100(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, - progress, **kwargs) - - -def iresnet200(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, - progress, **kwargs) - diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py deleted file mode 100644 index 39bb4335716b653bd5924e20d616d825ef48339f..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -from torch import nn - -assert torch.__version__ >= "1.8.1" -from torch.utils.checkpoint import checkpoint_sequential - -__all__ = ['iresnet2060'] - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=1, - stride=stride, - bias=False) - - -class IBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - groups=1, base_width=64, dilation=1): - super(IBasicBlock, self).__init__() - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) - self.conv1 = conv3x3(inplanes, planes) - self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) - self.prelu = nn.PReLU(planes) - self.conv2 = conv3x3(planes, planes, stride) - self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - out = self.bn1(x) - out = self.conv1(out) - out = self.bn2(out) - out = self.prelu(out) - out = self.conv2(out) - out = self.bn3(out) - if self.downsample is not None: - identity = self.downsample(x) - out += identity - return out - - -class IResNet(nn.Module): - fc_scale = 7 * 7 - - def __init__(self, - block, layers, dropout=0, num_features=512, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): - super(IResNet, self).__init__() - self.fp16 = fp16 - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) - self.prelu = nn.PReLU(self.inplanes) - self.layer1 = self._make_layer(block, 64, layers[0], stride=2) - self.layer2 = self._make_layer(block, - 128, - layers[1], - stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, - 256, - layers[2], - stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, - 512, - layers[3], - stride=2, - dilate=replace_stride_with_dilation[2]) - self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) - self.dropout = nn.Dropout(p=dropout, inplace=True) - self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) - self.features = nn.BatchNorm1d(num_features, eps=1e-05) - nn.init.constant_(self.features.weight, 1.0) - self.features.weight.requires_grad = False - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.normal_(m.weight, 0, 0.1) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - if zero_init_residual: - for m in self.modules(): - if isinstance(m, IBasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), - ) - layers = [] - layers.append( - block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append( - block(self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation)) - - return nn.Sequential(*layers) - - def checkpoint(self, func, num_seg, x): - if self.training: - return checkpoint_sequential(func, num_seg, x) - else: - return func(x) - - def forward(self, x): - with torch.cuda.amp.autocast(self.fp16): - x = self.conv1(x) - x = self.bn1(x) - x = self.prelu(x) - x = self.layer1(x) - x = self.checkpoint(self.layer2, 20, x) - x = self.checkpoint(self.layer3, 100, x) - x = self.layer4(x) - x = self.bn2(x) - x = torch.flatten(x, 1) - x = self.dropout(x) - x = self.fc(x.float() if self.fp16 else x) - x = self.features(x) - return x - - -def _iresnet(arch, block, layers, pretrained, progress, **kwargs): - model = IResNet(block, layers, **kwargs) - if pretrained: - raise ValueError() - return model - - -def iresnet2060(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py deleted file mode 100644 index c02c6c1e4fa6a6ddf09f5b01dec96971427cb110..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py +++ /dev/null @@ -1,130 +0,0 @@ -''' -Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py -Original author cavalleria -''' - -import torch.nn as nn -from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module -import torch - - -class Flatten(Module): - def forward(self, x): - return x.view(x.size(0), -1) - - -class ConvBlock(Module): - def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): - super(ConvBlock, self).__init__() - self.layers = nn.Sequential( - Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), - BatchNorm2d(num_features=out_c), - PReLU(num_parameters=out_c) - ) - - def forward(self, x): - return self.layers(x) - - -class LinearBlock(Module): - def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): - super(LinearBlock, self).__init__() - self.layers = nn.Sequential( - Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), - BatchNorm2d(num_features=out_c) - ) - - def forward(self, x): - return self.layers(x) - - -class DepthWise(Module): - def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): - super(DepthWise, self).__init__() - self.residual = residual - self.layers = nn.Sequential( - ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), - ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), - LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) - ) - - def forward(self, x): - short_cut = None - if self.residual: - short_cut = x - x = self.layers(x) - if self.residual: - output = short_cut + x - else: - output = x - return output - - -class Residual(Module): - def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): - super(Residual, self).__init__() - modules = [] - for _ in range(num_block): - modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) - self.layers = Sequential(*modules) - - def forward(self, x): - return self.layers(x) - - -class GDC(Module): - def __init__(self, embedding_size): - super(GDC, self).__init__() - self.layers = nn.Sequential( - LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), - Flatten(), - Linear(512, embedding_size, bias=False), - BatchNorm1d(embedding_size)) - - def forward(self, x): - return self.layers(x) - - -class MobileFaceNet(Module): - def __init__(self, fp16=False, num_features=512): - super(MobileFaceNet, self).__init__() - scale = 2 - self.fp16 = fp16 - self.layers = nn.Sequential( - ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), - ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), - DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), - Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), - DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), - Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), - DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), - Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), - ) - self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) - self.features = GDC(num_features) - self._initialize_weights() - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x): - with torch.cuda.amp.autocast(self.fp16): - x = self.layers(x) - x = self.conv_sep(x.float() if self.fp16 else x) - x = self.features(x) - return x - - -def get_mbf(fp16, num_features): - return MobileFaceNet(fp16, num_features) \ No newline at end of file diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions.py deleted file mode 100644 index 3bee7cb4236e8b842a1bd1e8c26de7a11df0bf43..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 300 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py deleted file mode 100644 index bf7df5f04e2509e5dcc14adebbb9302a18f03f2b..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 0.1 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 300 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/base.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/base.py deleted file mode 100644 index f98c62fed44afde276dcbacecd9da0a8f474963c..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/base.py +++ /dev/null @@ -1,56 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = "ms1mv3_arcface_r50" - -config.dataset = "ms1m-retinaface-t1" -config.embedding_size = 512 -config.sample_rate = 1 -config.fp16 = False -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -if config.dataset == "emore": - config.rec = "/train_tmp/faces_emore" - config.num_classes = 85742 - config.num_image = 5822653 - config.num_epoch = 16 - config.warmup_epoch = -1 - config.decay_epoch = [8, 14, ] - config.val_targets = ["lfw", ] - -elif config.dataset == "ms1m-retinaface-t1": - config.rec = "/train_tmp/ms1m-retinaface-t1" - config.num_classes = 93431 - config.num_image = 5179510 - config.num_epoch = 25 - config.warmup_epoch = -1 - config.decay_epoch = [11, 17, 22] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] - -elif config.dataset == "glint360k": - config.rec = "/train_tmp/glint360k" - config.num_classes = 360232 - config.num_image = 17091657 - config.num_epoch = 20 - config.warmup_epoch = -1 - config.decay_epoch = [8, 12, 15, 18] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] - -elif config.dataset == "webface": - config.rec = "/train_tmp/faces_webface_112x112" - config.num_classes = 10572 - config.num_image = "forget" - config.num_epoch = 34 - config.warmup_epoch = -1 - config.decay_epoch = [20, 28, 32] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py deleted file mode 100644 index 44ee5e8d96249d57196df43418f6fda4ab339877..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "mbf" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 0.1 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 2e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py deleted file mode 100644 index f8f8ef745c0efb9d5ea67409edc8c904def8a9d9..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r100" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py deleted file mode 100644 index 473b59a954fffcaddca132fb6e0f32cbe70c70f4..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r18" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py deleted file mode 100644 index d9c22ff0c82cc98bbbe81c9a1c26c9b3fc186105..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r34" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py deleted file mode 100644 index 8ecbfda06730e3842e7b347db366e82f0714912f..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py deleted file mode 100644 index 47c87a99867db55c7f689574c331c14cda23ea96..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "mbf" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 2e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 20, 25] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py deleted file mode 100644 index 1aeb851b05ea22e01da87b3d387812f0253989f8..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r18" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py deleted file mode 100644 index 8693e67080dac7e7b84da08a62df326c7b12d465..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r2060" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 64 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py deleted file mode 100644 index 52bff483db179045c0e3acc8e2975477182b0756..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r34" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py deleted file mode 100644 index de81ffdd84edd6fcea7fcb4d3594db031b9e4e26..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/speed.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/speed.py deleted file mode 100644 index c172f9d44d39b534f2253630471e91cf78e6fba7..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/configs/speed.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 100 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/dataset.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/dataset.py deleted file mode 100644 index 8bead250243237c650fa3138f6aa172d4f98535f..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/dataset.py +++ /dev/null @@ -1,124 +0,0 @@ -import numbers -import os -import queue as Queue -import threading - -import mxnet as mx -import numpy as np -import torch -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms - - -class BackgroundGenerator(threading.Thread): - def __init__(self, generator, local_rank, max_prefetch=6): - super(BackgroundGenerator, self).__init__() - self.queue = Queue.Queue(max_prefetch) - self.generator = generator - self.local_rank = local_rank - self.daemon = True - self.start() - - def run(self): - torch.cuda.set_device(self.local_rank) - for item in self.generator: - self.queue.put(item) - self.queue.put(None) - - def next(self): - next_item = self.queue.get() - if next_item is None: - raise StopIteration - return next_item - - def __next__(self): - return self.next() - - def __iter__(self): - return self - - -class DataLoaderX(DataLoader): - - def __init__(self, local_rank, **kwargs): - super(DataLoaderX, self).__init__(**kwargs) - self.stream = torch.cuda.Stream(local_rank) - self.local_rank = local_rank - - def __iter__(self): - self.iter = super(DataLoaderX, self).__iter__() - self.iter = BackgroundGenerator(self.iter, self.local_rank) - self.preload() - return self - - def preload(self): - self.batch = next(self.iter, None) - if self.batch is None: - return None - with torch.cuda.stream(self.stream): - for k in range(len(self.batch)): - self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) - - def __next__(self): - torch.cuda.current_stream().wait_stream(self.stream) - batch = self.batch - if batch is None: - raise StopIteration - self.preload() - return batch - - -class MXFaceDataset(Dataset): - def __init__(self, root_dir, local_rank): - super(MXFaceDataset, self).__init__() - self.transform = transforms.Compose( - [transforms.ToPILImage(), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - self.root_dir = root_dir - self.local_rank = local_rank - path_imgrec = os.path.join(root_dir, 'train.rec') - path_imgidx = os.path.join(root_dir, 'train.idx') - self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') - s = self.imgrec.read_idx(0) - header, _ = mx.recordio.unpack(s) - if header.flag > 0: - self.header0 = (int(header.label[0]), int(header.label[1])) - self.imgidx = np.array(range(1, int(header.label[0]))) - else: - self.imgidx = np.array(list(self.imgrec.keys)) - - def __getitem__(self, index): - idx = self.imgidx[index] - s = self.imgrec.read_idx(idx) - header, img = mx.recordio.unpack(s) - label = header.label - if not isinstance(label, numbers.Number): - label = label[0] - label = torch.tensor(label, dtype=torch.long) - sample = mx.image.imdecode(img).asnumpy() - if self.transform is not None: - sample = self.transform(sample) - return sample, label - - def __len__(self): - return len(self.imgidx) - - -class SyntheticDataset(Dataset): - def __init__(self, local_rank): - super(SyntheticDataset, self).__init__() - img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) - img = np.transpose(img, (2, 0, 1)) - img = torch.from_numpy(img).squeeze(0).float() - img = ((img / 255) - 0.5) / 0.5 - self.img = img - self.label = 1 - - def __getitem__(self, index): - return self.img, self.label - - def __len__(self): - return 1000000 diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/eval.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/eval.md deleted file mode 100644 index 4d29c855fc6e4245ed264216c1f96ab2efc57248..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/eval.md +++ /dev/null @@ -1,31 +0,0 @@ -## Eval on ICCV2021-MFR - -coming soon. - - -## Eval IJBC -You can eval ijbc with pytorch or onnx. - - -1. Eval IJBC With Onnx -```shell -CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 -``` - -2. Eval IJBC With Pytorch -```shell -CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ ---model-prefix ms1mv3_arcface_r50/backbone.pth \ ---image-path IJB_release/IJBC \ ---result-dir ms1mv3_arcface_r50 \ ---batch-size 128 \ ---job ms1mv3_arcface_r50 \ ---target IJBC \ ---network iresnet50 -``` - -## Inference - -```shell -python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 -``` diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/install.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/install.md deleted file mode 100644 index b1b770a0d93dac1f160185b5bbf4da2f414f21f6..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/install.md +++ /dev/null @@ -1,51 +0,0 @@ -## v1.8.0 -### Linux and Windows -```shell -# CUDA 11.0 -pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 10.2 -pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 - -# CPU only -pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html - -``` - - -## v1.7.1 -### Linux and Windows -```shell -# CUDA 11.0 -pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 10.2 -pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 - -# CUDA 10.1 -pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 9.2 -pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html -``` - - -## v1.6.0 - -### Linux and Windows -```shell -# CUDA 10.2 -pip install torch==1.6.0 torchvision==0.7.0 - -# CUDA 10.1 -pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 9.2 -pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -``` \ No newline at end of file diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/modelzoo.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/modelzoo.md deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md b/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md deleted file mode 100644 index d54904587df4e13784dc68d5709b4d7d97490890..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md +++ /dev/null @@ -1,93 +0,0 @@ -## Test Training Speed - -- Test Commands - -You need to use the following two commands to test the Partial FC training performance. -The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, -batch size is 1024. -```shell -# Model Parallel -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions -# Partial FC 0.1 -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc -``` - -- GPU Memory - -``` -# (Model Parallel) gpustat -i -[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB -[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB -[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB -[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB -[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB -[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB -[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB -[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB - -# (Partial FC 0.1) gpustat -i -[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· -[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· -[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· -[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· -[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· -[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· -[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· -[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· -``` - -- Training Speed - -```python -# (Model Parallel) trainging.log -Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 -Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 -Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 -Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 -Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 - -# (Partial FC 0.1) trainging.log -Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 -Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 -Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 -Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 -Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 -``` - -In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, -and the training speed is 2.5 times faster than the model parallel. - - -## Speed Benchmark - -1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 4681 | 4824 | 5004 | -|250000 | 4047 | 4521 | 4976 | -|500000 | 3087 | 4013 | 4900 | -|1000000 | 2090 | 3449 | 4803 | -|1400000 | 1672 | 3043 | 4738 | -|2000000 | - | 2593 | 4626 | -|4000000 | - | 1748 | 4208 | -|5500000 | - | 1389 | 3975 | -|8000000 | - | - | 3565 | -|16000000 | - | - | 2679 | -|29000000 | - | - | 1855 | - -2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 7358 | 5306 | 4868 | -|250000 | 9940 | 5826 | 5004 | -|500000 | 14220 | 7114 | 5202 | -|1000000 | 23708 | 9966 | 5620 | -|1400000 | 32252 | 11178 | 6056 | -|2000000 | - | 13978 | 6472 | -|4000000 | - | 23238 | 8284 | -|5500000 | - | 32188 | 9854 | -|8000000 | - | - | 12310 | -|16000000 | - | - | 19950 | -|29000000 | - | - | 32324 | diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/verification.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/verification.py deleted file mode 100644 index 5b1f5618184effae64895847af1a65d43d2e4418..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval/verification.py +++ /dev/null @@ -1,407 +0,0 @@ -"""Helper for evaluation on the Labeled Faces in the Wild dataset -""" - -# MIT License -# -# Copyright (c) 2016 David Sandberg -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - -import datetime -import os -import pickle - -import mxnet as mx -import numpy as np -import sklearn -import torch -from mxnet import ndarray as nd -from scipy import interpolate -from sklearn.decomposition import PCA -from sklearn.model_selection import KFold - - -class LFold: - def __init__(self, n_splits=2, shuffle=False): - self.n_splits = n_splits - if self.n_splits > 1: - self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) - - def split(self, indices): - if self.n_splits > 1: - return self.k_fold.split(indices) - else: - return [(indices, indices)] - - -def calculate_roc(thresholds, - embeddings1, - embeddings2, - actual_issame, - nrof_folds=10, - pca=0): - assert (embeddings1.shape[0] == embeddings2.shape[0]) - assert (embeddings1.shape[1] == embeddings2.shape[1]) - nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) - nrof_thresholds = len(thresholds) - k_fold = LFold(n_splits=nrof_folds, shuffle=False) - - tprs = np.zeros((nrof_folds, nrof_thresholds)) - fprs = np.zeros((nrof_folds, nrof_thresholds)) - accuracy = np.zeros((nrof_folds)) - indices = np.arange(nrof_pairs) - - if pca == 0: - diff = np.subtract(embeddings1, embeddings2) - dist = np.sum(np.square(diff), 1) - - for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): - if pca > 0: - print('doing pca on', fold_idx) - embed1_train = embeddings1[train_set] - embed2_train = embeddings2[train_set] - _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) - pca_model = PCA(n_components=pca) - pca_model.fit(_embed_train) - embed1 = pca_model.transform(embeddings1) - embed2 = pca_model.transform(embeddings2) - embed1 = sklearn.preprocessing.normalize(embed1) - embed2 = sklearn.preprocessing.normalize(embed2) - diff = np.subtract(embed1, embed2) - dist = np.sum(np.square(diff), 1) - - # Find the best threshold for the fold - acc_train = np.zeros((nrof_thresholds)) - for threshold_idx, threshold in enumerate(thresholds): - _, _, acc_train[threshold_idx] = calculate_accuracy( - threshold, dist[train_set], actual_issame[train_set]) - best_threshold_index = np.argmax(acc_train) - for threshold_idx, threshold in enumerate(thresholds): - tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( - threshold, dist[test_set], - actual_issame[test_set]) - _, _, accuracy[fold_idx] = calculate_accuracy( - thresholds[best_threshold_index], dist[test_set], - actual_issame[test_set]) - - tpr = np.mean(tprs, 0) - fpr = np.mean(fprs, 0) - return tpr, fpr, accuracy - - -def calculate_accuracy(threshold, dist, actual_issame): - predict_issame = np.less(dist, threshold) - tp = np.sum(np.logical_and(predict_issame, actual_issame)) - fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) - tn = np.sum( - np.logical_and(np.logical_not(predict_issame), - np.logical_not(actual_issame))) - fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) - - tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) - fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) - acc = float(tp + tn) / dist.size - return tpr, fpr, acc - - -def calculate_val(thresholds, - embeddings1, - embeddings2, - actual_issame, - far_target, - nrof_folds=10): - assert (embeddings1.shape[0] == embeddings2.shape[0]) - assert (embeddings1.shape[1] == embeddings2.shape[1]) - nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) - nrof_thresholds = len(thresholds) - k_fold = LFold(n_splits=nrof_folds, shuffle=False) - - val = np.zeros(nrof_folds) - far = np.zeros(nrof_folds) - - diff = np.subtract(embeddings1, embeddings2) - dist = np.sum(np.square(diff), 1) - indices = np.arange(nrof_pairs) - - for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): - - # Find the threshold that gives FAR = far_target - far_train = np.zeros(nrof_thresholds) - for threshold_idx, threshold in enumerate(thresholds): - _, far_train[threshold_idx] = calculate_val_far( - threshold, dist[train_set], actual_issame[train_set]) - if np.max(far_train) >= far_target: - f = interpolate.interp1d(far_train, thresholds, kind='slinear') - threshold = f(far_target) - else: - threshold = 0.0 - - val[fold_idx], far[fold_idx] = calculate_val_far( - threshold, dist[test_set], actual_issame[test_set]) - - val_mean = np.mean(val) - far_mean = np.mean(far) - val_std = np.std(val) - return val_mean, val_std, far_mean - - -def calculate_val_far(threshold, dist, actual_issame): - predict_issame = np.less(dist, threshold) - true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) - false_accept = np.sum( - np.logical_and(predict_issame, np.logical_not(actual_issame))) - n_same = np.sum(actual_issame) - n_diff = np.sum(np.logical_not(actual_issame)) - # print(true_accept, false_accept) - # print(n_same, n_diff) - val = float(true_accept) / float(n_same) - far = float(false_accept) / float(n_diff) - return val, far - - -def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): - # Calculate evaluation metrics - thresholds = np.arange(0, 4, 0.01) - embeddings1 = embeddings[0::2] - embeddings2 = embeddings[1::2] - tpr, fpr, accuracy = calculate_roc(thresholds, - embeddings1, - embeddings2, - np.asarray(actual_issame), - nrof_folds=nrof_folds, - pca=pca) - thresholds = np.arange(0, 4, 0.001) - val, val_std, far = calculate_val(thresholds, - embeddings1, - embeddings2, - np.asarray(actual_issame), - 1e-3, - nrof_folds=nrof_folds) - return tpr, fpr, accuracy, val, val_std, far - -@torch.no_grad() -def load_bin(path, image_size): - try: - with open(path, 'rb') as f: - bins, issame_list = pickle.load(f) # py2 - except UnicodeDecodeError as e: - with open(path, 'rb') as f: - bins, issame_list = pickle.load(f, encoding='bytes') # py3 - data_list = [] - for flip in [0, 1]: - data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) - data_list.append(data) - for idx in range(len(issame_list) * 2): - _bin = bins[idx] - img = mx.image.imdecode(_bin) - if img.shape[1] != image_size[0]: - img = mx.image.resize_short(img, image_size[0]) - img = nd.transpose(img, axes=(2, 0, 1)) - for flip in [0, 1]: - if flip == 1: - img = mx.ndarray.flip(data=img, axis=2) - data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) - if idx % 1000 == 0: - print('loading bin', idx) - print(data_list[0].shape) - return data_list, issame_list - -@torch.no_grad() -def test(data_set, backbone, batch_size, nfolds=10): - print('testing verification..') - data_list = data_set[0] - issame_list = data_set[1] - embeddings_list = [] - time_consumed = 0.0 - for i in range(len(data_list)): - data = data_list[i] - embeddings = None - ba = 0 - while ba < data.shape[0]: - bb = min(ba + batch_size, data.shape[0]) - count = bb - ba - _data = data[bb - batch_size: bb] - time0 = datetime.datetime.now() - img = ((_data / 255) - 0.5) / 0.5 - net_out: torch.Tensor = backbone(img) - _embeddings = net_out.detach().cpu().numpy() - time_now = datetime.datetime.now() - diff = time_now - time0 - time_consumed += diff.total_seconds() - if embeddings is None: - embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) - embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] - ba = bb - embeddings_list.append(embeddings) - - _xnorm = 0.0 - _xnorm_cnt = 0 - for embed in embeddings_list: - for i in range(embed.shape[0]): - _em = embed[i] - _norm = np.linalg.norm(_em) - _xnorm += _norm - _xnorm_cnt += 1 - _xnorm /= _xnorm_cnt - - acc1 = 0.0 - std1 = 0.0 - embeddings = embeddings_list[0] + embeddings_list[1] - embeddings = sklearn.preprocessing.normalize(embeddings) - print(embeddings.shape) - print('infer time', time_consumed) - _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) - acc2, std2 = np.mean(accuracy), np.std(accuracy) - return acc1, std1, acc2, std2, _xnorm, embeddings_list - - -def dumpR(data_set, - backbone, - batch_size, - name='', - data_extra=None, - label_shape=None): - print('dump verification embedding..') - data_list = data_set[0] - issame_list = data_set[1] - embeddings_list = [] - time_consumed = 0.0 - for i in range(len(data_list)): - data = data_list[i] - embeddings = None - ba = 0 - while ba < data.shape[0]: - bb = min(ba + batch_size, data.shape[0]) - count = bb - ba - - _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) - time0 = datetime.datetime.now() - if data_extra is None: - db = mx.io.DataBatch(data=(_data,), label=(_label,)) - else: - db = mx.io.DataBatch(data=(_data, _data_extra), - label=(_label,)) - model.forward(db, is_train=False) - net_out = model.get_outputs() - _embeddings = net_out[0].asnumpy() - time_now = datetime.datetime.now() - diff = time_now - time0 - time_consumed += diff.total_seconds() - if embeddings is None: - embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) - embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] - ba = bb - embeddings_list.append(embeddings) - embeddings = embeddings_list[0] + embeddings_list[1] - embeddings = sklearn.preprocessing.normalize(embeddings) - actual_issame = np.asarray(issame_list) - outname = os.path.join('temp.bin') - with open(outname, 'wb') as f: - pickle.dump((embeddings, issame_list), - f, - protocol=pickle.HIGHEST_PROTOCOL) - - -# if __name__ == '__main__': -# -# parser = argparse.ArgumentParser(description='do verification') -# # general -# parser.add_argument('--data-dir', default='', help='') -# parser.add_argument('--model', -# default='../model/softmax,50', -# help='path to load model.') -# parser.add_argument('--target', -# default='lfw,cfp_ff,cfp_fp,agedb_30', -# help='test targets.') -# parser.add_argument('--gpu', default=0, type=int, help='gpu id') -# parser.add_argument('--batch-size', default=32, type=int, help='') -# parser.add_argument('--max', default='', type=str, help='') -# parser.add_argument('--mode', default=0, type=int, help='') -# parser.add_argument('--nfolds', default=10, type=int, help='') -# args = parser.parse_args() -# image_size = [112, 112] -# print('image_size', image_size) -# ctx = mx.gpu(args.gpu) -# nets = [] -# vec = args.model.split(',') -# prefix = args.model.split(',')[0] -# epochs = [] -# if len(vec) == 1: -# pdir = os.path.dirname(prefix) -# for fname in os.listdir(pdir): -# if not fname.endswith('.params'): -# continue -# _file = os.path.join(pdir, fname) -# if _file.startswith(prefix): -# epoch = int(fname.split('.')[0].split('-')[1]) -# epochs.append(epoch) -# epochs = sorted(epochs, reverse=True) -# if len(args.max) > 0: -# _max = [int(x) for x in args.max.split(',')] -# assert len(_max) == 2 -# if len(epochs) > _max[1]: -# epochs = epochs[_max[0]:_max[1]] -# -# else: -# epochs = [int(x) for x in vec[1].split('|')] -# print('model number', len(epochs)) -# time0 = datetime.datetime.now() -# for epoch in epochs: -# print('loading', prefix, epoch) -# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) -# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) -# all_layers = sym.get_internals() -# sym = all_layers['fc1_output'] -# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) -# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) -# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], -# image_size[1]))]) -# model.set_params(arg_params, aux_params) -# nets.append(model) -# time_now = datetime.datetime.now() -# diff = time_now - time0 -# print('model loading time', diff.total_seconds()) -# -# ver_list = [] -# ver_name_list = [] -# for name in args.target.split(','): -# path = os.path.join(args.data_dir, name + ".bin") -# if os.path.exists(path): -# print('loading.. ', name) -# data_set = load_bin(path, image_size) -# ver_list.append(data_set) -# ver_name_list.append(name) -# -# if args.mode == 0: -# for i in range(len(ver_list)): -# results = [] -# for model in nets: -# acc1, std1, acc2, std2, xnorm, embeddings_list = test( -# ver_list[i], model, args.batch_size, args.nfolds) -# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) -# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) -# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) -# results.append(acc2) -# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) -# elif args.mode == 1: -# raise ValueError -# else: -# model = nets[0] -# dumpR(ver_list[0], model, args.batch_size, args.target) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval_ijbc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval_ijbc.py deleted file mode 100644 index 64844c4723a88b4b160d2fee9a7b626b987981d9..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/eval_ijbc.py +++ /dev/null @@ -1,483 +0,0 @@ -# coding: utf-8 - -import os -import pickle - -import matplotlib -import pandas as pd - -matplotlib.use('Agg') -import matplotlib.pyplot as plt -import timeit -import sklearn -import argparse -import cv2 -import numpy as np -import torch -from skimage import transform as trans -from backbones import get_model -from sklearn.metrics import roc_curve, auc - -from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap -from prettytable import PrettyTable -from pathlib import Path - -import sys -import warnings - -sys.path.insert(0, "../") -warnings.filterwarnings("ignore") - -parser = argparse.ArgumentParser(description='do ijb test') -# general -parser.add_argument('--model-prefix', default='', help='path to load model.') -parser.add_argument('--image-path', default='', type=str, help='') -parser.add_argument('--result-dir', default='.', type=str, help='') -parser.add_argument('--batch-size', default=128, type=int, help='') -parser.add_argument('--network', default='iresnet50', type=str, help='') -parser.add_argument('--job', default='insightface', type=str, help='job name') -parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') -args = parser.parse_args() - -target = args.target -model_path = args.model_prefix -image_path = args.image_path -result_dir = args.result_dir -gpu_id = None -use_norm_score = True # if Ture, TestMode(N1) -use_detector_score = True # if Ture, TestMode(D1) -use_flip_test = True # if Ture, TestMode(F1) -job = args.job -batch_size = args.batch_size - - -class Embedding(object): - def __init__(self, prefix, data_shape, batch_size=1): - image_size = (112, 112) - self.image_size = image_size - weight = torch.load(prefix) - resnet = get_model(args.network, dropout=0, fp16=False).cuda() - resnet.load_state_dict(weight) - model = torch.nn.DataParallel(resnet) - self.model = model - self.model.eval() - src = np.array([ - [30.2946, 51.6963], - [65.5318, 51.5014], - [48.0252, 71.7366], - [33.5493, 92.3655], - [62.7299, 92.2041]], dtype=np.float32) - src[:, 0] += 8.0 - self.src = src - self.batch_size = batch_size - self.data_shape = data_shape - - def get(self, rimg, landmark): - - assert landmark.shape[0] == 68 or landmark.shape[0] == 5 - assert landmark.shape[1] == 2 - if landmark.shape[0] == 68: - landmark5 = np.zeros((5, 2), dtype=np.float32) - landmark5[0] = (landmark[36] + landmark[39]) / 2 - landmark5[1] = (landmark[42] + landmark[45]) / 2 - landmark5[2] = landmark[30] - landmark5[3] = landmark[48] - landmark5[4] = landmark[54] - else: - landmark5 = landmark - tform = trans.SimilarityTransform() - tform.estimate(landmark5, self.src) - M = tform.params[0:2, :] - img = cv2.warpAffine(rimg, - M, (self.image_size[1], self.image_size[0]), - borderValue=0.0) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img_flip = np.fliplr(img) - img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB - img_flip = np.transpose(img_flip, (2, 0, 1)) - input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) - input_blob[0] = img - input_blob[1] = img_flip - return input_blob - - @torch.no_grad() - def forward_db(self, batch_data): - imgs = torch.Tensor(batch_data).cuda() - imgs.div_(255).sub_(0.5).div_(0.5) - feat = self.model(imgs) - feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) - return feat.cpu().numpy() - - -# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] -def divideIntoNstrand(listTemp, n): - twoList = [[] for i in range(n)] - for i, e in enumerate(listTemp): - twoList[i % n].append(e) - return twoList - - -def read_template_media_list(path): - # ijb_meta = np.loadtxt(path, dtype=str) - ijb_meta = pd.read_csv(path, sep=' ', header=None).values - templates = ijb_meta[:, 1].astype(np.int) - medias = ijb_meta[:, 2].astype(np.int) - return templates, medias - - -# In[ ]: - - -def read_template_pair_list(path): - # pairs = np.loadtxt(path, dtype=str) - pairs = pd.read_csv(path, sep=' ', header=None).values - # print(pairs.shape) - # print(pairs[:, 0].astype(np.int)) - t1 = pairs[:, 0].astype(np.int) - t2 = pairs[:, 1].astype(np.int) - label = pairs[:, 2].astype(np.int) - return t1, t2, label - - -# In[ ]: - - -def read_image_feature(path): - with open(path, 'rb') as fid: - img_feats = pickle.load(fid) - return img_feats - - -# In[ ]: - - -def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): - batch_size = args.batch_size - data_shape = (3, 112, 112) - - files = files_list - print('files:', len(files)) - rare_size = len(files) % batch_size - faceness_scores = [] - batch = 0 - img_feats = np.empty((len(files), 1024), dtype=np.float32) - - batch_data = np.empty((2 * batch_size, 3, 112, 112)) - embedding = Embedding(model_path, data_shape, batch_size) - for img_index, each_line in enumerate(files[:len(files) - rare_size]): - name_lmk_score = each_line.strip().split(' ') - img_name = os.path.join(img_path, name_lmk_score[0]) - img = cv2.imread(img_name) - lmk = np.array([float(x) for x in name_lmk_score[1:-1]], - dtype=np.float32) - lmk = lmk.reshape((5, 2)) - input_blob = embedding.get(img, lmk) - - batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] - batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] - if (img_index + 1) % batch_size == 0: - print('batch', batch) - img_feats[batch * batch_size:batch * batch_size + - batch_size][:] = embedding.forward_db(batch_data) - batch += 1 - faceness_scores.append(name_lmk_score[-1]) - - batch_data = np.empty((2 * rare_size, 3, 112, 112)) - embedding = Embedding(model_path, data_shape, rare_size) - for img_index, each_line in enumerate(files[len(files) - rare_size:]): - name_lmk_score = each_line.strip().split(' ') - img_name = os.path.join(img_path, name_lmk_score[0]) - img = cv2.imread(img_name) - lmk = np.array([float(x) for x in name_lmk_score[1:-1]], - dtype=np.float32) - lmk = lmk.reshape((5, 2)) - input_blob = embedding.get(img, lmk) - batch_data[2 * img_index][:] = input_blob[0] - batch_data[2 * img_index + 1][:] = input_blob[1] - if (img_index + 1) % rare_size == 0: - print('batch', batch) - img_feats[len(files) - - rare_size:][:] = embedding.forward_db(batch_data) - batch += 1 - faceness_scores.append(name_lmk_score[-1]) - faceness_scores = np.array(faceness_scores).astype(np.float32) - # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 - # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) - return img_feats, faceness_scores - - -# In[ ]: - - -def image2template_feature(img_feats=None, templates=None, medias=None): - # ========================================================== - # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] - # 2. compute media feature. - # 3. compute template feature. - # ========================================================== - unique_templates = np.unique(templates) - template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) - - for count_template, uqt in enumerate(unique_templates): - - (ind_t,) = np.where(templates == uqt) - face_norm_feats = img_feats[ind_t] - face_medias = medias[ind_t] - unique_medias, unique_media_counts = np.unique(face_medias, - return_counts=True) - media_norm_feats = [] - for u, ct in zip(unique_medias, unique_media_counts): - (ind_m,) = np.where(face_medias == u) - if ct == 1: - media_norm_feats += [face_norm_feats[ind_m]] - else: # image features from the same video will be aggregated into one feature - media_norm_feats += [ - np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) - ] - media_norm_feats = np.array(media_norm_feats) - # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) - template_feats[count_template] = np.sum(media_norm_feats, axis=0) - if count_template % 2000 == 0: - print('Finish Calculating {} template features.'.format( - count_template)) - # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) - template_norm_feats = sklearn.preprocessing.normalize(template_feats) - # print(template_norm_feats.shape) - return template_norm_feats, unique_templates - - -# In[ ]: - - -def verification(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - # ========================================================== - # Compute set-to-set Similarity Score. - # ========================================================== - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - - score = np.zeros((len(p1),)) # save cosine distance between pairs - - total_pairs = np.array(range(len(p1))) - batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation - sublists = [ - total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) - ] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -# In[ ]: -def verification2(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - score = np.zeros((len(p1),)) # save cosine distance between pairs - total_pairs = np.array(range(len(p1))) - batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation - sublists = [ - total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) - ] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -def read_score(path): - with open(path, 'rb') as fid: - img_feats = pickle.load(fid) - return img_feats - - -# # Step1: Load Meta Data - -# In[ ]: - -assert target == 'IJBC' or target == 'IJBB' - -# ============================================================= -# load image and template relationships for template feature embedding -# tid --> template id, mid --> media id -# format: -# image_name tid mid -# ============================================================= -start = timeit.default_timer() -templates, medias = read_template_media_list( - os.path.join('%s/meta' % image_path, - '%s_face_tid_mid.txt' % target.lower())) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# In[ ]: - -# ============================================================= -# load template pairs for template-to-template verification -# tid : template id, label : 1/0 -# format: -# tid_1 tid_2 label -# ============================================================= -start = timeit.default_timer() -p1, p2, label = read_template_pair_list( - os.path.join('%s/meta' % image_path, - '%s_template_pair_label.txt' % target.lower())) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# # Step 2: Get Image Features - -# In[ ]: - -# ============================================================= -# load image features -# format: -# img_feats: [image_num x feats_dim] (227630, 512) -# ============================================================= -start = timeit.default_timer() -img_path = '%s/loose_crop' % image_path -img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) -img_list = open(img_list_path) -files = img_list.readlines() -# files_list = divideIntoNstrand(files, rank_size) -files_list = files - -# img_feats -# for i in range(rank_size): -img_feats, faceness_scores = get_image_feature(img_path, files_list, - model_path, 0, gpu_id) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) -print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], - img_feats.shape[1])) - -# # Step3: Get Template Features - -# In[ ]: - -# ============================================================= -# compute template features from image features. -# ============================================================= -start = timeit.default_timer() -# ========================================================== -# Norm feature before aggregation into template feature? -# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). -# ========================================================== -# 1. FaceScore (Feature Norm) -# 2. FaceScore (Detector) - -if use_flip_test: - # concat --- F1 - # img_input_feats = img_feats - # add --- F2 - img_input_feats = img_feats[:, 0:img_feats.shape[1] // - 2] + img_feats[:, img_feats.shape[1] // 2:] -else: - img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] - -if use_norm_score: - img_input_feats = img_input_feats -else: - # normalise features to remove norm information - img_input_feats = img_input_feats / np.sqrt( - np.sum(img_input_feats ** 2, -1, keepdims=True)) - -if use_detector_score: - print(img_input_feats.shape, faceness_scores.shape) - img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] -else: - img_input_feats = img_input_feats - -template_norm_feats, unique_templates = image2template_feature( - img_input_feats, templates, medias) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# # Step 4: Get Template Similarity Scores - -# In[ ]: - -# ============================================================= -# compute verification scores between template pairs. -# ============================================================= -start = timeit.default_timer() -score = verification(template_norm_feats, unique_templates, p1, p2) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# In[ ]: -save_path = os.path.join(result_dir, args.job) -# save_path = result_dir + '/%s_result' % target - -if not os.path.exists(save_path): - os.makedirs(save_path) - -score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) -np.save(score_save_file, score) - -# # Step 5: Get ROC Curves and TPR@FPR Table - -# In[ ]: - -files = [score_save_file] -methods = [] -scores = [] -for file in files: - methods.append(Path(file).stem) - scores.append(np.load(file)) - -methods = np.array(methods) -scores = dict(zip(methods, scores)) -colours = dict( - zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) -x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] -tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) -fig = plt.figure() -for method in methods: - fpr, tpr, _ = roc_curve(label, scores[method]) - roc_auc = auc(fpr, tpr) - fpr = np.flipud(fpr) - tpr = np.flipud(tpr) # select largest tpr at same fpr - plt.plot(fpr, - tpr, - color=colours[method], - lw=1, - label=('[%s (AUC = %0.4f %%)]' % - (method.split('-')[-1], roc_auc * 100))) - tpr_fpr_row = [] - tpr_fpr_row.append("%s-%s" % (method, target)) - for fpr_iter in np.arange(len(x_labels)): - _, min_index = min( - list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) - tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) - tpr_fpr_table.add_row(tpr_fpr_row) -plt.xlim([10 ** -6, 0.1]) -plt.ylim([0.3, 1.0]) -plt.grid(linestyle='--', linewidth=1) -plt.xticks(x_labels) -plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) -plt.xscale('log') -plt.xlabel('False Positive Rate') -plt.ylabel('True Positive Rate') -plt.title('ROC on IJB') -plt.legend(loc="lower right") -fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) -print(tpr_fpr_table) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/inference.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/inference.py deleted file mode 100644 index 1929d4abb640d040398dda57b491b9bd96deac9d..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/inference.py +++ /dev/null @@ -1,35 +0,0 @@ -import argparse - -import cv2 -import numpy as np -import torch - -from backbones import get_model - - -@torch.no_grad() -def inference(weight, name, img): - if img is None: - img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) - else: - img = cv2.imread(img) - img = cv2.resize(img, (112, 112)) - - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = np.transpose(img, (2, 0, 1)) - img = torch.from_numpy(img).unsqueeze(0).float() - img.div_(255).sub_(0.5).div_(0.5) - net = get_model(name, fp16=False) - net.load_state_dict(torch.load(weight)) - net.eval() - feat = net(img).numpy() - print(feat) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') - parser.add_argument('--network', type=str, default='r50', help='backbone network') - parser.add_argument('--weight', type=str, default='') - parser.add_argument('--img', type=str, default=None) - args = parser.parse_args() - inference(args.weight, args.network, args.img) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/losses.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/losses.py deleted file mode 100644 index 7bfdd8c6b7f6b0d465928f19c554e62340e5ad7b..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/losses.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from torch import nn - - -def get_loss(name): - if name == "cosface": - return CosFace() - elif name == "arcface": - return ArcFace() - else: - raise ValueError() - - -class CosFace(nn.Module): - def __init__(self, s=64.0, m=0.40): - super(CosFace, self).__init__() - self.s = s - self.m = m - - def forward(self, cosine, label): - index = torch.where(label != -1)[0] - m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) - m_hot.scatter_(1, label[index, None], self.m) - cosine[index] -= m_hot - ret = cosine * self.s - return ret - - -class ArcFace(nn.Module): - def __init__(self, s=64.0, m=0.5): - super(ArcFace, self).__init__() - self.s = s - self.m = m - - def forward(self, cosine: torch.Tensor, label): - index = torch.where(label != -1)[0] - m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) - m_hot.scatter_(1, label[index, None], self.m) - cosine.acos_() - cosine[index] += m_hot - cosine.cos_().mul_(self.s) - return cosine diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_helper.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_helper.py deleted file mode 100644 index 4a01a46621dc0ea695bd903de5d1e212d424c860..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_helper.py +++ /dev/null @@ -1,250 +0,0 @@ -from __future__ import division -import datetime -import os -import os.path as osp -import glob -import numpy as np -import cv2 -import sys -import onnxruntime -import onnx -import argparse -from onnx import numpy_helper -from insightface.data import get_image - -class ArcFaceORT: - def __init__(self, model_path, cpu=False): - self.model_path = model_path - # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" - self.providers = ['CPUExecutionProvider'] if cpu else None - - #input_size is (w,h), return error message, return None if success - def check(self, track='cfat', test_img = None): - #default is cfat - max_model_size_mb=1024 - max_feat_dim=512 - max_time_cost=15 - if track.startswith('ms1m'): - max_model_size_mb=1024 - max_feat_dim=512 - max_time_cost=10 - elif track.startswith('glint'): - max_model_size_mb=1024 - max_feat_dim=1024 - max_time_cost=20 - elif track.startswith('cfat'): - max_model_size_mb = 1024 - max_feat_dim = 512 - max_time_cost = 15 - elif track.startswith('unconstrained'): - max_model_size_mb=1024 - max_feat_dim=1024 - max_time_cost=30 - else: - return "track not found" - - if not os.path.exists(self.model_path): - return "model_path not exists" - if not os.path.isdir(self.model_path): - return "model_path should be directory" - onnx_files = [] - for _file in os.listdir(self.model_path): - if _file.endswith('.onnx'): - onnx_files.append(osp.join(self.model_path, _file)) - if len(onnx_files)==0: - return "do not have onnx files" - self.model_file = sorted(onnx_files)[-1] - print('use onnx-model:', self.model_file) - try: - session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) - except: - return "load onnx failed" - input_cfg = session.get_inputs()[0] - input_shape = input_cfg.shape - print('input-shape:', input_shape) - if len(input_shape)!=4: - return "length of input_shape should be 4" - if not isinstance(input_shape[0], str): - #return "input_shape[0] should be str to support batch-inference" - print('reset input-shape[0] to None') - model = onnx.load(self.model_file) - model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' - new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') - onnx.save(model, new_model_file) - self.model_file = new_model_file - print('use new onnx-model:', self.model_file) - try: - session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) - except: - return "load onnx failed" - input_cfg = session.get_inputs()[0] - input_shape = input_cfg.shape - print('new-input-shape:', input_shape) - - self.image_size = tuple(input_shape[2:4][::-1]) - #print('image_size:', self.image_size) - input_name = input_cfg.name - outputs = session.get_outputs() - output_names = [] - for o in outputs: - output_names.append(o.name) - #print(o.name, o.shape) - if len(output_names)!=1: - return "number of output nodes should be 1" - self.session = session - self.input_name = input_name - self.output_names = output_names - #print(self.output_names) - model = onnx.load(self.model_file) - graph = model.graph - if len(graph.node)<8: - return "too small onnx graph" - - input_size = (112,112) - self.crop = None - if track=='cfat': - crop_file = osp.join(self.model_path, 'crop.txt') - if osp.exists(crop_file): - lines = open(crop_file,'r').readlines() - if len(lines)!=6: - return "crop.txt should contain 6 lines" - lines = [int(x) for x in lines] - self.crop = lines[:4] - input_size = tuple(lines[4:6]) - if input_size!=self.image_size: - return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) - - self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) - if self.model_size_mb > max_model_size_mb: - return "max model size exceed, given %.3f-MB"%self.model_size_mb - - input_mean = None - input_std = None - if track=='cfat': - pn_file = osp.join(self.model_path, 'pixel_norm.txt') - if osp.exists(pn_file): - lines = open(pn_file,'r').readlines() - if len(lines)!=2: - return "pixel_norm.txt should contain 2 lines" - input_mean = float(lines[0]) - input_std = float(lines[1]) - if input_mean is not None or input_std is not None: - if input_mean is None or input_std is None: - return "please set input_mean and input_std simultaneously" - else: - find_sub = False - find_mul = False - for nid, node in enumerate(graph.node[:8]): - print(nid, node.name) - if node.name.startswith('Sub') or node.name.startswith('_minus'): - find_sub = True - if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): - find_mul = True - if find_sub and find_mul: - print("find sub and mul") - #mxnet arcface model - input_mean = 0.0 - input_std = 1.0 - else: - input_mean = 127.5 - input_std = 127.5 - self.input_mean = input_mean - self.input_std = input_std - for initn in graph.initializer: - weight_array = numpy_helper.to_array(initn) - dt = weight_array.dtype - if dt.itemsize<4: - return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) - if test_img is None: - test_img = get_image('Tom_Hanks_54745') - test_img = cv2.resize(test_img, self.image_size) - else: - test_img = cv2.resize(test_img, self.image_size) - feat, cost = self.benchmark(test_img) - batch_result = self.check_batch(test_img) - batch_result_sum = float(np.sum(batch_result)) - if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: - print(batch_result) - print(batch_result_sum) - return "batch result output contains NaN!" - - if len(feat.shape) < 2: - return "the shape of the feature must be two, but get {}".format(str(feat.shape)) - - if feat.shape[1] > max_feat_dim: - return "max feat dim exceed, given %d"%feat.shape[1] - self.feat_dim = feat.shape[1] - cost_ms = cost*1000 - if cost_ms>max_time_cost: - return "max time cost exceed, given %.4f"%cost_ms - self.cost_ms = cost_ms - print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) - return None - - def check_batch(self, img): - if not isinstance(img, list): - imgs = [img, ] * 32 - if self.crop is not None: - nimgs = [] - for img in imgs: - nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] - if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: - nimg = cv2.resize(nimg, self.image_size) - nimgs.append(nimg) - imgs = nimgs - blob = cv2.dnn.blobFromImages( - images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, - mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) - net_out = self.session.run(self.output_names, {self.input_name: blob})[0] - return net_out - - - def meta_info(self): - return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} - - - def forward(self, imgs): - if not isinstance(imgs, list): - imgs = [imgs] - input_size = self.image_size - if self.crop is not None: - nimgs = [] - for img in imgs: - nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] - if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: - nimg = cv2.resize(nimg, input_size) - nimgs.append(nimg) - imgs = nimgs - blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) - net_out = self.session.run(self.output_names, {self.input_name : blob})[0] - return net_out - - def benchmark(self, img): - input_size = self.image_size - if self.crop is not None: - nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] - if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: - nimg = cv2.resize(nimg, input_size) - img = nimg - blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) - costs = [] - for _ in range(50): - ta = datetime.datetime.now() - net_out = self.session.run(self.output_names, {self.input_name : blob})[0] - tb = datetime.datetime.now() - cost = (tb-ta).total_seconds() - costs.append(cost) - costs = sorted(costs) - cost = costs[5] - return net_out, cost - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='') - # general - parser.add_argument('workdir', help='submitted work dir', type=str) - parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') - args = parser.parse_args() - handler = ArcFaceORT(args.workdir) - err = handler.check(args.track) - print('err:', err) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_ijbc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_ijbc.py deleted file mode 100644 index aa96b96745e23d4d6642d99f71456c10af5e4e4e..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/onnx_ijbc.py +++ /dev/null @@ -1,267 +0,0 @@ -import argparse -import os -import pickle -import timeit - -import cv2 -import mxnet as mx -import numpy as np -import pandas as pd -import prettytable -import skimage.transform -from sklearn.metrics import roc_curve -from sklearn.preprocessing import normalize - -from onnx_helper import ArcFaceORT - -SRC = np.array( - [ - [30.2946, 51.6963], - [65.5318, 51.5014], - [48.0252, 71.7366], - [33.5493, 92.3655], - [62.7299, 92.2041]] - , dtype=np.float32) -SRC[:, 0] += 8.0 - - -class AlignedDataSet(mx.gluon.data.Dataset): - def __init__(self, root, lines, align=True): - self.lines = lines - self.root = root - self.align = align - - def __len__(self): - return len(self.lines) - - def __getitem__(self, idx): - each_line = self.lines[idx] - name_lmk_score = each_line.strip().split(' ') - name = os.path.join(self.root, name_lmk_score[0]) - img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) - landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) - st = skimage.transform.SimilarityTransform() - st.estimate(landmark5, SRC) - img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) - img_1 = np.expand_dims(img, 0) - img_2 = np.expand_dims(np.fliplr(img), 0) - output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) - output = np.transpose(output, (0, 3, 1, 2)) - output = mx.nd.array(output) - return output - - -def extract(model_root, dataset): - model = ArcFaceORT(model_path=model_root) - model.check() - feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) - - def batchify_fn(data): - return mx.nd.concat(*data, dim=0) - - data_loader = mx.gluon.data.DataLoader( - dataset, 128, last_batch='keep', num_workers=4, - thread_pool=True, prefetch=16, batchify_fn=batchify_fn) - num_iter = 0 - for batch in data_loader: - batch = batch.asnumpy() - batch = (batch - model.input_mean) / model.input_std - feat = model.session.run(model.output_names, {model.input_name: batch})[0] - feat = np.reshape(feat, (-1, model.feat_dim * 2)) - feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat - num_iter += 1 - if num_iter % 50 == 0: - print(num_iter) - return feat_mat - - -def read_template_media_list(path): - ijb_meta = pd.read_csv(path, sep=' ', header=None).values - templates = ijb_meta[:, 1].astype(np.int) - medias = ijb_meta[:, 2].astype(np.int) - return templates, medias - - -def read_template_pair_list(path): - pairs = pd.read_csv(path, sep=' ', header=None).values - t1 = pairs[:, 0].astype(np.int) - t2 = pairs[:, 1].astype(np.int) - label = pairs[:, 2].astype(np.int) - return t1, t2, label - - -def read_image_feature(path): - with open(path, 'rb') as fid: - img_feats = pickle.load(fid) - return img_feats - - -def image2template_feature(img_feats=None, - templates=None, - medias=None): - unique_templates = np.unique(templates) - template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) - for count_template, uqt in enumerate(unique_templates): - (ind_t,) = np.where(templates == uqt) - face_norm_feats = img_feats[ind_t] - face_medias = medias[ind_t] - unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) - media_norm_feats = [] - for u, ct in zip(unique_medias, unique_media_counts): - (ind_m,) = np.where(face_medias == u) - if ct == 1: - media_norm_feats += [face_norm_feats[ind_m]] - else: # image features from the same video will be aggregated into one feature - media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] - media_norm_feats = np.array(media_norm_feats) - template_feats[count_template] = np.sum(media_norm_feats, axis=0) - if count_template % 2000 == 0: - print('Finish Calculating {} template features.'.format( - count_template)) - template_norm_feats = normalize(template_feats) - return template_norm_feats, unique_templates - - -def verification(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - score = np.zeros((len(p1),)) - total_pairs = np.array(range(len(p1))) - batchsize = 100000 - sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -def verification2(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - score = np.zeros((len(p1),)) # save cosine distance between pairs - total_pairs = np.array(range(len(p1))) - batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation - sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -def main(args): - use_norm_score = True # if Ture, TestMode(N1) - use_detector_score = True # if Ture, TestMode(D1) - use_flip_test = True # if Ture, TestMode(F1) - assert args.target == 'IJBC' or args.target == 'IJBB' - - start = timeit.default_timer() - templates, medias = read_template_media_list( - os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - - start = timeit.default_timer() - p1, p2, label = read_template_pair_list( - os.path.join('%s/meta' % args.image_path, - '%s_template_pair_label.txt' % args.target.lower())) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - - start = timeit.default_timer() - img_path = '%s/loose_crop' % args.image_path - img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) - img_list = open(img_list_path) - files = img_list.readlines() - dataset = AlignedDataSet(root=img_path, lines=files, align=True) - img_feats = extract(args.model_root, dataset) - - faceness_scores = [] - for each_line in files: - name_lmk_score = each_line.split() - faceness_scores.append(name_lmk_score[-1]) - faceness_scores = np.array(faceness_scores).astype(np.float32) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) - start = timeit.default_timer() - - if use_flip_test: - img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] - else: - img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] - - if use_norm_score: - img_input_feats = img_input_feats - else: - img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) - - if use_detector_score: - print(img_input_feats.shape, faceness_scores.shape) - img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] - else: - img_input_feats = img_input_feats - - template_norm_feats, unique_templates = image2template_feature( - img_input_feats, templates, medias) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - - start = timeit.default_timer() - score = verification(template_norm_feats, unique_templates, p1, p2) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) - if not os.path.exists(save_path): - os.makedirs(save_path) - score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) - np.save(score_save_file, score) - files = [score_save_file] - methods = [] - scores = [] - for file in files: - methods.append(os.path.basename(file)) - scores.append(np.load(file)) - methods = np.array(methods) - scores = dict(zip(methods, scores)) - x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] - tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) - for method in methods: - fpr, tpr, _ = roc_curve(label, scores[method]) - fpr = np.flipud(fpr) - tpr = np.flipud(tpr) - tpr_fpr_row = [] - tpr_fpr_row.append("%s-%s" % (method, args.target)) - for fpr_iter in np.arange(len(x_labels)): - _, min_index = min( - list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) - tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) - tpr_fpr_table.add_row(tpr_fpr_row) - print(tpr_fpr_table) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='do ijb test') - # general - parser.add_argument('--model-root', default='', help='path to load model.') - parser.add_argument('--image-path', default='', type=str, help='') - parser.add_argument('--result-dir', default='.', type=str, help='') - parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') - main(parser.parse_args()) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/partial_fc.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/partial_fc.py deleted file mode 100644 index e0286dd437319c920ecb61f4eb3a32333dcf49eb..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/partial_fc.py +++ /dev/null @@ -1,222 +0,0 @@ -import logging -import os - -import torch -import torch.distributed as dist -from torch.nn import Module -from torch.nn.functional import normalize, linear -from torch.nn.parameter import Parameter - - -class PartialFC(Module): - """ - Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, - Partial FC: Training 10 Million Identities on a Single Machine - See the original paper: - https://arxiv.org/abs/2010.05222 - """ - - @torch.no_grad() - def __init__(self, rank, local_rank, world_size, batch_size, resume, - margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): - """ - rank: int - Unique process(GPU) ID from 0 to world_size - 1. - local_rank: int - Unique process(GPU) ID within the server from 0 to 7. - world_size: int - Number of GPU. - batch_size: int - Batch size on current rank(GPU). - resume: bool - Select whether to restore the weight of softmax. - margin_softmax: callable - A function of margin softmax, eg: cosface, arcface. - num_classes: int - The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, - required. - sample_rate: float - The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling - can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. - embedding_size: int - The feature dimension, default is 512. - prefix: str - Path for save checkpoint, default is './'. - """ - super(PartialFC, self).__init__() - # - self.num_classes: int = num_classes - self.rank: int = rank - self.local_rank: int = local_rank - self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) - self.world_size: int = world_size - self.batch_size: int = batch_size - self.margin_softmax: callable = margin_softmax - self.sample_rate: float = sample_rate - self.embedding_size: int = embedding_size - self.prefix: str = prefix - self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) - self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) - self.num_sample: int = int(self.sample_rate * self.num_local) - - self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) - self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) - - if resume: - try: - self.weight: torch.Tensor = torch.load(self.weight_name) - self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) - if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: - raise IndexError - logging.info("softmax weight resume successfully!") - logging.info("softmax weight mom resume successfully!") - except (FileNotFoundError, KeyError, IndexError): - self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) - self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) - logging.info("softmax weight init!") - logging.info("softmax weight mom init!") - else: - self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) - self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) - logging.info("softmax weight init successfully!") - logging.info("softmax weight mom init successfully!") - self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) - - self.index = None - if int(self.sample_rate) == 1: - self.update = lambda: 0 - self.sub_weight = Parameter(self.weight) - self.sub_weight_mom = self.weight_mom - else: - self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) - - def save_params(self): - """ Save softmax weight for each rank on prefix - """ - torch.save(self.weight.data, self.weight_name) - torch.save(self.weight_mom, self.weight_mom_name) - - @torch.no_grad() - def sample(self, total_label): - """ - Sample all positive class centers in each rank, and random select neg class centers to filling a fixed - `num_sample`. - - total_label: tensor - Label after all gather, which cross all GPUs. - """ - index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) - total_label[~index_positive] = -1 - total_label[index_positive] -= self.class_start - if int(self.sample_rate) != 1: - positive = torch.unique(total_label[index_positive], sorted=True) - if self.num_sample - positive.size(0) >= 0: - perm = torch.rand(size=[self.num_local], device=self.device) - perm[positive] = 2.0 - index = torch.topk(perm, k=self.num_sample)[1] - index = index.sort()[0] - else: - index = positive - self.index = index - total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) - self.sub_weight = Parameter(self.weight[index]) - self.sub_weight_mom = self.weight_mom[index] - - def forward(self, total_features, norm_weight): - """ Partial fc forward, `logits = X * sample(W)` - """ - torch.cuda.current_stream().wait_stream(self.stream) - logits = linear(total_features, norm_weight) - return logits - - @torch.no_grad() - def update(self): - """ Set updated weight and weight_mom to memory bank. - """ - self.weight_mom[self.index] = self.sub_weight_mom - self.weight[self.index] = self.sub_weight - - def prepare(self, label, optimizer): - """ - get sampled class centers for cal softmax. - - label: tensor - Label tensor on each rank. - optimizer: opt - Optimizer for partial fc, which need to get weight mom. - """ - with torch.cuda.stream(self.stream): - total_label = torch.zeros( - size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) - dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) - self.sample(total_label) - optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) - optimizer.param_groups[-1]['params'][0] = self.sub_weight - optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom - norm_weight = normalize(self.sub_weight) - return total_label, norm_weight - - def forward_backward(self, label, features, optimizer): - """ - Partial fc forward and backward with model parallel - - label: tensor - Label tensor on each rank(GPU) - features: tensor - Features tensor on each rank(GPU) - optimizer: optimizer - Optimizer for partial fc - - Returns: - -------- - x_grad: tensor - The gradient of features. - loss_v: tensor - Loss value for cross entropy. - """ - total_label, norm_weight = self.prepare(label, optimizer) - total_features = torch.zeros( - size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) - dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) - total_features.requires_grad = True - - logits = self.forward(total_features, norm_weight) - logits = self.margin_softmax(logits, total_label) - - with torch.no_grad(): - max_fc = torch.max(logits, dim=1, keepdim=True)[0] - dist.all_reduce(max_fc, dist.ReduceOp.MAX) - - # calculate exp(logits) and all-reduce - logits_exp = torch.exp(logits - max_fc) - logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) - dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) - - # calculate prob - logits_exp.div_(logits_sum_exp) - - # get one-hot - grad = logits_exp - index = torch.where(total_label != -1)[0] - one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) - one_hot.scatter_(1, total_label[index, None], 1) - - # calculate loss - loss = torch.zeros(grad.size()[0], 1, device=grad.device) - loss[index] = grad[index].gather(1, total_label[index, None]) - dist.all_reduce(loss, dist.ReduceOp.SUM) - loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) - - # calculate grad - grad[index] -= one_hot - grad.div_(self.batch_size * self.world_size) - - logits.backward(grad) - if total_features.grad is not None: - total_features.grad.detach_() - x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) - # feature gradient all-reduce - dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) - x_grad = x_grad * self.world_size - # backward backbone - return x_grad, loss_v diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/requirement.txt b/sadtalker_audio2pose/src/face3d/models/arcface_torch/requirement.txt deleted file mode 100644 index 99aef673e30b99cbe56ce82a564c1df9df24ba21..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/requirement.txt +++ /dev/null @@ -1,5 +0,0 @@ -tensorboard -easydict -mxnet -onnx -sklearn diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/run.sh b/sadtalker_audio2pose/src/face3d/models/arcface_torch/run.sh deleted file mode 100644 index 67b25fd63ef3921733d81d5be844aacc5a5c84ed..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/run.sh +++ /dev/null @@ -1,2 +0,0 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 -ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/torch2onnx.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/torch2onnx.py deleted file mode 100644 index 458660df7cc7f9a567aaf492c45f232e776a9ef0..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/torch2onnx.py +++ /dev/null @@ -1,59 +0,0 @@ -import numpy as np -import onnx -import torch - - -def convert_onnx(net, path_module, output, opset=11, simplify=False): - assert isinstance(net, torch.nn.Module) - img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) - img = img.astype(np.float) - img = (img / 255. - 0.5) / 0.5 # torch style norm - img = img.transpose((2, 0, 1)) - img = torch.from_numpy(img).unsqueeze(0).float() - - weight = torch.load(path_module) - net.load_state_dict(weight) - net.eval() - torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) - model = onnx.load(output) - graph = model.graph - graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' - if simplify: - from onnxsim import simplify - model, check = simplify(model) - assert check, "Simplified ONNX model could not be validated" - onnx.save(model, output) - - -if __name__ == '__main__': - import os - import argparse - from backbones import get_model - - parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') - parser.add_argument('input', type=str, help='input backbone.pth file or path') - parser.add_argument('--output', type=str, default=None, help='output onnx path') - parser.add_argument('--network', type=str, default=None, help='backbone network') - parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') - args = parser.parse_args() - input_file = args.input - if os.path.isdir(input_file): - input_file = os.path.join(input_file, "backbone.pth") - assert os.path.exists(input_file) - model_name = os.path.basename(os.path.dirname(input_file)).lower() - params = model_name.split("_") - if len(params) >= 3 and params[1] in ('arcface', 'cosface'): - if args.network is None: - args.network = params[2] - assert args.network is not None - print(args) - backbone_onnx = get_model(args.network, dropout=0) - - output_path = args.output - if output_path is None: - output_path = os.path.join(os.path.dirname(__file__), 'onnx') - if not os.path.exists(output_path): - os.makedirs(output_path) - assert os.path.isdir(output_path) - output_file = os.path.join(output_path, "%s.onnx" % model_name) - convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/train.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/train.py deleted file mode 100644 index 0c5491de9af8fc7a2f3d0648c53b89584864f20e..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/train.py +++ /dev/null @@ -1,141 +0,0 @@ -import argparse -import logging -import os - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.utils.data.distributed -from torch.nn.utils import clip_grad_norm_ - -import losses -from backbones import get_model -from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX -from partial_fc import PartialFC -from utils.utils_amp import MaxClipGradScaler -from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint -from utils.utils_config import get_config -from utils.utils_logging import AverageMeter, init_logging - - -def main(args): - cfg = get_config(args.config) - try: - world_size = int(os.environ['WORLD_SIZE']) - rank = int(os.environ['RANK']) - dist.init_process_group('nccl') - except KeyError: - world_size = 1 - rank = 0 - dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) - - local_rank = args.local_rank - torch.cuda.set_device(local_rank) - os.makedirs(cfg.output, exist_ok=True) - init_logging(rank, cfg.output) - - if cfg.rec == "synthetic": - train_set = SyntheticDataset(local_rank=local_rank) - else: - train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) - - train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) - train_loader = DataLoaderX( - local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, - sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) - backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) - - if cfg.resume: - try: - backbone_pth = os.path.join(cfg.output, "backbone.pth") - backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) - if rank == 0: - logging.info("backbone resume successfully!") - except (FileNotFoundError, KeyError, IndexError, RuntimeError): - if rank == 0: - logging.info("resume fail, backbone init successfully!") - - backbone = torch.nn.parallel.DistributedDataParallel( - module=backbone, broadcast_buffers=False, device_ids=[local_rank]) - backbone.train() - margin_softmax = losses.get_loss(cfg.loss) - module_partial_fc = PartialFC( - rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, - batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, - sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) - - opt_backbone = torch.optim.SGD( - params=[{'params': backbone.parameters()}], - lr=cfg.lr / 512 * cfg.batch_size * world_size, - momentum=0.9, weight_decay=cfg.weight_decay) - opt_pfc = torch.optim.SGD( - params=[{'params': module_partial_fc.parameters()}], - lr=cfg.lr / 512 * cfg.batch_size * world_size, - momentum=0.9, weight_decay=cfg.weight_decay) - - num_image = len(train_set) - total_batch_size = cfg.batch_size * world_size - cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch - cfg.total_step = num_image // total_batch_size * cfg.num_epoch - - def lr_step_func(current_step): - cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] - if current_step < cfg.warmup_step: - return current_step / cfg.warmup_step - else: - return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) - - scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( - optimizer=opt_backbone, lr_lambda=lr_step_func) - scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( - optimizer=opt_pfc, lr_lambda=lr_step_func) - - for key, value in cfg.items(): - num_space = 25 - len(key) - logging.info(": " + key + " " * num_space + str(value)) - - val_target = cfg.val_targets - callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) - callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) - callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) - - loss = AverageMeter() - start_epoch = 0 - global_step = 0 - grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None - for epoch in range(start_epoch, cfg.num_epoch): - train_sampler.set_epoch(epoch) - for step, (img, label) in enumerate(train_loader): - global_step += 1 - features = F.normalize(backbone(img)) - x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) - if cfg.fp16: - features.backward(grad_amp.scale(x_grad)) - grad_amp.unscale_(opt_backbone) - clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) - grad_amp.step(opt_backbone) - grad_amp.update() - else: - features.backward(x_grad) - clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) - opt_backbone.step() - - opt_pfc.step() - module_partial_fc.update() - opt_backbone.zero_grad() - opt_pfc.zero_grad() - loss.update(loss_v, 1) - callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) - callback_verification(global_step, backbone) - scheduler_backbone.step() - scheduler_pfc.step() - callback_checkpoint(global_step, backbone, module_partial_fc) - dist.destroy_process_group() - - -if __name__ == "__main__": - torch.backends.cudnn.benchmark = True - parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') - parser.add_argument('config', type=str, help='py config file') - parser.add_argument('--local_rank', type=int, default=0, help='local_rank') - main(parser.parse_args()) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/__init__.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/plot.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/plot.py deleted file mode 100644 index 4fce6cc0ae526d5aebc8e7a1550300ceae3a2034..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/plot.py +++ /dev/null @@ -1,72 +0,0 @@ -# coding: utf-8 - -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap -from prettytable import PrettyTable -from sklearn.metrics import roc_curve, auc - -image_path = "/data/anxiang/IJB_release/IJBC" -files = [ - "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" -] - - -def read_template_pair_list(path): - pairs = pd.read_csv(path, sep=' ', header=None).values - t1 = pairs[:, 0].astype(np.int) - t2 = pairs[:, 1].astype(np.int) - label = pairs[:, 2].astype(np.int) - return t1, t2, label - - -p1, p2, label = read_template_pair_list( - os.path.join('%s/meta' % image_path, - '%s_template_pair_label.txt' % 'ijbc')) - -methods = [] -scores = [] -for file in files: - methods.append(file.split('/')[-2]) - scores.append(np.load(file)) - -methods = np.array(methods) -scores = dict(zip(methods, scores)) -colours = dict( - zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) -x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] -tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) -fig = plt.figure() -for method in methods: - fpr, tpr, _ = roc_curve(label, scores[method]) - roc_auc = auc(fpr, tpr) - fpr = np.flipud(fpr) - tpr = np.flipud(tpr) # select largest tpr at same fpr - plt.plot(fpr, - tpr, - color=colours[method], - lw=1, - label=('[%s (AUC = %0.4f %%)]' % - (method.split('-')[-1], roc_auc * 100))) - tpr_fpr_row = [] - tpr_fpr_row.append("%s-%s" % (method, "IJBC")) - for fpr_iter in np.arange(len(x_labels)): - _, min_index = min( - list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) - tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) - tpr_fpr_table.add_row(tpr_fpr_row) -plt.xlim([10 ** -6, 0.1]) -plt.ylim([0.3, 1.0]) -plt.grid(linestyle='--', linewidth=1) -plt.xticks(x_labels) -plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) -plt.xscale('log') -plt.xlabel('False Positive Rate') -plt.ylabel('True Positive Rate') -plt.title('ROC on IJB') -plt.legend(loc="lower right") -print(tpr_fpr_table) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_amp.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_amp.py deleted file mode 100644 index a6d5bcbb540ff8b04535e71c0057e124338df5bd..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_amp.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Dict, List - -import torch - -if torch.__version__ < '1.9': - Iterable = torch._six.container_abcs.Iterable -else: - import collections - - Iterable = collections.abc.Iterable -from torch.cuda.amp import GradScaler - - -class _MultiDeviceReplicator(object): - """ - Lazily serves copies of a tensor to requested devices. Copies are cached per-device. - """ - - def __init__(self, master_tensor: torch.Tensor) -> None: - assert master_tensor.is_cuda - self.master = master_tensor - self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} - - def get(self, device) -> torch.Tensor: - retval = self._per_device_tensors.get(device, None) - if retval is None: - retval = self.master.to(device=device, non_blocking=True, copy=True) - self._per_device_tensors[device] = retval - return retval - - -class MaxClipGradScaler(GradScaler): - def __init__(self, init_scale, max_scale: float, growth_interval=100): - GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) - self.max_scale = max_scale - - def scale_clip(self): - if self.get_scale() == self.max_scale: - self.set_growth_factor(1) - elif self.get_scale() < self.max_scale: - self.set_growth_factor(2) - elif self.get_scale() > self.max_scale: - self._scale.fill_(self.max_scale) - self.set_growth_factor(1) - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Arguments: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - if not self._enabled: - return outputs - self.scale_clip() - # Short-circuit for the common case. - if isinstance(outputs, torch.Tensor): - assert outputs.is_cuda - if self._scale is None: - self._lazy_init_scale_growth_tracker(outputs.device) - assert self._scale is not None - return outputs * self._scale.to(device=outputs.device, non_blocking=True) - - # Invoke the more complex machinery only if we're treating multiple outputs. - stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale - - def apply_scale(val): - if isinstance(val, torch.Tensor): - assert val.is_cuda - if len(stash) == 0: - if self._scale is None: - self._lazy_init_scale_growth_tracker(val.device) - assert self._scale is not None - stash.append(_MultiDeviceReplicator(self._scale)) - return val * stash[0].get(val.device) - elif isinstance(val, Iterable): - iterable = map(apply_scale, val) - if isinstance(val, list) or isinstance(val, tuple): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py deleted file mode 100644 index 748923b36358bd118efa0532a6f512b6ca96ff34..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -import os -import time -from typing import List - -import torch - -from eval import verification -from utils.utils_logging import AverageMeter - - -class CallBackVerification(object): - def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): - self.frequent: int = frequent - self.rank: int = rank - self.highest_acc: float = 0.0 - self.highest_acc_list: List[float] = [0.0] * len(val_targets) - self.ver_list: List[object] = [] - self.ver_name_list: List[str] = [] - if self.rank is 0: - self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) - - def ver_test(self, backbone: torch.nn.Module, global_step: int): - results = [] - for i in range(len(self.ver_list)): - acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( - self.ver_list[i], backbone, 10, 10) - logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) - logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) - if acc2 > self.highest_acc_list[i]: - self.highest_acc_list[i] = acc2 - logging.info( - '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) - results.append(acc2) - - def init_dataset(self, val_targets, data_dir, image_size): - for name in val_targets: - path = os.path.join(data_dir, name + ".bin") - if os.path.exists(path): - data_set = verification.load_bin(path, image_size) - self.ver_list.append(data_set) - self.ver_name_list.append(name) - - def __call__(self, num_update, backbone: torch.nn.Module): - if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: - backbone.eval() - self.ver_test(backbone, num_update) - backbone.train() - - -class CallBackLogging(object): - def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): - self.frequent: int = frequent - self.rank: int = rank - self.time_start = time.time() - self.total_step: int = total_step - self.batch_size: int = batch_size - self.world_size: int = world_size - self.writer = writer - - self.init = False - self.tic = 0 - - def __call__(self, - global_step: int, - loss: AverageMeter, - epoch: int, - fp16: bool, - learning_rate: float, - grad_scaler: torch.cuda.amp.GradScaler): - if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: - if self.init: - try: - speed: float = self.frequent * self.batch_size / (time.time() - self.tic) - speed_total = speed * self.world_size - except ZeroDivisionError: - speed_total = float('inf') - - time_now = (time.time() - self.time_start) / 3600 - time_total = time_now / ((global_step + 1) / self.total_step) - time_for_end = time_total - time_now - if self.writer is not None: - self.writer.add_scalar('time_for_end', time_for_end, global_step) - self.writer.add_scalar('learning_rate', learning_rate, global_step) - self.writer.add_scalar('loss', loss.avg, global_step) - if fp16: - msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ - "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( - speed_total, loss.avg, learning_rate, epoch, global_step, - grad_scaler.get_scale(), time_for_end - ) - else: - msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ - "Required: %1.f hours" % ( - speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end - ) - logging.info(msg) - loss.reset() - self.tic = time.time() - else: - self.init = True - self.tic = time.time() - - -class CallBackModelCheckpoint(object): - def __init__(self, rank, output="./"): - self.rank: int = rank - self.output: str = output - - def __call__(self, global_step, backbone, partial_fc, ): - if global_step > 100 and self.rank == 0: - path_module = os.path.join(self.output, "backbone.pth") - torch.save(backbone.module.state_dict(), path_module) - logging.info("Pytorch Model Saved in '{}'".format(path_module)) - - if global_step > 100 and partial_fc is not None: - partial_fc.save_params() diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_config.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_config.py deleted file mode 100644 index b60a1e5a2e860ce5511a2d3863c8b57a4df292d7..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_config.py +++ /dev/null @@ -1,16 +0,0 @@ -import importlib -import os.path as osp - - -def get_config(config_file): - assert config_file.startswith('configs/'), 'config file setting must start with configs/' - temp_config_name = osp.basename(config_file) - temp_module_name = osp.splitext(temp_config_name)[0] - config = importlib.import_module("configs.base") - cfg = config.config - config = importlib.import_module("configs.%s" % temp_module_name) - job_cfg = config.config - cfg.update(job_cfg) - if cfg.output is None: - cfg.output = osp.join('work_dirs', temp_module_name) - return cfg \ No newline at end of file diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_logging.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_logging.py deleted file mode 100644 index f2b43b851c9e06230abd94c73a1f64cfa1b6f3ac..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_logging.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -import os -import sys - - -class AverageMeter(object): - """Computes and stores the average and current value - """ - - def __init__(self): - self.val = None - self.avg = None - self.sum = None - self.count = None - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def init_logging(rank, models_root): - if rank == 0: - log_root = logging.getLogger() - log_root.setLevel(logging.INFO) - formatter = logging.Formatter("Training: %(asctime)s-%(message)s") - handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) - handler_stream = logging.StreamHandler(sys.stdout) - handler_file.setFormatter(formatter) - handler_stream.setFormatter(formatter) - log_root.addHandler(handler_file) - log_root.addHandler(handler_stream) - log_root.info('rank_id: %d' % rank) diff --git a/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_os.py b/sadtalker_audio2pose/src/face3d/models/arcface_torch/utils/utils_os.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_audio2pose/src/face3d/models/base_model.py b/sadtalker_audio2pose/src/face3d/models/base_model.py deleted file mode 100644 index b975223f6148febfe32d20d63980583c97b61eb3..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/base_model.py +++ /dev/null @@ -1,316 +0,0 @@ -"""This script defines the base network model for Deep3DFaceRecon_pytorch -""" - -import os -import numpy as np -import torch -from collections import OrderedDict -from abc import ABC, abstractmethod -from . import networks - - -class BaseModel(ABC): - """This class is an abstract base class (ABC) for models. - To create a subclass, you need to implement the following five functions: - -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). - -- : unpack data from dataset and apply preprocessing. - -- : produce intermediate results. - -- : calculate losses, gradients, and update network weights. - -- : (optionally) add model-specific options and set default options. - """ - - def __init__(self, opt): - """Initialize the BaseModel class. - - Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions - - When creating your custom class, you need to implement your own initialization. - In this fucntion, you should first call - Then, you need to define four lists: - -- self.loss_names (str list): specify the training losses that you want to plot and save. - -- self.model_names (str list): specify the images that you want to display and save. - -- self.visual_names (str list): define networks used in our training. - -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. - """ - self.opt = opt - self.isTrain = False - self.device = torch.device('cpu') - self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir - self.loss_names = [] - self.model_names = [] - self.visual_names = [] - self.parallel_names = [] - self.optimizers = [] - self.image_paths = [] - self.metric = 0 # used for learning rate policy 'plateau' - - @staticmethod - def dict_grad_hook_factory(add_func=lambda x: x): - saved_dict = dict() - - def hook_gen(name): - def grad_hook(grad): - saved_vals = add_func(grad) - saved_dict[name] = saved_vals - return grad_hook - return hook_gen, saved_dict - - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new model-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - return parser - - @abstractmethod - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input (dict): includes the data itself and its metadata information. - """ - pass - - @abstractmethod - def forward(self): - """Run forward pass; called by both functions and .""" - pass - - @abstractmethod - def optimize_parameters(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - pass - - def setup(self, opt): - """Load and print networks; create schedulers - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - if self.isTrain: - self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] - - if not self.isTrain or opt.continue_train: - load_suffix = opt.epoch - self.load_networks(load_suffix) - - - # self.print_networks(opt.verbose) - - def parallelize(self, convert_sync_batchnorm=True): - if not self.opt.use_ddp: - for name in self.parallel_names: - if isinstance(name, str): - module = getattr(self, name) - setattr(self, name, module.to(self.device)) - else: - for name in self.model_names: - if isinstance(name, str): - module = getattr(self, name) - if convert_sync_batchnorm: - module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) - setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), - device_ids=[self.device.index], - find_unused_parameters=True, broadcast_buffers=True)) - - # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. - for name in self.parallel_names: - if isinstance(name, str) and name not in self.model_names: - module = getattr(self, name) - setattr(self, name, module.to(self.device)) - - # put state_dict of optimizer to gpu device - if self.opt.phase != 'test': - if self.opt.continue_train: - for optim in self.optimizers: - for state in optim.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.to(self.device) - - def data_dependent_initialize(self, data): - pass - - def train(self): - """Make models train mode""" - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - net.train() - - def eval(self): - """Make models eval mode""" - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - net.eval() - - def test(self): - """Forward function used in test time. - - This function wraps function in no_grad() so we don't save intermediate steps for backprop - It also calls to produce additional visualization results - """ - with torch.no_grad(): - self.forward() - self.compute_visuals() - - def compute_visuals(self): - """Calculate additional output images for visdom and HTML visualization""" - pass - - def get_image_paths(self, name='A'): - """ Return image paths that are used to load current data""" - return self.image_paths if name =='A' else self.image_paths_B - - def update_learning_rate(self): - """Update learning rates for all the networks; called at the end of every epoch""" - for scheduler in self.schedulers: - if self.opt.lr_policy == 'plateau': - scheduler.step(self.metric) - else: - scheduler.step() - - lr = self.optimizers[0].param_groups[0]['lr'] - print('learning rate = %.7f' % lr) - - def get_current_visuals(self): - """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" - visual_ret = OrderedDict() - for name in self.visual_names: - if isinstance(name, str): - visual_ret[name] = getattr(self, name)[:, :3, ...] - return visual_ret - - def get_current_losses(self): - """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" - errors_ret = OrderedDict() - for name in self.loss_names: - if isinstance(name, str): - errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number - return errors_ret - - def save_networks(self, epoch): - """Save all the networks to the disk. - - Parameters: - epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) - """ - if not os.path.isdir(self.save_dir): - os.makedirs(self.save_dir) - - save_filename = 'epoch_%s.pth' % (epoch) - save_path = os.path.join(self.save_dir, save_filename) - - save_dict = {} - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - if isinstance(net, torch.nn.DataParallel) or isinstance(net, - torch.nn.parallel.DistributedDataParallel): - net = net.module - save_dict[name] = net.state_dict() - - - for i, optim in enumerate(self.optimizers): - save_dict['opt_%02d'%i] = optim.state_dict() - - for i, sched in enumerate(self.schedulers): - save_dict['sched_%02d'%i] = sched.state_dict() - - torch.save(save_dict, save_path) - - def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): - """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" - key = keys[i] - if i + 1 == len(keys): # at the end, pointing to a parameter/buffer - if module.__class__.__name__.startswith('InstanceNorm') and \ - (key == 'running_mean' or key == 'running_var'): - if getattr(module, key) is None: - state_dict.pop('.'.join(keys)) - if module.__class__.__name__.startswith('InstanceNorm') and \ - (key == 'num_batches_tracked'): - state_dict.pop('.'.join(keys)) - else: - self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) - - def load_networks(self, epoch): - """Load all the networks from the disk. - - Parameters: - epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) - """ - if self.opt.isTrain and self.opt.pretrained_name is not None: - load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) - else: - load_dir = self.save_dir - load_filename = 'epoch_%s.pth' % (epoch) - load_path = os.path.join(load_dir, load_filename) - state_dict = torch.load(load_path, map_location=self.device) - print('loading the model from %s' % load_path) - - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - if isinstance(net, torch.nn.DataParallel): - net = net.module - net.load_state_dict(state_dict[name]) - - if self.opt.phase != 'test': - if self.opt.continue_train: - print('loading the optim from %s' % load_path) - for i, optim in enumerate(self.optimizers): - optim.load_state_dict(state_dict['opt_%02d'%i]) - - try: - print('loading the sched from %s' % load_path) - for i, sched in enumerate(self.schedulers): - sched.load_state_dict(state_dict['sched_%02d'%i]) - except: - print('Failed to load schedulers, set schedulers according to epoch count manually') - for i, sched in enumerate(self.schedulers): - sched.last_epoch = self.opt.epoch_count - 1 - - - - - def print_networks(self, verbose): - """Print the total number of parameters in the network and (if verbose) network architecture - - Parameters: - verbose (bool) -- if verbose: print the network architecture - """ - print('---------- Networks initialized -------------') - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - num_params = 0 - for param in net.parameters(): - num_params += param.numel() - if verbose: - print(net) - print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) - print('-----------------------------------------------') - - def set_requires_grad(self, nets, requires_grad=False): - """Set requies_grad=Fasle for all the networks to avoid unnecessary computations - Parameters: - nets (network list) -- a list of networks - requires_grad (bool) -- whether the networks require gradients or not - """ - if not isinstance(nets, list): - nets = [nets] - for net in nets: - if net is not None: - for param in net.parameters(): - param.requires_grad = requires_grad - - def generate_visuals_for_evaluation(self, data, mode): - return {} diff --git a/sadtalker_audio2pose/src/face3d/models/bfm.py b/sadtalker_audio2pose/src/face3d/models/bfm.py deleted file mode 100644 index 0cecaf589befac790cf9c124737ba01e27bc29e6..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/bfm.py +++ /dev/null @@ -1,331 +0,0 @@ -"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch -""" - -import numpy as np -import torch -import torch.nn.functional as F -from scipy.io import loadmat -from src.face3d.util.load_mats import transferBFM09 -import os - -def perspective_projection(focal, center): - # return p.T (N, 3) @ (3, 3) - return np.array([ - focal, 0, center, - 0, focal, center, - 0, 0, 1 - ]).reshape([3, 3]).astype(np.float32).transpose() - -class SH: - def __init__(self): - self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] - self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] - - - -class ParametricFaceModel: - def __init__(self, - bfm_folder='./BFM', - recenter=True, - camera_distance=10., - init_lit=np.array([ - 0.8, 0, 0, 0, 0, 0, 0, 0, 0 - ]), - focal=1015., - center=112., - is_train=True, - default_name='BFM_model_front.mat'): - - if not os.path.isfile(os.path.join(bfm_folder, default_name)): - transferBFM09(bfm_folder) - - model = loadmat(os.path.join(bfm_folder, default_name)) - # mean face shape. [3*N,1] - self.mean_shape = model['meanshape'].astype(np.float32) - # identity basis. [3*N,80] - self.id_base = model['idBase'].astype(np.float32) - # expression basis. [3*N,64] - self.exp_base = model['exBase'].astype(np.float32) - # mean face texture. [3*N,1] (0-255) - self.mean_tex = model['meantex'].astype(np.float32) - # texture basis. [3*N,80] - self.tex_base = model['texBase'].astype(np.float32) - # face indices for each vertex that lies in. starts from 0. [N,8] - self.point_buf = model['point_buf'].astype(np.int64) - 1 - # vertex indices for each face. starts from 0. [F,3] - self.face_buf = model['tri'].astype(np.int64) - 1 - # vertex indices for 68 landmarks. starts from 0. [68,1] - self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 - - if is_train: - # vertex indices for small face region to compute photometric error. starts from 0. - self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 - # vertex indices for each face from small face region. starts from 0. [f,3] - self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 - # vertex indices for pre-defined skin region to compute reflectance loss - self.skin_mask = np.squeeze(model['skinmask']) - - if recenter: - mean_shape = self.mean_shape.reshape([-1, 3]) - mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) - self.mean_shape = mean_shape.reshape([-1, 1]) - - self.persc_proj = perspective_projection(focal, center) - self.device = 'cpu' - self.camera_distance = camera_distance - self.SH = SH() - self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) - - - def to(self, device): - self.device = device - for key, value in self.__dict__.items(): - if type(value).__module__ == np.__name__: - setattr(self, key, torch.tensor(value).to(device)) - - - def compute_shape(self, id_coeff, exp_coeff): - """ - Return: - face_shape -- torch.tensor, size (B, N, 3) - - Parameters: - id_coeff -- torch.tensor, size (B, 80), identity coeffs - exp_coeff -- torch.tensor, size (B, 64), expression coeffs - """ - batch_size = id_coeff.shape[0] - id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) - exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) - face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) - return face_shape.reshape([batch_size, -1, 3]) - - - def compute_texture(self, tex_coeff, normalize=True): - """ - Return: - face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) - - Parameters: - tex_coeff -- torch.tensor, size (B, 80) - """ - batch_size = tex_coeff.shape[0] - face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex - if normalize: - face_texture = face_texture / 255. - return face_texture.reshape([batch_size, -1, 3]) - - - def compute_norm(self, face_shape): - """ - Return: - vertex_norm -- torch.tensor, size (B, N, 3) - - Parameters: - face_shape -- torch.tensor, size (B, N, 3) - """ - - v1 = face_shape[:, self.face_buf[:, 0]] - v2 = face_shape[:, self.face_buf[:, 1]] - v3 = face_shape[:, self.face_buf[:, 2]] - e1 = v1 - v2 - e2 = v2 - v3 - face_norm = torch.cross(e1, e2, dim=-1) - face_norm = F.normalize(face_norm, dim=-1, p=2) - face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) - - vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) - vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) - return vertex_norm - - - def compute_color(self, face_texture, face_norm, gamma): - """ - Return: - face_color -- torch.tensor, size (B, N, 3), range (0, 1.) - - Parameters: - face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) - face_norm -- torch.tensor, size (B, N, 3), rotated face normal - gamma -- torch.tensor, size (B, 27), SH coeffs - """ - batch_size = gamma.shape[0] - v_num = face_texture.shape[1] - a, c = self.SH.a, self.SH.c - gamma = gamma.reshape([batch_size, 3, 9]) - gamma = gamma + self.init_lit - gamma = gamma.permute(0, 2, 1) - Y = torch.cat([ - a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), - -a[1] * c[1] * face_norm[..., 1:2], - a[1] * c[1] * face_norm[..., 2:], - -a[1] * c[1] * face_norm[..., :1], - a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], - -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], - 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), - -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], - 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) - ], dim=-1) - r = Y @ gamma[..., :1] - g = Y @ gamma[..., 1:2] - b = Y @ gamma[..., 2:] - face_color = torch.cat([r, g, b], dim=-1) * face_texture - return face_color - - - def compute_rotation(self, angles): - """ - Return: - rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat - - Parameters: - angles -- torch.tensor, size (B, 3), radian - """ - - batch_size = angles.shape[0] - ones = torch.ones([batch_size, 1]).to(self.device) - zeros = torch.zeros([batch_size, 1]).to(self.device) - x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], - - rot_x = torch.cat([ - ones, zeros, zeros, - zeros, torch.cos(x), -torch.sin(x), - zeros, torch.sin(x), torch.cos(x) - ], dim=1).reshape([batch_size, 3, 3]) - - rot_y = torch.cat([ - torch.cos(y), zeros, torch.sin(y), - zeros, ones, zeros, - -torch.sin(y), zeros, torch.cos(y) - ], dim=1).reshape([batch_size, 3, 3]) - - rot_z = torch.cat([ - torch.cos(z), -torch.sin(z), zeros, - torch.sin(z), torch.cos(z), zeros, - zeros, zeros, ones - ], dim=1).reshape([batch_size, 3, 3]) - - rot = rot_z @ rot_y @ rot_x - return rot.permute(0, 2, 1) - - - def to_camera(self, face_shape): - face_shape[..., -1] = self.camera_distance - face_shape[..., -1] - return face_shape - - def to_image(self, face_shape): - """ - Return: - face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction - - Parameters: - face_shape -- torch.tensor, size (B, N, 3) - """ - # to image_plane - face_proj = face_shape @ self.persc_proj - face_proj = face_proj[..., :2] / face_proj[..., 2:] - - return face_proj - - - def transform(self, face_shape, rot, trans): - """ - Return: - face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans - - Parameters: - face_shape -- torch.tensor, size (B, N, 3) - rot -- torch.tensor, size (B, 3, 3) - trans -- torch.tensor, size (B, 3) - """ - return face_shape @ rot + trans.unsqueeze(1) - - - def get_landmarks(self, face_proj): - """ - Return: - face_lms -- torch.tensor, size (B, 68, 2) - - Parameters: - face_proj -- torch.tensor, size (B, N, 2) - """ - return face_proj[:, self.keypoints] - - def split_coeff(self, coeffs): - """ - Return: - coeffs_dict -- a dict of torch.tensors - - Parameters: - coeffs -- torch.tensor, size (B, 256) - """ - id_coeffs = coeffs[:, :80] - exp_coeffs = coeffs[:, 80: 144] - tex_coeffs = coeffs[:, 144: 224] - angles = coeffs[:, 224: 227] - gammas = coeffs[:, 227: 254] - translations = coeffs[:, 254:] - return { - 'id': id_coeffs, - 'exp': exp_coeffs, - 'tex': tex_coeffs, - 'angle': angles, - 'gamma': gammas, - 'trans': translations - } - def compute_for_render(self, coeffs): - """ - Return: - face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate - face_color -- torch.tensor, size (B, N, 3), in RGB order - landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction - Parameters: - coeffs -- torch.tensor, size (B, 257) - """ - coef_dict = self.split_coeff(coeffs) - face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) - rotation = self.compute_rotation(coef_dict['angle']) - - - face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) - face_vertex = self.to_camera(face_shape_transformed) - - face_proj = self.to_image(face_vertex) - landmark = self.get_landmarks(face_proj) - - face_texture = self.compute_texture(coef_dict['tex']) - face_norm = self.compute_norm(face_shape) - face_norm_roted = face_norm @ rotation - face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) - - return face_vertex, face_texture, face_color, landmark - - def compute_for_render_woRotation(self, coeffs): - """ - Return: - face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate - face_color -- torch.tensor, size (B, N, 3), in RGB order - landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction - Parameters: - coeffs -- torch.tensor, size (B, 257) - """ - coef_dict = self.split_coeff(coeffs) - face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) - #rotation = self.compute_rotation(coef_dict['angle']) - - - #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) - face_vertex = self.to_camera(face_shape) - - face_proj = self.to_image(face_vertex) - landmark = self.get_landmarks(face_proj) - - face_texture = self.compute_texture(coef_dict['tex']) - face_norm = self.compute_norm(face_shape) - face_norm_roted = face_norm # @ rotation - face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) - - return face_vertex, face_texture, face_color, landmark - - -if __name__ == '__main__': - transferBFM09() \ No newline at end of file diff --git a/sadtalker_audio2pose/src/face3d/models/facerecon_model.py b/sadtalker_audio2pose/src/face3d/models/facerecon_model.py deleted file mode 100644 index 6a8a701f4771fc337aa9b456310f4af4a6f86a69..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/facerecon_model.py +++ /dev/null @@ -1,220 +0,0 @@ -"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch -""" - -import numpy as np -import torch -from src.face3d.models.base_model import BaseModel -from src.face3d.models import networks -from src.face3d.models.bfm import ParametricFaceModel -from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss -from src.face3d.util import util -from src.face3d.util.nvdiffrast import MeshRenderer -# from src.face3d.util.preprocess import estimate_norm_torch - -import trimesh -from scipy.io import savemat - -class FaceReconModel(BaseModel): - - @staticmethod - def modify_commandline_options(parser, is_train=False): - """ Configures options specific for CUT model - """ - # net structure and parameters - parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') - parser.add_argument('--init_path', type=str, default='./ckpts/sad_talker/init_model/resnet50-0676ba61.pth') - parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') - parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talker/BFM_Fitting/') - parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') - - # renderer parameters - parser.add_argument('--focal', type=float, default=1015.) - parser.add_argument('--center', type=float, default=112.) - parser.add_argument('--camera_d', type=float, default=10.) - parser.add_argument('--z_near', type=float, default=5.) - parser.add_argument('--z_far', type=float, default=15.) - - if is_train: - # training parameters - parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') - parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') - parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') - parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') - - - # augmentation parameters - parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') - parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') - parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') - - # loss weights - parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') - parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') - parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') - parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') - parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') - parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') - parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') - parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') - parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') - - opt, _ = parser.parse_known_args() - parser.set_defaults( - focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. - ) - if is_train: - parser.set_defaults( - use_crop_face=True, use_predef_M=False - ) - return parser - - def __init__(self, opt): - """Initialize this model class. - - Parameters: - opt -- training/test options - - A few things can be done here. - - (required) call the initialization function of BaseModel - - define loss function, visualization images, model names, and optimizers - """ - BaseModel.__init__(self, opt) # call the initialization method of BaseModel - - self.visual_names = ['output_vis'] - self.model_names = ['net_recon'] - self.parallel_names = self.model_names + ['renderer'] - - self.facemodel = ParametricFaceModel( - bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, - is_train=self.isTrain, default_name=opt.bfm_model - ) - - fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi - self.renderer = MeshRenderer( - rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center) - ) - - if self.isTrain: - self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] - - self.net_recog = networks.define_net_recog( - net_recog=opt.net_recog, pretrained_path=opt.net_recog_path - ) - # loss func name: (compute_%s_loss) % loss_name - self.compute_feat_loss = perceptual_loss - self.comupte_color_loss = photo_loss - self.compute_lm_loss = landmark_loss - self.compute_reg_loss = reg_loss - self.compute_reflc_loss = reflectance_loss - - self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) - self.optimizers = [self.optimizer] - self.parallel_names += ['net_recog'] - # Our program will automatically call to define schedulers, load networks, and print networks - - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input: a dictionary that contains the data itself and its metadata information. - """ - self.input_img = input['imgs'].to(self.device) - self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None - self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None - self.trans_m = input['M'].to(self.device) if 'M' in input else None - self.image_paths = input['im_paths'] if 'im_paths' in input else None - - def forward(self, output_coeff, device): - self.facemodel.to(device) - self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ - self.facemodel.compute_for_render(output_coeff) - self.pred_mask, _, self.pred_face = self.renderer( - self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) - - self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) - - - def compute_losses(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - - assert self.net_recog.training == False - trans_m = self.trans_m - if not self.opt.use_predef_M: - trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) - - pred_feat = self.net_recog(self.pred_face, trans_m) - gt_feat = self.net_recog(self.input_img, self.trans_m) - self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) - - face_mask = self.pred_mask - if self.opt.use_crop_face: - face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) - - face_mask = face_mask.detach() - self.loss_color = self.opt.w_color * self.comupte_color_loss( - self.pred_face, self.input_img, self.atten_mask * face_mask) - - loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) - self.loss_reg = self.opt.w_reg * loss_reg - self.loss_gamma = self.opt.w_gamma * loss_gamma - - self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) - - self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) - - self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ - + self.loss_lm + self.loss_reflc - - - def optimize_parameters(self, isTrain=True): - self.forward() - self.compute_losses() - """Update network weights; it will be called in every training iteration.""" - if isTrain: - self.optimizer.zero_grad() - self.loss_all.backward() - self.optimizer.step() - - def compute_visuals(self): - with torch.no_grad(): - input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() - output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img - output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() - - if self.gt_lm is not None: - gt_lm_numpy = self.gt_lm.cpu().numpy() - pred_lm_numpy = self.pred_lm.detach().cpu().numpy() - output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') - output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') - - output_vis_numpy = np.concatenate((input_img_numpy, - output_vis_numpy_raw, output_vis_numpy), axis=-2) - else: - output_vis_numpy = np.concatenate((input_img_numpy, - output_vis_numpy_raw), axis=-2) - - self.output_vis = torch.tensor( - output_vis_numpy / 255., dtype=torch.float32 - ).permute(0, 3, 1, 2).to(self.device) - - def save_mesh(self, name): - - recon_shape = self.pred_vertex # get reconstructed shape - recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space - recon_shape = recon_shape.cpu().numpy()[0] - recon_color = self.pred_color - recon_color = recon_color.cpu().numpy()[0] - tri = self.facemodel.face_buf.cpu().numpy() - mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) - mesh.export(name) - - def save_coeff(self,name): - - pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} - pred_lm = self.pred_lm.cpu().numpy() - pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate - pred_coeffs['lm68'] = pred_lm - savemat(name,pred_coeffs) - - - diff --git a/sadtalker_audio2pose/src/face3d/models/losses.py b/sadtalker_audio2pose/src/face3d/models/losses.py deleted file mode 100644 index 01d9da84f28d54e772bebd2385ae5a7fedd10f7d..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/losses.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from kornia.geometry import warp_affine -import torch.nn.functional as F - -def resize_n_crop(image, M, dsize=112): - # image: (b, c, h, w) - # M : (b, 2, 3) - return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) - -### perceptual level loss -class PerceptualLoss(nn.Module): - def __init__(self, recog_net, input_size=112): - super(PerceptualLoss, self).__init__() - self.recog_net = recog_net - self.preprocess = lambda x: 2 * x - 1 - self.input_size=input_size - def forward(imageA, imageB, M): - """ - 1 - cosine distance - Parameters: - imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order - imageB --same as imageA - """ - - imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) - imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) - - # freeze bn - self.recog_net.eval() - - id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) - id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) - cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) - # assert torch.sum((cosine_d > 1).float()) == 0 - return torch.sum(1 - cosine_d) / cosine_d.shape[0] - -def perceptual_loss(id_featureA, id_featureB): - cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) - # assert torch.sum((cosine_d > 1).float()) == 0 - return torch.sum(1 - cosine_d) / cosine_d.shape[0] - -### image level loss -def photo_loss(imageA, imageB, mask, eps=1e-6): - """ - l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) - Parameters: - imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order - imageB --same as imageA - """ - loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask - loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) - return loss - -def landmark_loss(predict_lm, gt_lm, weight=None): - """ - weighted mse loss - Parameters: - predict_lm --torch.tensor (B, 68, 2) - gt_lm --torch.tensor (B, 68, 2) - weight --numpy.array (1, 68) - """ - if not weight: - weight = np.ones([68]) - weight[28:31] = 20 - weight[-8:] = 20 - weight = np.expand_dims(weight, 0) - weight = torch.tensor(weight).to(predict_lm.device) - loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight - loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) - return loss - - -### regulization -def reg_loss(coeffs_dict, opt=None): - """ - l2 norm without the sqrt, from yu's implementation (mse) - tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss - Parameters: - coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans - - """ - # coefficient regularization to ensure plausible 3d faces - if opt: - w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex - else: - w_id, w_exp, w_tex = 1, 1, 1, 1 - creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ - w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ - w_tex * torch.sum(coeffs_dict['tex'] ** 2) - creg_loss = creg_loss / coeffs_dict['id'].shape[0] - - # gamma regularization to ensure a nearly-monochromatic light - gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) - gamma_mean = torch.mean(gamma, dim=1, keepdims=True) - gamma_loss = torch.mean((gamma - gamma_mean) ** 2) - - return creg_loss, gamma_loss - -def reflectance_loss(texture, mask): - """ - minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo - Parameters: - texture --torch.tensor, (B, N, 3) - mask --torch.tensor, (N), 1 or 0 - - """ - mask = mask.reshape([1, mask.shape[0], 1]) - texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) - loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) - return loss - diff --git a/sadtalker_audio2pose/src/face3d/models/networks.py b/sadtalker_audio2pose/src/face3d/models/networks.py deleted file mode 100644 index 1e69eba1ade2e6431e7e7fd526ea68b8f63e7152..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/networks.py +++ /dev/null @@ -1,521 +0,0 @@ -"""This script defines deep neural networks for Deep3DFaceRecon_pytorch -""" - -import os -import numpy as np -import torch.nn.functional as F -from torch.nn import init -import functools -from torch.optim import lr_scheduler -import torch -from torch import Tensor -import torch.nn as nn -try: - from torch.hub import load_state_dict_from_url -except ImportError: - from torch.utils.model_zoo import load_url as load_state_dict_from_url -from typing import Type, Any, Callable, Union, List, Optional -from .arcface_torch.backbones import get_model -from kornia.geometry import warp_affine - -def resize_n_crop(image, M, dsize=112): - # image: (b, c, h, w) - # M : (b, 2, 3) - return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) - -def filter_state_dict(state_dict, remove_name='fc'): - new_state_dict = {} - for key in state_dict: - if remove_name in key: - continue - new_state_dict[key] = state_dict[key] - return new_state_dict - -def get_scheduler(optimizer, opt): - """Return a learning rate scheduler - - Parameters: - optimizer -- the optimizer of the network - opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  - opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine - - For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. - See https://pytorch.org/docs/stable/optim.html for more details. - """ - if opt.lr_policy == 'linear': - def lambda_rule(epoch): - lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) - return lr_l - scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) - elif opt.lr_policy == 'step': - scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) - elif opt.lr_policy == 'plateau': - scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) - elif opt.lr_policy == 'cosine': - scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) - else: - return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) - return scheduler - - -def define_net_recon(net_recon, use_last_fc=False, init_path=None): - return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) - -def define_net_recog(net_recog, pretrained_path=None): - net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) - net.eval() - return net - -class ReconNetWrapper(nn.Module): - fc_dim=257 - def __init__(self, net_recon, use_last_fc=False, init_path=None): - super(ReconNetWrapper, self).__init__() - self.use_last_fc = use_last_fc - if net_recon not in func_dict: - return NotImplementedError('network [%s] is not implemented', net_recon) - func, last_dim = func_dict[net_recon] - backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) - if init_path and os.path.isfile(init_path): - state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) - backbone.load_state_dict(state_dict) - print("loading init net_recon %s from %s" %(net_recon, init_path)) - self.backbone = backbone - if not use_last_fc: - self.final_layers = nn.ModuleList([ - conv1x1(last_dim, 80, bias=True), # id layer - conv1x1(last_dim, 64, bias=True), # exp layer - conv1x1(last_dim, 80, bias=True), # tex layer - conv1x1(last_dim, 3, bias=True), # angle layer - conv1x1(last_dim, 27, bias=True), # gamma layer - conv1x1(last_dim, 2, bias=True), # tx, ty - conv1x1(last_dim, 1, bias=True) # tz - ]) - for m in self.final_layers: - nn.init.constant_(m.weight, 0.) - nn.init.constant_(m.bias, 0.) - - def forward(self, x): - x = self.backbone(x) - if not self.use_last_fc: - output = [] - for layer in self.final_layers: - output.append(layer(x)) - x = torch.flatten(torch.cat(output, dim=1), 1) - return x - - -class RecogNetWrapper(nn.Module): - def __init__(self, net_recog, pretrained_path=None, input_size=112): - super(RecogNetWrapper, self).__init__() - net = get_model(name=net_recog, fp16=False) - if pretrained_path: - state_dict = torch.load(pretrained_path, map_location='cpu') - net.load_state_dict(state_dict) - print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) - for param in net.parameters(): - param.requires_grad = False - self.net = net - self.preprocess = lambda x: 2 * x - 1 - self.input_size=input_size - - def forward(self, image, M): - image = self.preprocess(resize_n_crop(image, M, self.input_size)) - id_feature = F.normalize(self.net(image), dim=-1, p=2) - return id_feature - - -# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'wide_resnet50_2', 'wide_resnet101_2'] - - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', - 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', -} - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - - -def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) - - -class BasicBlock(nn.Module): - expansion: int = 1 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - expansion: int = 4 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(Bottleneck, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - - def __init__( - self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 1000, - zero_init_residual: bool = False, - use_last_fc: bool = False, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(ResNet, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - self._norm_layer = norm_layer - - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.use_last_fc = use_last_fc - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - - if self.use_last_fc: - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] - - def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, - stride: int = 1, dilate: bool = False) -> nn.Sequential: - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def _forward_impl(self, x: Tensor) -> Tensor: - # See note [TorchScript super()] - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - if self.use_last_fc: - x = torch.flatten(x, 1) - x = self.fc(x) - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) - - -def _resnet( - arch: str, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - pretrained: bool, - progress: bool, - **kwargs: Any -) -> ResNet: - model = ResNet(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) - model.load_state_dict(state_dict) - return model - - -def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-18 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, - **kwargs) - - -def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-34 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-50 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-101 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, - **kwargs) - - -def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-152 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, - **kwargs) - - -def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNeXt-50 32x4d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 4 - return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - -def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNeXt-101 32x8d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) - - -def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""Wide ResNet-50-2 model from - `"Wide Residual Networks" `_. - - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - -def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""Wide ResNet-101-2 model from - `"Wide Residual Networks" `_. - - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) - - -func_dict = { - 'resnet18': (resnet18, 512), - 'resnet50': (resnet50, 2048) -} diff --git a/sadtalker_audio2pose/src/face3d/models/template_model.py b/sadtalker_audio2pose/src/face3d/models/template_model.py deleted file mode 100644 index 75860272a06312bfa4de382729dce5136a480a7f..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/models/template_model.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Model class template - -This module provides a template for users to implement custom models. -You can specify '--model template' to use this model. -The class name should be consistent with both the filename and its model option. -The filename should be _dataset.py -The class name should be Dataset.py -It implements a simple image-to-image translation baseline based on regression loss. -Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: - min_ ||netG(data_A) - data_B||_1 -You need to implement the following functions: - : Add model-specific options and rewrite default values for existing options. - <__init__>: Initialize this model class. - : Unpack input data and perform data pre-processing. - : Run forward pass. This will be called by both and . - : Update network weights; it will be called in every training iteration. -""" -import numpy as np -import torch -from .base_model import BaseModel -from . import networks - - -class TemplateModel(BaseModel): - @staticmethod - def modify_commandline_options(parser, is_train=True): - """Add new model-specific options and rewrite default values for existing options. - - Parameters: - parser -- the option parser - is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. - if is_train: - parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. - - return parser - - def __init__(self, opt): - """Initialize this model class. - - Parameters: - opt -- training/test options - - A few things can be done here. - - (required) call the initialization function of BaseModel - - define loss function, visualization images, model names, and optimizers - """ - BaseModel.__init__(self, opt) # call the initialization method of BaseModel - # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. - self.loss_names = ['loss_G'] - # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. - self.visual_names = ['data_A', 'data_B', 'output'] - # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. - # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. - self.model_names = ['G'] - # define networks; you can use opt.isTrain to specify different behaviors for training and test. - self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) - if self.isTrain: # only defined during training time - # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. - # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) - self.criterionLoss = torch.nn.L1Loss() - # define and initialize optimizers. You can define one optimizer for each network. - # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. - self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) - self.optimizers = [self.optimizer] - - # Our program will automatically call to define schedulers, load networks, and print networks - - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input: a dictionary that contains the data itself and its metadata information. - """ - AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B - self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A - self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B - self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths - - def forward(self): - """Run forward pass. This will be called by both functions and .""" - self.output = self.netG(self.data_A) # generate output image given the input data_A - - def backward(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - # caculate the intermediate results if necessary; here self.output has been computed during function - # calculate loss given the input and intermediate results - self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression - self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G - - def optimize_parameters(self): - """Update network weights; it will be called in every training iteration.""" - self.forward() # first call forward to calculate intermediate results - self.optimizer.zero_grad() # clear network G's existing gradients - self.backward() # calculate gradients for network G - self.optimizer.step() # update gradients for network G diff --git a/sadtalker_audio2pose/src/face3d/options/__init__.py b/sadtalker_audio2pose/src/face3d/options/__init__.py deleted file mode 100644 index 06559aa558cf178b946c4523b28b098d1dfad606..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/options/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/sadtalker_audio2pose/src/face3d/options/base_options.py b/sadtalker_audio2pose/src/face3d/options/base_options.py deleted file mode 100644 index 616a2e63f57e033a0a37e01a9b41babf93f6c3dd..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/options/base_options.py +++ /dev/null @@ -1,169 +0,0 @@ -"""This script contains base options for Deep3DFaceRecon_pytorch -""" - -import argparse -import os -from util import util -import numpy as np -import torch -import face3d.models as models -import face3d.data as data - - -class BaseOptions(): - """This class defines options used during both training and test time. - - It also implements several helper functions such as parsing, printing, and saving the options. - It also gathers additional options defined in functions in both dataset class and model class. - """ - - def __init__(self, cmd_line=None): - """Reset the class; indicates the class hasn't been initailized""" - self.initialized = False - self.cmd_line = None - if cmd_line is not None: - self.cmd_line = cmd_line.split() - - def initialize(self, parser): - """Define the common options that are used in both training and test.""" - # basic parameters - parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') - parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') - parser.add_argument('--checkpoints_dir', type=str, default='./ckpts/sad_talker', help='models are saved here') - parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') - parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') - parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') - parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') - parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') - parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') - parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') - - # model parameters - parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') - - # additional parameters - parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') - parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') - parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') - - self.initialized = True - return parser - - def gather_options(self): - """Initialize our parser with basic options(only once). - Add additional model-specific and dataset-specific options. - These options are defined in the function - in model and dataset classes. - """ - if not self.initialized: # check if it has been initialized - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser = self.initialize(parser) - - # get the basic options - if self.cmd_line is None: - opt, _ = parser.parse_known_args() - else: - opt, _ = parser.parse_known_args(self.cmd_line) - - # set cuda visible devices - os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids - - # modify model-related parser options - model_name = opt.model - model_option_setter = models.get_option_setter(model_name) - parser = model_option_setter(parser, self.isTrain) - if self.cmd_line is None: - opt, _ = parser.parse_known_args() # parse again with new defaults - else: - opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults - - # modify dataset-related parser options - if opt.dataset_mode: - dataset_name = opt.dataset_mode - dataset_option_setter = data.get_option_setter(dataset_name) - parser = dataset_option_setter(parser, self.isTrain) - - # save and return the parser - self.parser = parser - if self.cmd_line is None: - return parser.parse_args() - else: - return parser.parse_args(self.cmd_line) - - def print_options(self, opt): - """Print and save options - - It will print both current options and default values(if different). - It will save options into a text file / [checkpoints_dir] / opt.txt - """ - message = '' - message += '----------------- Options ---------------\n' - for k, v in sorted(vars(opt).items()): - comment = '' - default = self.parser.get_default(k) - if v != default: - comment = '\t[default: %s]' % str(default) - message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) - message += '----------------- End -------------------' - print(message) - - # save to the disk - expr_dir = os.path.join(opt.checkpoints_dir, opt.name) - util.mkdirs(expr_dir) - file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) - try: - with open(file_name, 'wt') as opt_file: - opt_file.write(message) - opt_file.write('\n') - except PermissionError as error: - print("permission error {}".format(error)) - pass - - def parse(self): - """Parse our options, create checkpoints directory suffix, and set up gpu device.""" - opt = self.gather_options() - opt.isTrain = self.isTrain # train or test - - # process opt.suffix - if opt.suffix: - suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' - opt.name = opt.name + suffix - - - # set gpu ids - str_ids = opt.gpu_ids.split(',') - gpu_ids = [] - for str_id in str_ids: - id = int(str_id) - if id >= 0: - gpu_ids.append(id) - opt.world_size = len(gpu_ids) - # if len(opt.gpu_ids) > 0: - # torch.cuda.set_device(gpu_ids[0]) - if opt.world_size == 1: - opt.use_ddp = False - - if opt.phase != 'test': - # set continue_train automatically - if opt.pretrained_name is None: - model_dir = os.path.join(opt.checkpoints_dir, opt.name) - else: - model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) - if os.path.isdir(model_dir): - model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] - if os.path.isdir(model_dir) and len(model_pths) != 0: - opt.continue_train= True - - # update the latest epoch count - if opt.continue_train: - if opt.epoch == 'latest': - epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] - if len(epoch_counts) != 0: - opt.epoch_count = max(epoch_counts) + 1 - else: - opt.epoch_count = int(opt.epoch) + 1 - - - self.print_options(opt) - self.opt = opt - return self.opt diff --git a/sadtalker_audio2pose/src/face3d/options/inference_options.py b/sadtalker_audio2pose/src/face3d/options/inference_options.py deleted file mode 100644 index 80b9466776e120e0fe3d164217df5071c2114cef..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/options/inference_options.py +++ /dev/null @@ -1,23 +0,0 @@ -from face3d.options.base_options import BaseOptions - - -class InferenceOptions(BaseOptions): - """This class includes test options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) # define shared options - parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') - - parser.add_argument('--input_dir', type=str, help='the folder of the input files') - parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') - parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') - parser.add_argument('--save_split_files', action='store_true', help='save split files or not') - parser.add_argument('--inference_batch_size', type=int, default=8) - - # Dropout and Batchnorm has different behavior during training and test. - self.isTrain = False - return parser diff --git a/sadtalker_audio2pose/src/face3d/options/test_options.py b/sadtalker_audio2pose/src/face3d/options/test_options.py deleted file mode 100644 index f81c0c6eee0549e6fa8762dc4fc4b8573b887fe4..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/options/test_options.py +++ /dev/null @@ -1,21 +0,0 @@ -"""This script contains the test options for Deep3DFaceRecon_pytorch -""" - -from .base_options import BaseOptions - - -class TestOptions(BaseOptions): - """This class includes test options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) # define shared options - parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') - parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') - - # Dropout and Batchnorm has different behavior during training and test. - self.isTrain = False - return parser diff --git a/sadtalker_audio2pose/src/face3d/options/train_options.py b/sadtalker_audio2pose/src/face3d/options/train_options.py deleted file mode 100644 index 1100b0e35cc8ef563f41f6b8219510edbef53233..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/options/train_options.py +++ /dev/null @@ -1,53 +0,0 @@ -"""This script contains the training options for Deep3DFaceRecon_pytorch -""" - -from .base_options import BaseOptions -from util import util - -class TrainOptions(BaseOptions): - """This class includes training options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) - # dataset parameters - # for train - parser.add_argument('--data_root', type=str, default='./', help='dataset root') - parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') - parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') - parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') - parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') - parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') - parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') - - # for val - parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') - parser.add_argument('--batch_size_val', type=int, default=32) - - - # visualization parameters - parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') - parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') - - # network saving and loading parameters - parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') - parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') - parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') - parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') - parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') - parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') - parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') - parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') - - # training parameters - parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') - parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') - parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') - parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') - - self.isTrain = True - return parser diff --git a/sadtalker_audio2pose/src/face3d/util/BBRegressorParam_r.mat b/sadtalker_audio2pose/src/face3d/util/BBRegressorParam_r.mat deleted file mode 100644 index a0da99af145c400a5216d9f6fb251d9412565921..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/BBRegressorParam_r.mat +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3a5a07b8ce75a39d96b918dc0fc6e110a72e090da16f5f056a0ef7bfbc3f4560 -size 22019 diff --git a/sadtalker_audio2pose/src/face3d/util/__init__.py b/sadtalker_audio2pose/src/face3d/util/__init__.py deleted file mode 100644 index 1c67833cc634a2ca310b883ae253b08687665f40..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""This package includes a miscellaneous collection of useful helper functions.""" -from src.face3d.util import * - diff --git a/sadtalker_audio2pose/src/face3d/util/detect_lm68.py b/sadtalker_audio2pose/src/face3d/util/detect_lm68.py deleted file mode 100644 index 8a2cfd22b342de5c872ff07fc1c2a9920c2985b7..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/detect_lm68.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import cv2 -import numpy as np -from scipy.io import loadmat -import tensorflow as tf -from util.preprocess import align_for_lm -from shutil import move - -mean_face = np.loadtxt('util/test_mean_face.txt') -mean_face = mean_face.reshape([68, 2]) - -def save_label(labels, save_path): - np.savetxt(save_path, labels) - -def draw_landmarks(img, landmark, save_name): - landmark = landmark - lm_img = np.zeros([img.shape[0], img.shape[1], 3]) - lm_img[:] = img.astype(np.float32) - landmark = np.round(landmark).astype(np.int32) - - for i in range(len(landmark)): - for j in range(-1, 1): - for k in range(-1, 1): - if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ - img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ - landmark[i, 0]+k > 0 and \ - landmark[i, 0]+k < img.shape[1]: - lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, - :] = np.array([0, 0, 255]) - lm_img = lm_img.astype(np.uint8) - - cv2.imwrite(save_name, lm_img) - - -def load_data(img_name, txt_name): - return cv2.imread(img_name), np.loadtxt(txt_name) - -# create tensorflow graph for landmark detector -def load_lm_graph(graph_filename): - with tf.gfile.GFile(graph_filename, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - - with tf.Graph().as_default() as graph: - tf.import_graph_def(graph_def, name='net') - img_224 = graph.get_tensor_by_name('net/input_imgs:0') - output_lm = graph.get_tensor_by_name('net/lm:0') - lm_sess = tf.Session(graph=graph) - - return lm_sess,img_224,output_lm - -# landmark detection -def detect_68p(img_path,sess,input_op,output_op): - print('detecting landmarks......') - names = [i for i in sorted(os.listdir( - img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] - vis_path = os.path.join(img_path, 'vis') - remove_path = os.path.join(img_path, 'remove') - save_path = os.path.join(img_path, 'landmarks') - if not os.path.isdir(vis_path): - os.makedirs(vis_path) - if not os.path.isdir(remove_path): - os.makedirs(remove_path) - if not os.path.isdir(save_path): - os.makedirs(save_path) - - for i in range(0, len(names)): - name = names[i] - print('%05d' % (i), ' ', name) - full_image_name = os.path.join(img_path, name) - txt_name = '.'.join(name.split('.')[:-1]) + '.txt' - full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image - - # if an image does not have detected 5 facial landmarks, remove it from the training list - if not os.path.isfile(full_txt_name): - move(full_image_name, os.path.join(remove_path, name)) - continue - - # load data - img, five_points = load_data(full_image_name, full_txt_name) - input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection - - # if the alignment fails, remove corresponding image from the training list - if scale == 0: - move(full_txt_name, os.path.join( - remove_path, txt_name)) - move(full_image_name, os.path.join(remove_path, name)) - continue - - # detect landmarks - input_img = np.reshape( - input_img, [1, 224, 224, 3]).astype(np.float32) - landmark = sess.run( - output_op, feed_dict={input_op: input_img}) - - # transform back to original image coordinate - landmark = landmark.reshape([68, 2]) + mean_face - landmark[:, 1] = 223 - landmark[:, 1] - landmark = landmark / scale - landmark[:, 0] = landmark[:, 0] + bbox[0] - landmark[:, 1] = landmark[:, 1] + bbox[1] - landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] - - if i % 100 == 0: - draw_landmarks(img, landmark, os.path.join(vis_path, name)) - save_label(landmark, os.path.join(save_path, txt_name)) diff --git a/sadtalker_audio2pose/src/face3d/util/generate_list.py b/sadtalker_audio2pose/src/face3d/util/generate_list.py deleted file mode 100644 index ebe93fcc5c61fbc79f4cd004a8d1bdd10ece16eb..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/generate_list.py +++ /dev/null @@ -1,34 +0,0 @@ -"""This script is to generate training list files for Deep3DFaceRecon_pytorch -""" - -import os - -# save path to training data -def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): - save_path = os.path.join(save_folder, mode) - if not os.path.isdir(save_path): - os.makedirs(save_path) - with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: - fd.writelines([i + '\n' for i in lms_list]) - - with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: - fd.writelines([i + '\n' for i in imgs_list]) - - with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: - fd.writelines([i + '\n' for i in msks_list]) - -# check if the path is valid -def check_list(rlms_list, rimgs_list, rmsks_list): - lms_list, imgs_list, msks_list = [], [], [] - for i in range(len(rlms_list)): - flag = 'false' - lm_path = rlms_list[i] - im_path = rimgs_list[i] - msk_path = rmsks_list[i] - if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): - flag = 'true' - lms_list.append(rlms_list[i]) - imgs_list.append(rimgs_list[i]) - msks_list.append(rmsks_list[i]) - print(i, rlms_list[i], flag) - return lms_list, imgs_list, msks_list diff --git a/sadtalker_audio2pose/src/face3d/util/html.py b/sadtalker_audio2pose/src/face3d/util/html.py deleted file mode 100644 index c0c4e6a66ba5a34e30cee3beb13e21465c72ef38..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/html.py +++ /dev/null @@ -1,86 +0,0 @@ -import dominate -from dominate.tags import meta, h3, table, tr, td, p, a, img, br -import os - - -class HTML: - """This HTML class allows us to save images and write texts into a single HTML file. - - It consists of functions such as (add a text header to the HTML file), - (add a row of images to the HTML file), and (save the HTML to the disk). - It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. - """ - - def __init__(self, web_dir, title, refresh=0): - """Initialize the HTML classes - - Parameters: - web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: - with self.doc.head: - meta(http_equiv="refresh", content=str(refresh)) - - def get_image_dir(self): - """Return the directory that stores images""" - return self.img_dir - - def add_header(self, text): - """Insert a header to the HTML file - - Parameters: - text (str) -- the header text - """ - with self.doc: - h3(text) - - def add_images(self, ims, txts, links, width=400): - """add images to the HTML file - - Parameters: - ims (str list) -- a list of image paths - txts (str list) -- a list of image names shown on the website - links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page - """ - self.t = table(border=1, style="table-layout: fixed;") # Insert a table - self.doc.add(self.t) - with self.t: - with tr(): - for im, txt, link in zip(ims, txts, links): - with td(style="word-wrap: break-word;", halign="center", valign="top"): - with p(): - with a(href=os.path.join('images', link)): - img(style="width:%dpx" % width, src=os.path.join('images', im)) - br() - p(txt) - - def save(self): - """save the current content to the HMTL file""" - html_file = '%s/index.html' % self.web_dir - f = open(html_file, 'wt') - f.write(self.doc.render()) - f.close() - - -if __name__ == '__main__': # we show an example usage here. - html = HTML('web/', 'test_html') - html.add_header('hello world') - - ims, txts, links = [], [], [] - for n in range(4): - ims.append('image_%d.png' % n) - txts.append('text_%d' % n) - links.append('image_%d.png' % n) - html.add_images(ims, txts, links) - html.save() diff --git a/sadtalker_audio2pose/src/face3d/util/load_mats.py b/sadtalker_audio2pose/src/face3d/util/load_mats.py deleted file mode 100644 index b7ea0a7877e80035883138415c102910d896bb61..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/load_mats.py +++ /dev/null @@ -1,120 +0,0 @@ -"""This script is to load 3D face model for Deep3DFaceRecon_pytorch -""" - -import numpy as np -from PIL import Image -from scipy.io import loadmat, savemat -from array import array -import os.path as osp - -# load expression basis -def LoadExpBasis(bfm_folder='BFM'): - n_vertex = 53215 - Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') - exp_dim = array('i') - exp_dim.fromfile(Expbin, 1) - expMU = array('f') - expPC = array('f') - expMU.fromfile(Expbin, 3*n_vertex) - expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) - Expbin.close() - - expPC = np.array(expPC) - expPC = np.reshape(expPC, [exp_dim[0], -1]) - expPC = np.transpose(expPC) - - expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) - - return expPC, expEV - - -# transfer original BFM09 to our face model -def transferBFM09(bfm_folder='BFM'): - print('Transfer BFM09 to BFM_model_front......') - original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) - shapePC = original_BFM['shapePC'] # shape basis - shapeEV = original_BFM['shapeEV'] # corresponding eigen value - shapeMU = original_BFM['shapeMU'] # mean face - texPC = original_BFM['texPC'] # texture basis - texEV = original_BFM['texEV'] # eigen value - texMU = original_BFM['texMU'] # mean texture - - expPC, expEV = LoadExpBasis(bfm_folder) - - # transfer BFM09 to our face model - - idBase = shapePC*np.reshape(shapeEV, [-1, 199]) - idBase = idBase/1e5 # unify the scale to decimeter - idBase = idBase[:, :80] # use only first 80 basis - - exBase = expPC*np.reshape(expEV, [-1, 79]) - exBase = exBase/1e5 # unify the scale to decimeter - exBase = exBase[:, :64] # use only first 64 basis - - texBase = texPC*np.reshape(texEV, [-1, 199]) - texBase = texBase[:, :80] # use only first 80 basis - - # our face model is cropped along face landmarks and contains only 35709 vertex. - # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. - # thus we select corresponding vertex to get our face model. - - index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) - index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) - - index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) - index_shape = index_shape['trimIndex'].astype( - np.int32) - 1 # starts from 0 (to 53490) - index_shape = index_shape[index_exp] - - idBase = np.reshape(idBase, [-1, 3, 80]) - idBase = idBase[index_shape, :, :] - idBase = np.reshape(idBase, [-1, 80]) - - texBase = np.reshape(texBase, [-1, 3, 80]) - texBase = texBase[index_shape, :, :] - texBase = np.reshape(texBase, [-1, 80]) - - exBase = np.reshape(exBase, [-1, 3, 64]) - exBase = exBase[index_exp, :, :] - exBase = np.reshape(exBase, [-1, 64]) - - meanshape = np.reshape(shapeMU, [-1, 3])/1e5 - meanshape = meanshape[index_shape, :] - meanshape = np.reshape(meanshape, [1, -1]) - - meantex = np.reshape(texMU, [-1, 3]) - meantex = meantex[index_shape, :] - meantex = np.reshape(meantex, [1, -1]) - - # other info contains triangles, region used for computing photometric loss, - # region used for skin texture regularization, and 68 landmarks index etc. - other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) - frontmask2_idx = other_info['frontmask2_idx'] - skinmask = other_info['skinmask'] - keypoints = other_info['keypoints'] - point_buf = other_info['point_buf'] - tri = other_info['tri'] - tri_mask2 = other_info['tri_mask2'] - - # save our face model - savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, - 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) - - -# load landmarks for standard face, which is used for image preprocessing -def load_lm3d(bfm_folder): - - Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) - Lm3D = Lm3D['lm'] - - # calculate 5 facial landmarks using 68 landmarks - lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 - Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( - Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) - Lm3D = Lm3D[[1, 2, 0, 3, 4], :] - - return Lm3D - - -if __name__ == '__main__': - transferBFM09() \ No newline at end of file diff --git a/sadtalker_audio2pose/src/face3d/util/nvdiffrast.py b/sadtalker_audio2pose/src/face3d/util/nvdiffrast.py deleted file mode 100644 index 4b345db30085de501b6718ad5b49bb5f9144dd29..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/nvdiffrast.py +++ /dev/null @@ -1,126 +0,0 @@ -"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch - Attention, antialiasing step is missing in current version. -""" -import pytorch3d.ops -import torch -import torch.nn.functional as F -import kornia -from kornia.geometry.camera import pixel2cam -import numpy as np -from typing import List -from scipy.io import loadmat -from torch import nn - -from pytorch3d.structures import Meshes -from pytorch3d.renderer import ( - look_at_view_transform, - FoVPerspectiveCameras, - DirectionalLights, - RasterizationSettings, - MeshRenderer, - MeshRasterizer, - SoftPhongShader, - TexturesUV, -) - -# def ndc_projection(x=0.1, n=1.0, f=50.0): -# return np.array([[n/x, 0, 0, 0], -# [ 0, n/-x, 0, 0], -# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], -# [ 0, 0, -1, 0]]).astype(np.float32) - -class MeshRenderer(nn.Module): - def __init__(self, - rasterize_fov, - znear=0.1, - zfar=10, - rasterize_size=224): - super(MeshRenderer, self).__init__() - - # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear - # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( - # torch.diag(torch.tensor([1., -1, -1, 1]))) - self.rasterize_size = rasterize_size - self.fov = rasterize_fov - self.znear = znear - self.zfar = zfar - - self.rasterizer = None - - def forward(self, vertex, tri, feat=None): - """ - Return: - mask -- torch.tensor, size (B, 1, H, W) - depth -- torch.tensor, size (B, 1, H, W) - features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None - - Parameters: - vertex -- torch.tensor, size (B, N, 3) - tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles - feat(optional) -- torch.tensor, size (B, N ,C), features - """ - device = vertex.device - rsize = int(self.rasterize_size) - # ndc_proj = self.ndc_proj.to(device) - # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v - if vertex.shape[-1] == 3: - vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) - vertex[..., 0] = -vertex[..., 0] - - - # vertex_ndc = vertex @ ndc_proj.t() - if self.rasterizer is None: - self.rasterizer = MeshRasterizer() - print("create rasterizer on device cuda:%d"%device.index) - - # ranges = None - # if isinstance(tri, List) or len(tri.shape) == 3: - # vum = vertex_ndc.shape[1] - # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) - # fstartidx = torch.cumsum(fnum, dim=0) - fnum - # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() - # for i in range(tri.shape[0]): - # tri[i] = tri[i] + i*vum - # vertex_ndc = torch.cat(vertex_ndc, dim=0) - # tri = torch.cat(tri, dim=0) - - # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] - tri = tri.type(torch.int32).contiguous() - - # rasterize - cameras = FoVPerspectiveCameras( - device=device, - fov=self.fov, - znear=self.znear, - zfar=self.zfar, - ) - - raster_settings = RasterizationSettings( - image_size=rsize - ) - - # print(vertex.shape, tri.shape) - mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) - - fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) - rast_out = fragments.pix_to_face.squeeze(-1) - depth = fragments.zbuf - - # render depth - depth = depth.permute(0, 3, 1, 2) - mask = (rast_out > 0).float().unsqueeze(1) - depth = mask * depth - - - image = None - if feat is not None: - attributes = feat.reshape(-1,3)[mesh.faces_packed()] - image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, - fragments.bary_coords, - attributes) - # print(image.shape) - image = image.squeeze(-2).permute(0, 3, 1, 2) - image = mask * image - - return mask, depth, image - diff --git a/sadtalker_audio2pose/src/face3d/util/preprocess.py b/sadtalker_audio2pose/src/face3d/util/preprocess.py deleted file mode 100644 index 82b36443fe4c84c1ad6366897a8e7d4e8b63b2b6..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/preprocess.py +++ /dev/null @@ -1,134 +0,0 @@ -"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch -""" - -import numpy as np -from scipy.io import loadmat -from PIL import Image -import cv2 -import os -from skimage import transform as trans -import torch -import warnings -warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) -warnings.filterwarnings("ignore", category=FutureWarning) - - -# calculating least square problem for image alignment -def POS(xp, x): - npts = xp.shape[1] - - A = np.zeros([2*npts, 8]) - - A[0:2*npts-1:2, 0:3] = x.transpose() - A[0:2*npts-1:2, 3] = 1 - - A[1:2*npts:2, 4:7] = x.transpose() - A[1:2*npts:2, 7] = 1 - - b = np.reshape(xp.transpose(), [2*npts, 1]) - - k, _, _, _ = np.linalg.lstsq(A, b) - - R1 = k[0:3] - R2 = k[4:7] - sTx = k[3] - sTy = k[7] - s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 - t = np.stack([sTx, sTy], axis=0) - - return t, s - -# # resize and crop images for face reconstruction -# def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): -# w0, h0 = img.size -# w = (w0*s).astype(np.int32) -# h = (h0*s).astype(np.int32) -# left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) -# right = left + target_size -# up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) -# below = up + target_size - -# img = img.resize((w, h), resample=Image.BICUBIC) -# img = img.crop((left, up, right, below)) - -# if mask is not None: -# mask = mask.resize((w, h), resample=Image.BICUBIC) -# mask = mask.crop((left, up, right, below)) - -# lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - -# t[1] + h0/2], axis=1)*s -# lm = lm - np.reshape( -# np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) - -# return img, lm, mask - - -# resize and crop images for face reconstruction -def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): - w0, h0 = img.size - w = (w0*s).astype(np.int32) - h = (h0*s).astype(np.int32) - left = np.round(w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) - right = left + target_size - up = np.round(h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) - below = up + target_size - - img = img.resize((w, h), resample=Image.BICUBIC) - img = img.crop((left, up, right, below)) - # import pdb; pdb.set_trace() - if mask is not None: - mask = mask.resize((w, h), resample=Image.BICUBIC) - mask = mask.crop((left, up, right, below)) - - lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - - t[1] + h0/2], axis=1)*s - lm = lm - np.reshape( - np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) - - # orig_left, orig_up, orig_crop_size = (left,up,target_size)/s - - return img, lm, mask, left, up, target_size - -# utils for face reconstruction -def extract_5p(lm): - lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 - lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( - lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) - lm5p = lm5p[[1, 2, 0, 3, 4], :] - return lm5p - -# utils for face reconstruction -def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): - """ - Return: - transparams --numpy.array (raw_W, raw_H, scale, tx, ty) - img_new --PIL.Image (target_size, target_size, 3) - lm_new --numpy.array (68, 2), y direction is opposite to v direction - mask_new --PIL.Image (target_size, target_size) - - Parameters: - img --PIL.Image (raw_H, raw_W, 3) - lm --numpy.array (68, 2), y direction is opposite to v direction - lm3D --numpy.array (5, 3) - mask --PIL.Image (raw_H, raw_W, 3) - """ - - w0, h0 = img.size - if lm.shape[0] != 5: - lm5p = extract_5p(lm) - else: - lm5p = lm - - # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face - t, s = POS(lm5p.transpose(), lm3D.transpose()) - s = rescale_factor/s - - # processing the image - - # processing the image - img_new, lm_new, mask_new, orig_left, orig_up, orig_crop_size = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) - trans_params = np.array([w0, h0, s, t[0], t[1], orig_left, orig_up, orig_crop_size]) - # img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) - # trans_params = np.array([w0, h0, s, t[0], t[1]]) - - return trans_params, img_new, lm_new, mask_new diff --git a/sadtalker_audio2pose/src/face3d/util/skin_mask.py b/sadtalker_audio2pose/src/face3d/util/skin_mask.py deleted file mode 100644 index ed764759038f77b35d45448b344d4347498ca427..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/skin_mask.py +++ /dev/null @@ -1,125 +0,0 @@ -"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch -""" - -import math -import numpy as np -import os -import cv2 - -class GMM: - def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): - self.dim = dim # feature dimension - self.num = num # number of Gaussian components - self.w = w # weights of Gaussian components (a list of scalars) - self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) - self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) - self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) - self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) - - self.factor = [0]*num - for i in range(self.num): - self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 - - def likelihood(self, data): - assert(data.shape[1] == self.dim) - N = data.shape[0] - lh = np.zeros(N) - - for i in range(self.num): - data_ = data - self.mu[i] - - tmp = np.matmul(data_,self.cov_inv[i]) * data_ - tmp = np.sum(tmp,axis=1) - power = -0.5 * tmp - - p = np.array([math.exp(power[j]) for j in range(N)]) - p = p/self.factor[i] - lh += p*self.w[i] - - return lh - - -def _rgb2ycbcr(rgb): - m = np.array([[65.481, 128.553, 24.966], - [-37.797, -74.203, 112], - [112, -93.786, -18.214]]) - shape = rgb.shape - rgb = rgb.reshape((shape[0] * shape[1], 3)) - ycbcr = np.dot(rgb, m.transpose() / 255.) - ycbcr[:, 0] += 16. - ycbcr[:, 1:] += 128. - return ycbcr.reshape(shape) - - -def _bgr2ycbcr(bgr): - rgb = bgr[..., ::-1] - return _rgb2ycbcr(rgb) - - -gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] -gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), - np.array([150.19858, 105.18467, 155.51428]), - np.array([183.92976, 107.62468, 152.71820]), - np.array([114.90524, 113.59782, 151.38217])] -gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] -gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), - np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), - np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), - np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] - -gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) - -gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] -gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), - np.array([110.91392, 125.52969, 130.19237]), - np.array([129.75864, 129.96107, 126.96808]), - np.array([112.29587, 128.85121, 129.05431])] -gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] -gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), - np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), - np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), - np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] - -gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) - -prior_skin = 0.8 -prior_nonskin = 1 - prior_skin - - -# calculate skin attention mask -def skinmask(imbgr): - im = _bgr2ycbcr(imbgr) - - data = im.reshape((-1,3)) - - lh_skin = gmm_skin.likelihood(data) - lh_nonskin = gmm_nonskin.likelihood(data) - - tmp1 = prior_skin * lh_skin - tmp2 = prior_nonskin * lh_nonskin - post_skin = tmp1 / (tmp1+tmp2) # posterior probability - - post_skin = post_skin.reshape((im.shape[0],im.shape[1])) - - post_skin = np.round(post_skin*255) - post_skin = post_skin.astype(np.uint8) - post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 - - return post_skin - - -def get_skin_mask(img_path): - print('generating skin masks......') - names = [i for i in sorted(os.listdir( - img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] - save_path = os.path.join(img_path, 'mask') - if not os.path.isdir(save_path): - os.makedirs(save_path) - - for i in range(0, len(names)): - name = names[i] - print('%05d' % (i), ' ', name) - full_image_name = os.path.join(img_path, name) - img = cv2.imread(full_image_name).astype(np.float32) - skin_img = skinmask(img) - cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) diff --git a/sadtalker_audio2pose/src/face3d/util/test_mean_face.txt b/sadtalker_audio2pose/src/face3d/util/test_mean_face.txt deleted file mode 100644 index 1637648acf5a61cbc71b317c845414bb16d0150c..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/test_mean_face.txt +++ /dev/null @@ -1,136 +0,0 @@ --5.228591537475585938e+01 -2.078247070312500000e-01 --5.064269638061523438e+01 --1.315765380859375000e+01 --4.952939224243164062e+01 --2.592591094970703125e+01 --4.793047332763671875e+01 --3.832135772705078125e+01 --4.512159729003906250e+01 --5.059623336791992188e+01 --3.917720794677734375e+01 --6.043736648559570312e+01 --2.929953765869140625e+01 --6.861183166503906250e+01 --1.719801330566406250e+01 --7.572736358642578125e+01 --1.961936950683593750e+00 --7.862001037597656250e+01 -1.467941284179687500e+01 --7.607844543457031250e+01 -2.744073486328125000e+01 --6.915261840820312500e+01 -3.855677795410156250e+01 --5.950350570678710938e+01 -4.478240966796875000e+01 --4.867547225952148438e+01 -4.714337158203125000e+01 --3.800830078125000000e+01 -4.940315246582031250e+01 --2.496297454833984375e+01 -5.117234802246093750e+01 --1.241538238525390625e+01 -5.190507507324218750e+01 -8.244247436523437500e-01 --4.150688934326171875e+01 -2.386329650878906250e+01 --3.570307159423828125e+01 -3.017010498046875000e+01 --2.790358734130859375e+01 -3.212951660156250000e+01 --1.941773223876953125e+01 -3.156523132324218750e+01 --1.138106536865234375e+01 -2.841992187500000000e+01 -5.993263244628906250e+00 -2.895182800292968750e+01 -1.343590545654296875e+01 -3.189880371093750000e+01 -2.203153991699218750e+01 -3.302221679687500000e+01 -2.992478942871093750e+01 -3.099150085449218750e+01 -3.628388977050781250e+01 -2.765748596191406250e+01 --1.933914184570312500e+00 -1.405374145507812500e+01 --2.153038024902343750e+00 -5.772636413574218750e+00 --2.270050048828125000e+00 --2.121643066406250000e+00 --2.218330383300781250e+00 --1.068978118896484375e+01 --1.187252044677734375e+01 --1.997912597656250000e+01 --6.879402160644531250e+00 --2.143579864501953125e+01 --1.227821350097656250e+00 --2.193494415283203125e+01 -4.623237609863281250e+00 --2.152721405029296875e+01 -9.721397399902343750e+00 --1.953671264648437500e+01 --3.648714447021484375e+01 -9.811126708984375000e+00 --3.130242919921875000e+01 -1.422447967529296875e+01 --2.212834930419921875e+01 -1.493019866943359375e+01 --1.500880432128906250e+01 -1.073588562011718750e+01 --2.095037078857421875e+01 -9.054298400878906250e+00 --3.050099182128906250e+01 -8.704177856445312500e+00 -1.173237609863281250e+01 -1.054329681396484375e+01 -1.856353759765625000e+01 -1.535009765625000000e+01 -2.893331909179687500e+01 -1.451992797851562500e+01 -3.452944946289062500e+01 -1.065280151367187500e+01 -2.875990295410156250e+01 -8.654792785644531250e+00 -1.942100524902343750e+01 -9.422447204589843750e+00 --2.204488372802734375e+01 --3.983994293212890625e+01 --1.324458312988281250e+01 --3.467377471923828125e+01 --6.749649047851562500e+00 --3.092894744873046875e+01 --9.183349609375000000e-01 --3.196458435058593750e+01 -4.220649719238281250e+00 --3.090406036376953125e+01 -1.089889526367187500e+01 --3.497008514404296875e+01 -1.874589538574218750e+01 --4.065438079833984375e+01 -1.124106597900390625e+01 --4.438417816162109375e+01 -5.181709289550781250e+00 --4.649170684814453125e+01 --1.158607482910156250e+00 --4.680406951904296875e+01 --7.918922424316406250e+00 --4.671575164794921875e+01 --1.452505493164062500e+01 --4.416526031494140625e+01 --2.005007171630859375e+01 --3.997841644287109375e+01 --1.054919433593750000e+01 --3.849683380126953125e+01 --1.051826477050781250e+00 --3.794863128662109375e+01 -6.412681579589843750e+00 --3.804645538330078125e+01 -1.627674865722656250e+01 --4.039697265625000000e+01 -6.373878479003906250e+00 --4.087213897705078125e+01 --8.551712036132812500e-01 --4.157129669189453125e+01 --1.014953613281250000e+01 --4.128469085693359375e+01 diff --git a/sadtalker_audio2pose/src/face3d/util/util.py b/sadtalker_audio2pose/src/face3d/util/util.py deleted file mode 100644 index 79c7517ee66c8830a73fa86ab5e5c3513f11d869..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/util.py +++ /dev/null @@ -1,208 +0,0 @@ -"""This script contains basic utilities for Deep3DFaceRecon_pytorch -""" -from __future__ import print_function -import numpy as np -import torch -from PIL import Image -import os -import importlib -import argparse -from argparse import Namespace -import torchvision - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def copyconf(default_opt, **kwargs): - conf = Namespace(**vars(default_opt)) - for key in kwargs: - setattr(conf, key, kwargs[key]) - return conf - -def genvalconf(train_opt, **kwargs): - conf = Namespace(**vars(train_opt)) - attr_dict = train_opt.__dict__ - for key, value in attr_dict.items(): - if 'val' in key and key.split('_')[0] in attr_dict: - setattr(conf, key.split('_')[0], value) - - for key in kwargs: - setattr(conf, key, kwargs[key]) - - return conf - -def find_class_in_module(target_cls_name, module): - target_cls_name = target_cls_name.replace('_', '').lower() - clslib = importlib.import_module(module) - cls = None - for name, clsobj in clslib.__dict__.items(): - if name.lower() == target_cls_name: - cls = clsobj - - assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) - - return cls - - -def tensor2im(input_image, imtype=np.uint8): - """"Converts a Tensor array into a numpy image array. - - Parameters: - input_image (tensor) -- the input image tensor array, range(0, 1) - imtype (type) -- the desired type of the converted numpy array - """ - if not isinstance(input_image, np.ndarray): - if isinstance(input_image, torch.Tensor): # get the data from a variable - image_tensor = input_image.data - else: - return input_image - image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array - if image_numpy.shape[0] == 1: # grayscale to RGB - image_numpy = np.tile(image_numpy, (3, 1, 1)) - image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling - else: # if it is a numpy array, do nothing - image_numpy = input_image - return image_numpy.astype(imtype) - - -def diagnose_network(net, name='network'): - """Calculate and print the mean of average absolute(gradients) - - Parameters: - net (torch network) -- Torch network - name (str) -- the name of the network - """ - mean = 0.0 - count = 0 - for param in net.parameters(): - if param.grad is not None: - mean += torch.mean(torch.abs(param.grad.data)) - count += 1 - if count > 0: - mean = mean / count - print(name) - print(mean) - - -def save_image(image_numpy, image_path, aspect_ratio=1.0): - """Save a numpy image to the disk - - Parameters: - image_numpy (numpy array) -- input numpy array - image_path (str) -- the path of the image - """ - - image_pil = Image.fromarray(image_numpy) - h, w, _ = image_numpy.shape - - if aspect_ratio is None: - pass - elif aspect_ratio > 1.0: - image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) - elif aspect_ratio < 1.0: - image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) - image_pil.save(image_path) - - -def print_numpy(x, val=True, shp=False): - """Print the mean, min, max, median, std, and size of a numpy array - - Parameters: - val (bool) -- if print the values of the numpy array - shp (bool) -- if print the shape of the numpy array - """ - x = x.astype(np.float64) - if shp: - print('shape,', x.shape) - if val: - x = x.flatten() - print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( - np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) - - -def mkdirs(paths): - """create empty directories if they don't exist - - Parameters: - paths (str list) -- a list of directory paths - """ - if isinstance(paths, list) and not isinstance(paths, str): - for path in paths: - mkdir(path) - else: - mkdir(paths) - - -def mkdir(path): - """create a single empty directory if it didn't exist - - Parameters: - path (str) -- a single directory path - """ - if not os.path.exists(path): - os.makedirs(path) - - -def correct_resize_label(t, size): - device = t.device - t = t.detach().cpu() - resized = [] - for i in range(t.size(0)): - one_t = t[i, :1] - one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) - one_np = one_np[:, :, 0] - one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) - resized_t = torch.from_numpy(np.array(one_image)).long() - resized.append(resized_t) - return torch.stack(resized, dim=0).to(device) - - -def correct_resize(t, size, mode=Image.BICUBIC): - device = t.device - t = t.detach().cpu() - resized = [] - for i in range(t.size(0)): - one_t = t[i:i + 1] - one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) - resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 - resized.append(resized_t) - return torch.stack(resized, dim=0).to(device) - -def draw_landmarks(img, landmark, color='r', step=2): - """ - Return: - img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) - - - Parameters: - img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) - landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction - color -- str, 'r' or 'b' (red or blue) - """ - if color =='r': - c = np.array([255., 0, 0]) - else: - c = np.array([0, 0, 255.]) - - _, H, W, _ = img.shape - img, landmark = img.copy(), landmark.copy() - landmark[..., 1] = H - 1 - landmark[..., 1] - landmark = np.round(landmark).astype(np.int32) - for i in range(landmark.shape[1]): - x, y = landmark[:, i, 0], landmark[:, i, 1] - for j in range(-step, step): - for k in range(-step, step): - u = np.clip(x + j, 0, W - 1) - v = np.clip(y + k, 0, H - 1) - for m in range(landmark.shape[0]): - img[m, v[m], u[m]] = c - return img diff --git a/sadtalker_audio2pose/src/face3d/util/visualizer.py b/sadtalker_audio2pose/src/face3d/util/visualizer.py deleted file mode 100644 index c4a8b755e054a4a34d003962a723ef189726a7a0..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/util/visualizer.py +++ /dev/null @@ -1,227 +0,0 @@ -"""This script defines the visualizer for Deep3DFaceRecon_pytorch -""" - -import numpy as np -import os -import sys -import ntpath -import time -from . import util, html -from subprocess import Popen, PIPE -from torch.utils.tensorboard import SummaryWriter - -def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): - """Save images to the disk. - - Parameters: - webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) - visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs - image_path (str) -- the string is used to create image paths - aspect_ratio (float) -- the aspect ratio of saved images - width (int) -- the images will be resized to width x width - - This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. - """ - image_dir = webpage.get_image_dir() - short_path = ntpath.basename(image_path[0]) - name = os.path.splitext(short_path)[0] - - webpage.add_header(name) - ims, txts, links = [], [], [] - - for label, im_data in visuals.items(): - im = util.tensor2im(im_data) - image_name = '%s/%s.png' % (label, name) - os.makedirs(os.path.join(image_dir, label), exist_ok=True) - save_path = os.path.join(image_dir, image_name) - util.save_image(im, save_path, aspect_ratio=aspect_ratio) - ims.append(image_name) - txts.append(label) - links.append(image_name) - webpage.add_images(ims, txts, links, width=width) - - -class Visualizer(): - """This class includes several functions that can display/save images and print/save logging information. - - It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. - """ - - def __init__(self, opt): - """Initialize the Visualizer class - - Parameters: - opt -- stores all the experiment flags; needs to be a subclass of BaseOptions - Step 1: Cache the training/test options - Step 2: create a tensorboard writer - Step 3: create an HTML object for saveing HTML filters - Step 4: create a logging file to store training losses - """ - self.opt = opt # cache the option - self.use_html = opt.isTrain and not opt.no_html - self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) - self.win_size = opt.display_winsize - self.name = opt.name - self.saved = False - if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ - self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') - self.img_dir = os.path.join(self.web_dir, 'images') - print('create web directory %s...' % self.web_dir) - util.mkdirs([self.web_dir, self.img_dir]) - # create a logging file to store training losses - self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') - with open(self.log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ Training Loss (%s) ================\n' % now) - - def reset(self): - """Reset the self.saved status""" - self.saved = False - - - def display_current_results(self, visuals, total_iters, epoch, save_result): - """Display current results on tensorboad; save current results to an HTML file. - - Parameters: - visuals (OrderedDict) - - dictionary of images to display or save - total_iters (int) -- total iterations - epoch (int) - - the current epoch - save_result (bool) - - if save the current results to an HTML file - """ - for label, image in visuals.items(): - self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') - - if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. - self.saved = True - # save images to the disk - for label, image in visuals.items(): - image_numpy = util.tensor2im(image) - img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) - util.save_image(image_numpy, img_path) - - # update website - webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) - for n in range(epoch, 0, -1): - webpage.add_header('epoch [%d]' % n) - ims, txts, links = [], [], [] - - for label, image_numpy in visuals.items(): - image_numpy = util.tensor2im(image) - img_path = 'epoch%.3d_%s.png' % (n, label) - ims.append(img_path) - txts.append(label) - links.append(img_path) - webpage.add_images(ims, txts, links, width=self.win_size) - webpage.save() - - def plot_current_losses(self, total_iters, losses): - # G_loss_collection = {} - # D_loss_collection = {} - # for name, value in losses.items(): - # if 'G' in name or 'NCE' in name or 'idt' in name: - # G_loss_collection[name] = value - # else: - # D_loss_collection[name] = value - # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) - # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) - for name, value in losses.items(): - self.writer.add_scalar(name, value, total_iters) - - # losses: same format as |losses| of plot_current_losses - def print_current_losses(self, epoch, iters, losses, t_comp, t_data): - """print current losses on console; also save the losses to the disk - - Parameters: - epoch (int) -- current epoch - iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) - losses (OrderedDict) -- training losses stored in the format of (name, float) pairs - t_comp (float) -- computational time per data point (normalized by batch_size) - t_data (float) -- data loading time per data point (normalized by batch_size) - """ - message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) - for k, v in losses.items(): - message += '%s: %.3f ' % (k, v) - - print(message) # print the message - with open(self.log_name, "a") as log_file: - log_file.write('%s\n' % message) # save the message - - -class MyVisualizer: - def __init__(self, opt): - """Initialize the Visualizer class - - Parameters: - opt -- stores all the experiment flags; needs to be a subclass of BaseOptions - Step 1: Cache the training/test options - Step 2: create a tensorboard writer - Step 3: create an HTML object for saveing HTML filters - Step 4: create a logging file to store training losses - """ - self.opt = opt # cache the optio - self.name = opt.name - self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') - - if opt.phase != 'test': - self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) - # create a logging file to store training losses - self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') - with open(self.log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ Training Loss (%s) ================\n' % now) - - - def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, - add_image=True): - """Display current results on tensorboad; save current results to an HTML file. - - Parameters: - visuals (OrderedDict) - - dictionary of images to display or save - total_iters (int) -- total iterations - epoch (int) - - the current epoch - dataset (str) - - 'train' or 'val' or 'test' - """ - # if (not add_image) and (not save_results): return - - for label, image in visuals.items(): - for i in range(image.shape[0]): - image_numpy = util.tensor2im(image[i]) - if add_image: - self.writer.add_image(label + '%s_%02d'%(dataset, i + count), - image_numpy, total_iters, dataformats='HWC') - - if save_results: - save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) - if not os.path.isdir(save_path): - os.makedirs(save_path) - - if name is not None: - img_path = os.path.join(save_path, '%s.png' % name) - else: - img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) - util.save_image(image_numpy, img_path) - - - def plot_current_losses(self, total_iters, losses, dataset='train'): - for name, value in losses.items(): - self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) - - # losses: same format as |losses| of plot_current_losses - def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): - """print current losses on console; also save the losses to the disk - - Parameters: - epoch (int) -- current epoch - iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) - losses (OrderedDict) -- training losses stored in the format of (name, float) pairs - t_comp (float) -- computational time per data point (normalized by batch_size) - t_data (float) -- data loading time per data point (normalized by batch_size) - """ - message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( - dataset, epoch, iters, t_comp, t_data) - for k, v in losses.items(): - message += '%s: %.3f ' % (k, v) - - print(message) # print the message - with open(self.log_name, "a") as log_file: - log_file.write('%s\n' % message) # save the message diff --git a/sadtalker_audio2pose/src/face3d/visualize.py b/sadtalker_audio2pose/src/face3d/visualize.py deleted file mode 100644 index cb8791ec30fb8f748aefc82cf4385444754825a4..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/visualize.py +++ /dev/null @@ -1,133 +0,0 @@ -# check the sync of 3dmm feature and the audio -import shutil -import cv2 -import numpy as np -from src.face3d.models.bfm import ParametricFaceModel -from src.face3d.models.facerecon_model import FaceReconModel -import torch -import subprocess, platform -import scipy.io as scio -from tqdm import tqdm - - -def draw_landmarks(image, landmarks): - for i, point in enumerate(landmarks): - cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1) - cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1) - return image - -# draft -def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False): - - coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] - info = scio.loadmat(first_frame_coeff)['trans_params'][0] - print(info) - - coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] - - # print(coeff_pred.shape) - # print(coeff_pred[1:, 64:].shape) - - if args.still: - coeff_pred[1:, 64:] = np.stack([coeff_pred[0, 64:]]*coeff_pred[1:, 64:].shape[0]) - - # assert False - - coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 - - coeff_full[:, 80:144] = coeff_pred[:, 0:64] - coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation - coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - tmp_video_path = '/tmp/face3dtmp.mp4' - facemodel = FaceReconModel(args) - im0 = cv2.imread(args.source_image) - - video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) - - # since we resize the video, we first need to resize the landmark to the cropped size resolution - # then, we need to add it back to the original video - x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256 - - W, H = im0.shape[0], im0.shape[1] - - _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7]) - orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)] - - landmark_scale = np.array([[x_scale, y_scale]]) - landmark_shift = np.array([[orig_left, orig_up]]) - landmark_shift2 = np.array([[ox1, oy1]]) - - - landmarks = [] - - for k in tqdm(range(coeff_first.shape[0]), '1st:'): - cur_coeff_full = torch.tensor(coeff_first, device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - print(orig_up, orig_left, orig_crop_size, s) - - for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): - cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - rendered_img = facemodel.pred_face - rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) - out_img = rendered_img[:, :, :3].astype(np.uint8) - - video.write(np.uint8(out_img[:,:,::-1])) - - video.release() - - # visualize landmarks - video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1])) - - for k in tqdm(range(len(landmarks)), 'face3d vis:'): - # im = draw_landmarks(im0.copy(), landmarks[k]) - im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k]) - video.write(im) - video.release() - - shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png')) - - np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks) - - command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) - subprocess.call(command, shell=platform.system() != 'Windows') - diff --git a/sadtalker_audio2pose/src/face3d/visualize_old.py b/sadtalker_audio2pose/src/face3d/visualize_old.py deleted file mode 100644 index b4a37b388320344fd96b4778b60679440fe584c3..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/face3d/visualize_old.py +++ /dev/null @@ -1,110 +0,0 @@ -# check the sync of 3dmm feature and the audio -import shutil -import cv2 -import numpy as np -from src.face3d.models.bfm import ParametricFaceModel -from src.face3d.models.facerecon_model import FaceReconModel -import torch -import subprocess, platform -import scipy.io as scio -from tqdm import tqdm - - -def draw_landmarks(image, landmarks): - for i, point in enumerate(landmarks): - cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1) - cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1) - return image - -# draft -def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False): - - coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] - info = scio.loadmat(first_frame_coeff)['trans_params'][0] - print(info) - - coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] - - coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 - - coeff_full[:, 80:144] = coeff_pred[:, 0:64] - coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation - coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - tmp_video_path = '/tmp/face3dtmp.mp4' - facemodel = FaceReconModel(args) - im0 = cv2.imread(args.source_image) - - video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) - - # since we resize the video, we first need to resize the landmark to the cropped size resolution - # then, we need to add it back to the original video - x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256 - - W, H = im0.shape[0], im0.shape[1] - - _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7]) - orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)] - - landmark_scale = np.array([[x_scale, y_scale]]) - landmark_shift = np.array([[orig_left, orig_up]]) - landmark_shift2 = np.array([[ox1, oy1]]) - - landmarks = [] - - print(orig_up, orig_left, orig_crop_size, s) - - for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): - cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - rendered_img = facemodel.pred_face - rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) - out_img = rendered_img[:, :, :3].astype(np.uint8) - - video.write(np.uint8(out_img[:,:,::-1])) - - video.release() - - # visualize landmarks - video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1])) - - for k in tqdm(range(len(landmarks)), 'face3d vis:'): - # im = draw_landmarks(im0.copy(), landmarks[k]) - im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k]) - video.write(im) - video.release() - - shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png')) - - np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks) - - command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) - subprocess.call(command, shell=platform.system() != 'Windows') - diff --git a/sadtalker_audio2pose/src/facerender/animate.py b/sadtalker_audio2pose/src/facerender/animate.py deleted file mode 100644 index 45fcb45edb4169166b851a066c8aaf08063ed1c6..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/animate.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -import cv2 -import yaml -import numpy as np -import warnings -from skimage import img_as_ubyte -import safetensors -import safetensors.torch -warnings.filterwarnings('ignore') - - -import imageio -import torch -import torchvision - - -from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector -from src.facerender.modules.mapping import MappingNet -from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator -from src.facerender.modules.make_animation import make_animation - -from pydub import AudioSegment -from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list -from src.utils.paste_pic import paste_pic -from src.utils.videoio import save_video_with_watermark - -try: - import webui # in webui - in_webui = True -except: - in_webui = False - -class AnimateFromCoeff(): - - def __init__(self, sadtalker_path, device): - - with open(sadtalker_path['facerender_yaml']) as f: - config = yaml.safe_load(f) - - generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], - **config['model_params']['common_params']) - kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], - **config['model_params']['common_params']) - he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], - **config['model_params']['common_params']) - mapping = MappingNet(**config['model_params']['mapping_params']) - - generator.to(device) - kp_extractor.to(device) - he_estimator.to(device) - mapping.to(device) - for param in generator.parameters(): - param.requires_grad = False - for param in kp_extractor.parameters(): - param.requires_grad = False - for param in he_estimator.parameters(): - param.requires_grad = False - for param in mapping.parameters(): - param.requires_grad = False - - if sadtalker_path is not None: - if 'checkpoint' in sadtalker_path: # use safe tensor - self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None) - else: - self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) - else: - raise AttributeError("Checkpoint should be specified for video head pose estimator.") - - if sadtalker_path['mappingnet_checkpoint'] is not None: - self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping) - else: - raise AttributeError("Checkpoint should be specified for video head pose estimator.") - - self.kp_extractor = kp_extractor - self.generator = generator - self.he_estimator = he_estimator - self.mapping = mapping - - self.kp_extractor.eval() - self.generator.eval() - self.he_estimator.eval() - self.mapping.eval() - - self.device = device - - def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, - kp_detector=None, he_estimator=None, - device="cpu"): - - checkpoint = safetensors.torch.load_file(checkpoint_path) - - if generator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'generator' in k: - x_generator[k.replace('generator.', '')] = v - generator.load_state_dict(x_generator) - if kp_detector is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'kp_extractor' in k: - x_generator[k.replace('kp_extractor.', '')] = v - kp_detector.load_state_dict(x_generator) - if he_estimator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'he_estimator' in k: - x_generator[k.replace('he_estimator.', '')] = v - he_estimator.load_state_dict(x_generator) - - return None - - def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, - kp_detector=None, he_estimator=None, optimizer_generator=None, - optimizer_discriminator=None, optimizer_kp_detector=None, - optimizer_he_estimator=None, device="cpu"): - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if generator is not None: - generator.load_state_dict(checkpoint['generator']) - if kp_detector is not None: - kp_detector.load_state_dict(checkpoint['kp_detector']) - if he_estimator is not None: - he_estimator.load_state_dict(checkpoint['he_estimator']) - if discriminator is not None: - try: - discriminator.load_state_dict(checkpoint['discriminator']) - except: - print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') - if optimizer_generator is not None: - optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) - if optimizer_discriminator is not None: - try: - optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) - except RuntimeError as e: - print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') - if optimizer_kp_detector is not None: - optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) - if optimizer_he_estimator is not None: - optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) - - return checkpoint['epoch'] - - def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, - optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if mapping is not None: - mapping.load_state_dict(checkpoint['mapping']) - if discriminator is not None: - discriminator.load_state_dict(checkpoint['discriminator']) - if optimizer_mapping is not None: - optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) - if optimizer_discriminator is not None: - optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) - - return checkpoint['epoch'] - - def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - if 'yaw_c_seq' in x: - yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor) - yaw_c_seq = x['yaw_c_seq'].to(self.device) - else: - yaw_c_seq = None - if 'pitch_c_seq' in x: - pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor) - pitch_c_seq = x['pitch_c_seq'].to(self.device) - else: - pitch_c_seq = None - if 'roll_c_seq' in x: - roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) - roll_c_seq = x['roll_c_seq'].to(self.device) - else: - roll_c_seq = None - - frame_num = x['frame_num'] - - predictions_video = make_animation(source_image, source_semantics, target_semantics, - self.generator, self.kp_extractor, self.he_estimator, self.mapping, - yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True) - - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - predictions_video = predictions_video[:frame_num] - - video = [] - for idx in range(predictions_video.shape[0]): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - result = img_as_ubyte(video) - - ### the generated video is 256x256, so we keep the aspect ratio, - original_size = crop_info[0] - if original_size: - result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - - video_name = x['video_name'] + '.mp4' - path = os.path.join(video_save_dir, 'temp_'+video_name) - - # print(path) - - imageio.mimsave(path, result, fps=float(25)) - - av_path = os.path.join(video_save_dir, video_name) - return_path = av_path - - audio_path = x['audio_path'] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') - start_time = 0 - # cog will not keep the .mp3 filename - sound = AudioSegment.from_file(audio_path) - frames = frame_num - end_time = start_time + frames*1/25*1000 - word1=sound.set_frame_rate(16000) - word = word1[start_time:end_time] - word.export(new_audio_path, format="wav") - - save_video_with_watermark(path, new_audio_path, av_path, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name}') - - if 'full' in preprocess.lower(): - # only add watermark to the full image. - video_name_full = x['video_name'] + '_full.mp4' - full_video_path = os.path.join(video_save_dir, video_name_full) - return_path = full_video_path - paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) - print(f'The generated video is named {video_save_dir}/{video_name_full}') - else: - full_video_path = av_path - - #### paste back then enhancers - if enhancer: - video_name_enhancer = x['video_name'] + '_enhanced.mp4' - enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) - av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) - return_path = av_path_enhancer - - try: - enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - except: - enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - - save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') - - - # os.remove(enhanced_path) - - # os.remove(path) - # os.remove(new_audio_path) - - return return_path - diff --git a/sadtalker_audio2pose/src/facerender/modules/dense_motion.py b/sadtalker_audio2pose/src/facerender/modules/dense_motion.py deleted file mode 100644 index 4c30417870e79bc005ea47a8f383c3aa406df563..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/modules/dense_motion.py +++ /dev/null @@ -1,121 +0,0 @@ -from torch import nn -import torch.nn.functional as F -import torch -from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian - -from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d - - -class DenseMotionNetwork(nn.Module): - """ - Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving - """ - - def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, - estimate_occlusion_map=False): - super(DenseMotionNetwork, self).__init__() - # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) - self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) - - self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) - - self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) - self.norm = BatchNorm3d(compress, affine=True) - - if estimate_occlusion_map: - # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) - self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) - else: - self.occlusion = None - - self.num_kp = num_kp - - - def create_sparse_motions(self, feature, kp_driving, kp_source): - bs, _, d, h, w = feature.shape - identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) - identity_grid = identity_grid.view(1, 1, d, h, w, 3) - coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) - - # if 'jacobian' in kp_driving: - if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: - jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) - jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) - jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) - coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) - coordinate_grid = coordinate_grid.squeeze(-1) - - - driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) - - #adding background feature - identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) - sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3 - - # sparse_motions = driving_to_source - - return sparse_motions - - def create_deformed_feature(self, feature, sparse_motions): - bs, _, d, h, w = feature.shape - feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) - feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) - sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!! - sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) - sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) - return sparse_deformed - - def create_heatmap_representations(self, feature, kp_driving, kp_source): - spatial_size = feature.shape[3:] - gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) - gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) - heatmap = gaussian_driving - gaussian_source - - # adding background feature - zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) - heatmap = torch.cat([zeros, heatmap], dim=1) - heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) - return heatmap - - def forward(self, feature, kp_driving, kp_source): - bs, _, d, h, w = feature.shape - - feature = self.compress(feature) - feature = self.norm(feature) - feature = F.relu(feature) - - out_dict = dict() - sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) - deformed_feature = self.create_deformed_feature(feature, sparse_motion) - - heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) - - input_ = torch.cat([heatmap, deformed_feature], dim=2) - input_ = input_.view(bs, -1, d, h, w) - - # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) - - prediction = self.hourglass(input_) - - - mask = self.mask(prediction) - mask = F.softmax(mask, dim=1) - out_dict['mask'] = mask - mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) - - zeros_mask = torch.zeros_like(mask) - mask = torch.where(mask < 1e-3, zeros_mask, mask) - - sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) - deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) - deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) - - out_dict['deformation'] = deformation - - if self.occlusion: - bs, c, d, h, w = prediction.shape - prediction = prediction.view(bs, -1, h, w) - occlusion_map = torch.sigmoid(self.occlusion(prediction)) - out_dict['occlusion_map'] = occlusion_map - - return out_dict diff --git a/sadtalker_audio2pose/src/facerender/modules/discriminator.py b/sadtalker_audio2pose/src/facerender/modules/discriminator.py deleted file mode 100644 index cc0a2b460d2175a958d7b230b7e5233d7d7c7f92..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/modules/discriminator.py +++ /dev/null @@ -1,90 +0,0 @@ -from torch import nn -import torch.nn.functional as F -from facerender.modules.util import kp2gaussian -import torch - - -class DownBlock2d(nn.Module): - """ - Simple block for processing video (encoder). - """ - - def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): - super(DownBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) - - if sn: - self.conv = nn.utils.spectral_norm(self.conv) - - if norm: - self.norm = nn.InstanceNorm2d(out_features, affine=True) - else: - self.norm = None - self.pool = pool - - def forward(self, x): - out = x - out = self.conv(out) - if self.norm: - out = self.norm(out) - out = F.leaky_relu(out, 0.2) - if self.pool: - out = F.avg_pool2d(out, (2, 2)) - return out - - -class Discriminator(nn.Module): - """ - Discriminator similar to Pix2Pix - """ - - def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, - sn=False, **kwargs): - super(Discriminator, self).__init__() - - down_blocks = [] - for i in range(num_blocks): - down_blocks.append( - DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), - min(max_features, block_expansion * (2 ** (i + 1))), - norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) - - self.down_blocks = nn.ModuleList(down_blocks) - self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) - if sn: - self.conv = nn.utils.spectral_norm(self.conv) - - def forward(self, x): - feature_maps = [] - out = x - - for down_block in self.down_blocks: - feature_maps.append(down_block(out)) - out = feature_maps[-1] - prediction_map = self.conv(out) - - return feature_maps, prediction_map - - -class MultiScaleDiscriminator(nn.Module): - """ - Multi-scale (scale) discriminator - """ - - def __init__(self, scales=(), **kwargs): - super(MultiScaleDiscriminator, self).__init__() - self.scales = scales - discs = {} - for scale in scales: - discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) - self.discs = nn.ModuleDict(discs) - - def forward(self, x): - out_dict = {} - for scale, disc in self.discs.items(): - scale = str(scale).replace('-', '.') - key = 'prediction_' + scale - feature_maps, prediction_map = disc(x[key]) - out_dict['feature_maps_' + scale] = feature_maps - out_dict['prediction_map_' + scale] = prediction_map - return out_dict diff --git a/sadtalker_audio2pose/src/facerender/modules/generator.py b/sadtalker_audio2pose/src/facerender/modules/generator.py deleted file mode 100644 index 2b94dde7a37c5ddf0f74dd0317a5db3507ab0729..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/modules/generator.py +++ /dev/null @@ -1,255 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock -from src.facerender.modules.dense_motion import DenseMotionNetwork - - -class OcclusionAwareGenerator(nn.Module): - """ - Generator follows NVIDIA architecture. - """ - - def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, - num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): - super(OcclusionAwareGenerator, self).__init__() - - if dense_motion_params is not None: - self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, - estimate_occlusion_map=estimate_occlusion_map, - **dense_motion_params) - else: - self.dense_motion_network = None - - self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) - - down_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, block_expansion * (2 ** i)) - out_features = min(max_features, block_expansion * (2 ** (i + 1))) - down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - self.down_blocks = nn.ModuleList(down_blocks) - - self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) - - self.reshape_channel = reshape_channel - self.reshape_depth = reshape_depth - - self.resblocks_3d = torch.nn.Sequential() - for i in range(num_resblocks): - self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) - - out_features = block_expansion * (2 ** (num_down_blocks)) - self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) - self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) - - self.resblocks_2d = torch.nn.Sequential() - for i in range(num_resblocks): - self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) - - up_blocks = [] - for i in range(num_down_blocks): - in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) - out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) - up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - self.up_blocks = nn.ModuleList(up_blocks) - - self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) - self.estimate_occlusion_map = estimate_occlusion_map - self.image_channel = image_channel - - def deform_input(self, inp, deformation): - _, d_old, h_old, w_old, _ = deformation.shape - _, _, d, h, w = inp.shape - if d_old != d or h_old != h or w_old != w: - deformation = deformation.permute(0, 4, 1, 2, 3) - deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') - deformation = deformation.permute(0, 2, 3, 4, 1) - return F.grid_sample(inp, deformation) - - def forward(self, source_image, kp_driving, kp_source): - # Encoding (downsampling) part - out = self.first(source_image) - for i in range(len(self.down_blocks)): - out = self.down_blocks[i](out) - out = self.second(out) - bs, c, h, w = out.shape - # print(out.shape) - feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) - feature_3d = self.resblocks_3d(feature_3d) - - # Transforming feature representation according to deformation and occlusion - output_dict = {} - if self.dense_motion_network is not None: - dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, - kp_source=kp_source) - output_dict['mask'] = dense_motion['mask'] - - if 'occlusion_map' in dense_motion: - occlusion_map = dense_motion['occlusion_map'] - output_dict['occlusion_map'] = occlusion_map - else: - occlusion_map = None - deformation = dense_motion['deformation'] - out = self.deform_input(feature_3d, deformation) - - bs, c, d, h, w = out.shape - out = out.view(bs, c*d, h, w) - out = self.third(out) - out = self.fourth(out) - - if occlusion_map is not None: - if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: - occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') - out = out * occlusion_map - - # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image - - # Decoding part - out = self.resblocks_2d(out) - for i in range(len(self.up_blocks)): - out = self.up_blocks[i](out) - out = self.final(out) - out = F.sigmoid(out) - - output_dict["prediction"] = out - - return output_dict - - -class SPADEDecoder(nn.Module): - def __init__(self): - super().__init__() - ic = 256 - oc = 64 - norm_G = 'spadespectralinstance' - label_nc = 256 - - self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) - self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) - self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) - self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) - self.up = nn.Upsample(scale_factor=2) - - def forward(self, feature): - seg = feature - x = self.fc(feature) - x = self.G_middle_0(x, seg) - x = self.G_middle_1(x, seg) - x = self.G_middle_2(x, seg) - x = self.G_middle_3(x, seg) - x = self.G_middle_4(x, seg) - x = self.G_middle_5(x, seg) - x = self.up(x) - x = self.up_0(x, seg) # 256, 128, 128 - x = self.up(x) - x = self.up_1(x, seg) # 64, 256, 256 - - x = self.conv_img(F.leaky_relu(x, 2e-1)) - # x = torch.tanh(x) - x = F.sigmoid(x) - - return x - - -class OcclusionAwareSPADEGenerator(nn.Module): - - def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, - num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): - super(OcclusionAwareSPADEGenerator, self).__init__() - - if dense_motion_params is not None: - self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, - estimate_occlusion_map=estimate_occlusion_map, - **dense_motion_params) - else: - self.dense_motion_network = None - - self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) - - down_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, block_expansion * (2 ** i)) - out_features = min(max_features, block_expansion * (2 ** (i + 1))) - down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - self.down_blocks = nn.ModuleList(down_blocks) - - self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) - - self.reshape_channel = reshape_channel - self.reshape_depth = reshape_depth - - self.resblocks_3d = torch.nn.Sequential() - for i in range(num_resblocks): - self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) - - out_features = block_expansion * (2 ** (num_down_blocks)) - self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) - self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) - - self.estimate_occlusion_map = estimate_occlusion_map - self.image_channel = image_channel - - self.decoder = SPADEDecoder() - - def deform_input(self, inp, deformation): - _, d_old, h_old, w_old, _ = deformation.shape - _, _, d, h, w = inp.shape - if d_old != d or h_old != h or w_old != w: - deformation = deformation.permute(0, 4, 1, 2, 3) - deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') - deformation = deformation.permute(0, 2, 3, 4, 1) - return F.grid_sample(inp, deformation) - - def forward(self, source_image, kp_driving, kp_source): - # Encoding (downsampling) part - out = self.first(source_image) - for i in range(len(self.down_blocks)): - out = self.down_blocks[i](out) - out = self.second(out) - bs, c, h, w = out.shape - # print(out.shape) - feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) - feature_3d = self.resblocks_3d(feature_3d) - - # Transforming feature representation according to deformation and occlusion - output_dict = {} - if self.dense_motion_network is not None: - dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, - kp_source=kp_source) - output_dict['mask'] = dense_motion['mask'] - - # import pdb; pdb.set_trace() - - if 'occlusion_map' in dense_motion: - occlusion_map = dense_motion['occlusion_map'] - output_dict['occlusion_map'] = occlusion_map - else: - occlusion_map = None - deformation = dense_motion['deformation'] - out = self.deform_input(feature_3d, deformation) - - bs, c, d, h, w = out.shape - out = out.view(bs, c*d, h, w) - out = self.third(out) - out = self.fourth(out) - - # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map) - - if occlusion_map is not None: - if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: - occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') - out = out * occlusion_map - - # Decoding part - out = self.decoder(out) - - output_dict["prediction"] = out - - return output_dict \ No newline at end of file diff --git a/sadtalker_audio2pose/src/facerender/modules/keypoint_detector.py b/sadtalker_audio2pose/src/facerender/modules/keypoint_detector.py deleted file mode 100644 index e56800c7b1e94bb3cbf97200cd3f059ce9d29cf3..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/modules/keypoint_detector.py +++ /dev/null @@ -1,179 +0,0 @@ -from torch import nn -import torch -import torch.nn.functional as F - -from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d -from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck - - -class KPDetector(nn.Module): - """ - Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. - """ - - def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, - num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): - super(KPDetector, self).__init__() - - self.predictor = KPHourglass(block_expansion, in_features=image_channel, - max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) - - # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) - self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) - - if estimate_jacobian: - self.num_jacobian_maps = 1 if single_jacobian_map else num_kp - # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) - self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) - ''' - initial as: - [[1 0 0] - [0 1 0] - [0 0 1]] - ''' - self.jacobian.weight.data.zero_() - self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) - else: - self.jacobian = None - - self.temperature = temperature - self.scale_factor = scale_factor - if self.scale_factor != 1: - self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) - - def gaussian2kp(self, heatmap): - """ - Extract the mean from a heatmap - """ - shape = heatmap.shape - heatmap = heatmap.unsqueeze(-1) - grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) - value = (heatmap * grid).sum(dim=(2, 3, 4)) - kp = {'value': value} - - return kp - - def forward(self, x): - if self.scale_factor != 1: - x = self.down(x) - - feature_map = self.predictor(x) - prediction = self.kp(feature_map) - - final_shape = prediction.shape - heatmap = prediction.view(final_shape[0], final_shape[1], -1) - heatmap = F.softmax(heatmap / self.temperature, dim=2) - heatmap = heatmap.view(*final_shape) - - out = self.gaussian2kp(heatmap) - - if self.jacobian is not None: - jacobian_map = self.jacobian(feature_map) - jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], - final_shape[3], final_shape[4]) - heatmap = heatmap.unsqueeze(2) - - jacobian = heatmap * jacobian_map - jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) - jacobian = jacobian.sum(dim=-1) - jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) - out['jacobian'] = jacobian - - return out - - -class HEEstimator(nn.Module): - """ - Estimating head pose and expression. - """ - - def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): - super(HEEstimator, self).__init__() - - self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) - self.norm1 = BatchNorm2d(block_expansion, affine=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - - self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) - self.norm2 = BatchNorm2d(256, affine=True) - - self.block1 = nn.Sequential() - for i in range(3): - self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) - - self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) - self.norm3 = BatchNorm2d(512, affine=True) - self.block2 = ResBottleneck(in_features=512, stride=2) - - self.block3 = nn.Sequential() - for i in range(3): - self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) - - self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) - self.norm4 = BatchNorm2d(1024, affine=True) - self.block4 = ResBottleneck(in_features=1024, stride=2) - - self.block5 = nn.Sequential() - for i in range(5): - self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) - - self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) - self.norm5 = BatchNorm2d(2048, affine=True) - self.block6 = ResBottleneck(in_features=2048, stride=2) - - self.block7 = nn.Sequential() - for i in range(2): - self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) - - self.fc_roll = nn.Linear(2048, num_bins) - self.fc_pitch = nn.Linear(2048, num_bins) - self.fc_yaw = nn.Linear(2048, num_bins) - - self.fc_t = nn.Linear(2048, 3) - - self.fc_exp = nn.Linear(2048, 3*num_kp) - - def forward(self, x): - out = self.conv1(x) - out = self.norm1(out) - out = F.relu(out) - out = self.maxpool(out) - - out = self.conv2(out) - out = self.norm2(out) - out = F.relu(out) - - out = self.block1(out) - - out = self.conv3(out) - out = self.norm3(out) - out = F.relu(out) - out = self.block2(out) - - out = self.block3(out) - - out = self.conv4(out) - out = self.norm4(out) - out = F.relu(out) - out = self.block4(out) - - out = self.block5(out) - - out = self.conv5(out) - out = self.norm5(out) - out = F.relu(out) - out = self.block6(out) - - out = self.block7(out) - - out = F.adaptive_avg_pool2d(out, 1) - out = out.view(out.shape[0], -1) - - yaw = self.fc_roll(out) - pitch = self.fc_pitch(out) - roll = self.fc_yaw(out) - t = self.fc_t(out) - exp = self.fc_exp(out) - - return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} - diff --git a/sadtalker_audio2pose/src/facerender/modules/make_animation.py b/sadtalker_audio2pose/src/facerender/modules/make_animation.py deleted file mode 100644 index 42c8c53dcc04da8354d05c98c2bc0d88bf067fb2..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/modules/make_animation.py +++ /dev/null @@ -1,170 +0,0 @@ -from scipy.spatial import ConvexHull -import torch -import torch.nn.functional as F -import numpy as np -from tqdm import tqdm - -def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, - use_relative_movement=False, use_relative_jacobian=False): - if adapt_movement_scale: - source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume - driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume - adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) - else: - adapt_movement_scale = 1 - - kp_new = {k: v for k, v in kp_driving.items()} - - if use_relative_movement: - kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) - kp_value_diff *= adapt_movement_scale - kp_new['value'] = kp_value_diff + kp_source['value'] - - if use_relative_jacobian: - jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) - kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) - - return kp_new - -def headpose_pred_to_degree(pred): - device = pred.device - idx_tensor = [idx for idx in range(66)] - idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device) - pred = F.softmax(pred) - degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 - return degree - -def get_rotation_matrix(yaw, pitch, roll): - yaw = yaw / 180 * 3.14 - pitch = pitch / 180 * 3.14 - roll = roll / 180 * 3.14 - - roll = roll.unsqueeze(1) - pitch = pitch.unsqueeze(1) - yaw = yaw.unsqueeze(1) - - pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), - torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), - torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) - pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) - - yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), - torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), - -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) - yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) - - roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), - torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), - torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) - roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) - - rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) - - return rot_mat - -def keypoint_transformation(kp_canonical, he, wo_exp=False): - kp = kp_canonical['value'] # (bs, k, 3) - yaw, pitch, roll= he['yaw'], he['pitch'], he['roll'] - yaw = headpose_pred_to_degree(yaw) - pitch = headpose_pred_to_degree(pitch) - roll = headpose_pred_to_degree(roll) - - if 'yaw_in' in he: - yaw = he['yaw_in'] - if 'pitch_in' in he: - pitch = he['pitch_in'] - if 'roll_in' in he: - roll = he['roll_in'] - - rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) - - t, exp = he['t'], he['exp'] - if wo_exp: - exp = exp*0 - - # keypoint rotation - kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) - - # keypoint translation - t[:, 0] = t[:, 0]*0 - t[:, 2] = t[:, 2]*0 - t = t.unsqueeze(1).repeat(1, kp.shape[1], 1) - kp_t = kp_rotated + t - - # add expression deviation - exp = exp.view(exp.shape[0], -1, 3) - kp_transformed = kp_t + exp - - return {'value': kp_transformed} - - - -def make_animation(source_image, source_semantics, target_semantics, - generator, kp_detector, he_estimator, mapping, - yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, - use_exp=True, use_half=False): - with torch.no_grad(): - predictions = [] - - kp_canonical = kp_detector(source_image) - he_source = mapping(source_semantics) - kp_source = keypoint_transformation(kp_canonical, he_source) - - for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'): - # still check the dimension - # print(target_semantics.shape, source_semantics.shape) - target_semantics_frame = target_semantics[:, frame_idx] - he_driving = mapping(target_semantics_frame) - if yaw_c_seq is not None: - he_driving['yaw_in'] = yaw_c_seq[:, frame_idx] - if pitch_c_seq is not None: - he_driving['pitch_in'] = pitch_c_seq[:, frame_idx] - if roll_c_seq is not None: - he_driving['roll_in'] = roll_c_seq[:, frame_idx] - - kp_driving = keypoint_transformation(kp_canonical, he_driving) - - kp_norm = kp_driving - out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm) - ''' - source_image_new = out['prediction'].squeeze(1) - kp_canonical_new = kp_detector(source_image_new) - he_source_new = he_estimator(source_image_new) - kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True) - kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True) - out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new) - ''' - predictions.append(out['prediction']) - predictions_ts = torch.stack(predictions, dim=1) - return predictions_ts - -class AnimateModel(torch.nn.Module): - """ - Merge all generator related updates into single model for better multi-gpu usage - """ - - def __init__(self, generator, kp_extractor, mapping): - super(AnimateModel, self).__init__() - self.kp_extractor = kp_extractor - self.generator = generator - self.mapping = mapping - - self.kp_extractor.eval() - self.generator.eval() - self.mapping.eval() - - def forward(self, x): - - source_image = x['source_image'] - source_semantics = x['source_semantics'] - target_semantics = x['target_semantics'] - yaw_c_seq = x['yaw_c_seq'] - pitch_c_seq = x['pitch_c_seq'] - roll_c_seq = x['roll_c_seq'] - - predictions_video = make_animation(source_image, source_semantics, target_semantics, - self.generator, self.kp_extractor, - self.mapping, use_exp = True, - yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq) - - return predictions_video \ No newline at end of file diff --git a/sadtalker_audio2pose/src/facerender/modules/mapping.py b/sadtalker_audio2pose/src/facerender/modules/mapping.py deleted file mode 100644 index 5ac98dd9e177b949f71f8f47029b66d67ece05b4..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/modules/mapping.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class MappingNet(nn.Module): - def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): - super( MappingNet, self).__init__() - - self.layer = layer - nonlinearity = nn.LeakyReLU(0.1) - - self.first = nn.Sequential( - torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) - - for i in range(layer): - net = nn.Sequential(nonlinearity, - torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) - setattr(self, 'encoder' + str(i), net) - - self.pooling = nn.AdaptiveAvgPool1d(1) - self.output_nc = descriptor_nc - - self.fc_roll = nn.Linear(descriptor_nc, num_bins) - self.fc_pitch = nn.Linear(descriptor_nc, num_bins) - self.fc_yaw = nn.Linear(descriptor_nc, num_bins) - self.fc_t = nn.Linear(descriptor_nc, 3) - self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) - - def forward(self, input_3dmm): - out = self.first(input_3dmm) - for i in range(self.layer): - model = getattr(self, 'encoder' + str(i)) - out = model(out) + out[:,:,3:-3] - out = self.pooling(out) - out = out.view(out.shape[0], -1) - #print('out:', out.shape) - - yaw = self.fc_yaw(out) - pitch = self.fc_pitch(out) - roll = self.fc_roll(out) - t = self.fc_t(out) - exp = self.fc_exp(out) - - return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} \ No newline at end of file diff --git a/sadtalker_audio2pose/src/facerender/modules/util.py b/sadtalker_audio2pose/src/facerender/modules/util.py deleted file mode 100644 index f3bfb1f26427b491f032ca9952db41cdeb793d70..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/modules/util.py +++ /dev/null @@ -1,564 +0,0 @@ -from torch import nn - -import torch.nn.functional as F -import torch - -from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d -from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d - -import torch.nn.utils.spectral_norm as spectral_norm - - -def kp2gaussian(kp, spatial_size, kp_variance): - """ - Transform a keypoint into gaussian like representation - """ - mean = kp['value'] - - coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) - number_of_leading_dimensions = len(mean.shape) - 1 - shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape - coordinate_grid = coordinate_grid.view(*shape) - repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) - coordinate_grid = coordinate_grid.repeat(*repeats) - - # Preprocess kp shape - shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) - mean = mean.view(*shape) - - mean_sub = (coordinate_grid - mean) - - out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) - - return out - -def make_coordinate_grid_2d(spatial_size, type): - """ - Create a meshgrid [-1,1] x [-1,1] of given spatial_size. - """ - h, w = spatial_size - x = torch.arange(w).type(type) - y = torch.arange(h).type(type) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - - yy = y.view(-1, 1).repeat(1, w) - xx = x.view(1, -1).repeat(h, 1) - - meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) - - return meshed - - -def make_coordinate_grid(spatial_size, type): - d, h, w = spatial_size - x = torch.arange(w).type(type) - y = torch.arange(h).type(type) - z = torch.arange(d).type(type) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - z = (2 * (z / (d - 1)) - 1) - - yy = y.view(1, -1, 1).repeat(d, 1, w) - xx = x.view(1, 1, -1).repeat(d, h, 1) - zz = z.view(-1, 1, 1).repeat(1, h, w) - - meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) - - return meshed - - -class ResBottleneck(nn.Module): - def __init__(self, in_features, stride): - super(ResBottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1) - self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride) - self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1) - self.norm1 = BatchNorm2d(in_features//4, affine=True) - self.norm2 = BatchNorm2d(in_features//4, affine=True) - self.norm3 = BatchNorm2d(in_features, affine=True) - - self.stride = stride - if self.stride != 1: - self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride) - self.norm4 = BatchNorm2d(in_features, affine=True) - - def forward(self, x): - out = self.conv1(x) - out = self.norm1(out) - out = F.relu(out) - out = self.conv2(out) - out = self.norm2(out) - out = F.relu(out) - out = self.conv3(out) - out = self.norm3(out) - if self.stride != 1: - x = self.skip(x) - x = self.norm4(x) - out += x - out = F.relu(out) - return out - - -class ResBlock2d(nn.Module): - """ - Res block, preserve spatial resolution. - """ - - def __init__(self, in_features, kernel_size, padding): - super(ResBlock2d, self).__init__() - self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.norm1 = BatchNorm2d(in_features, affine=True) - self.norm2 = BatchNorm2d(in_features, affine=True) - - def forward(self, x): - out = self.norm1(x) - out = F.relu(out) - out = self.conv1(out) - out = self.norm2(out) - out = F.relu(out) - out = self.conv2(out) - out += x - return out - - -class ResBlock3d(nn.Module): - """ - Res block, preserve spatial resolution. - """ - - def __init__(self, in_features, kernel_size, padding): - super(ResBlock3d, self).__init__() - self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.norm1 = BatchNorm3d(in_features, affine=True) - self.norm2 = BatchNorm3d(in_features, affine=True) - - def forward(self, x): - out = self.norm1(x) - out = F.relu(out) - out = self.conv1(out) - out = self.norm2(out) - out = F.relu(out) - out = self.conv2(out) - out += x - return out - - -class UpBlock2d(nn.Module): - """ - Upsampling block for use in decoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(UpBlock2d, self).__init__() - - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm2d(out_features, affine=True) - - def forward(self, x): - out = F.interpolate(x, scale_factor=2) - out = self.conv(out) - out = self.norm(out) - out = F.relu(out) - return out - -class UpBlock3d(nn.Module): - """ - Upsampling block for use in decoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(UpBlock3d, self).__init__() - - self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm3d(out_features, affine=True) - - def forward(self, x): - # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') - out = F.interpolate(x, scale_factor=(1, 2, 2)) - out = self.conv(out) - out = self.norm(out) - out = F.relu(out) - return out - - -class DownBlock2d(nn.Module): - """ - Downsampling block for use in encoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(DownBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm2d(out_features, affine=True) - self.pool = nn.AvgPool2d(kernel_size=(2, 2)) - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = F.relu(out) - out = self.pool(out) - return out - - -class DownBlock3d(nn.Module): - """ - Downsampling block for use in encoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(DownBlock3d, self).__init__() - ''' - self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups, stride=(1, 2, 2)) - ''' - self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm3d(out_features, affine=True) - self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = F.relu(out) - out = self.pool(out) - return out - - -class SameBlock2d(nn.Module): - """ - Simple block, preserve spatial resolution. - """ - - def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): - super(SameBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, - kernel_size=kernel_size, padding=padding, groups=groups) - self.norm = BatchNorm2d(out_features, affine=True) - if lrelu: - self.ac = nn.LeakyReLU() - else: - self.ac = nn.ReLU() - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = self.ac(out) - return out - - -class Encoder(nn.Module): - """ - Hourglass Encoder - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(Encoder, self).__init__() - - down_blocks = [] - for i in range(num_blocks): - down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), - min(max_features, block_expansion * (2 ** (i + 1))), - kernel_size=3, padding=1)) - self.down_blocks = nn.ModuleList(down_blocks) - - def forward(self, x): - outs = [x] - for down_block in self.down_blocks: - outs.append(down_block(outs[-1])) - return outs - - -class Decoder(nn.Module): - """ - Hourglass Decoder - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(Decoder, self).__init__() - - up_blocks = [] - - for i in range(num_blocks)[::-1]: - in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) - out_filters = min(max_features, block_expansion * (2 ** i)) - up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) - - self.up_blocks = nn.ModuleList(up_blocks) - # self.out_filters = block_expansion - self.out_filters = block_expansion + in_features - - self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) - self.norm = BatchNorm3d(self.out_filters, affine=True) - - def forward(self, x): - out = x.pop() - # for up_block in self.up_blocks[:-1]: - for up_block in self.up_blocks: - out = up_block(out) - skip = x.pop() - out = torch.cat([out, skip], dim=1) - # out = self.up_blocks[-1](out) - out = self.conv(out) - out = self.norm(out) - out = F.relu(out) - return out - - -class Hourglass(nn.Module): - """ - Hourglass architecture. - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(Hourglass, self).__init__() - self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) - self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) - self.out_filters = self.decoder.out_filters - - def forward(self, x): - return self.decoder(self.encoder(x)) - - -class KPHourglass(nn.Module): - """ - Hourglass architecture. - """ - - def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256): - super(KPHourglass, self).__init__() - - self.down_blocks = nn.Sequential() - for i in range(num_blocks): - self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), - min(max_features, block_expansion * (2 ** (i + 1))), - kernel_size=3, padding=1)) - - in_filters = min(max_features, block_expansion * (2 ** num_blocks)) - self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1) - - self.up_blocks = nn.Sequential() - for i in range(num_blocks): - in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i))) - out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1))) - self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) - - self.reshape_depth = reshape_depth - self.out_filters = out_filters - - def forward(self, x): - out = self.down_blocks(x) - out = self.conv(out) - bs, c, h, w = out.shape - out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w) - out = self.up_blocks(out) - - return out - - - -class AntiAliasInterpolation2d(nn.Module): - """ - Band-limited downsampling, for better preservation of the input signal. - """ - def __init__(self, channels, scale): - super(AntiAliasInterpolation2d, self).__init__() - sigma = (1 / scale - 1) / 2 - kernel_size = 2 * round(sigma * 4) + 1 - self.ka = kernel_size // 2 - self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka - - kernel_size = [kernel_size, kernel_size] - sigma = [sigma, sigma] - # The gaussian kernel is the product of the - # gaussian function of each dimension. - kernel = 1 - meshgrids = torch.meshgrid( - [ - torch.arange(size, dtype=torch.float32) - for size in kernel_size - ] - ) - for size, std, mgrid in zip(kernel_size, sigma, meshgrids): - mean = (size - 1) / 2 - kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) - - # Make sure sum of values in gaussian kernel equals 1. - kernel = kernel / torch.sum(kernel) - # Reshape to depthwise convolutional weight - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - - self.register_buffer('weight', kernel) - self.groups = channels - self.scale = scale - inv_scale = 1 / scale - self.int_inv_scale = int(inv_scale) - - def forward(self, input): - if self.scale == 1.0: - return input - - out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) - out = F.conv2d(out, weight=self.weight, groups=self.groups) - out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] - - return out - - -class SPADE(nn.Module): - def __init__(self, norm_nc, label_nc): - super().__init__() - - self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) - nhidden = 128 - - self.mlp_shared = nn.Sequential( - nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), - nn.ReLU()) - self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) - self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) - - def forward(self, x, segmap): - normalized = self.param_free_norm(x) - segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') - actv = self.mlp_shared(segmap) - gamma = self.mlp_gamma(actv) - beta = self.mlp_beta(actv) - out = normalized * (1 + gamma) + beta - return out - - -class SPADEResnetBlock(nn.Module): - def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): - super().__init__() - # Attributes - self.learned_shortcut = (fin != fout) - fmiddle = min(fin, fout) - self.use_se = use_se - # create conv layers - self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) - self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) - if self.learned_shortcut: - self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) - # apply spectral norm if specified - if 'spectral' in norm_G: - self.conv_0 = spectral_norm(self.conv_0) - self.conv_1 = spectral_norm(self.conv_1) - if self.learned_shortcut: - self.conv_s = spectral_norm(self.conv_s) - # define normalization layers - self.norm_0 = SPADE(fin, label_nc) - self.norm_1 = SPADE(fmiddle, label_nc) - if self.learned_shortcut: - self.norm_s = SPADE(fin, label_nc) - - def forward(self, x, seg1): - x_s = self.shortcut(x, seg1) - dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) - dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) - out = x_s + dx - return out - - def shortcut(self, x, seg1): - if self.learned_shortcut: - x_s = self.conv_s(self.norm_s(x, seg1)) - else: - x_s = x - return x_s - - def actvn(self, x): - return F.leaky_relu(x, 2e-1) - -class audio2image(nn.Module): - def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params): - super().__init__() - # Attributes - self.generator = generator - self.kp_extractor = kp_extractor - self.he_estimator_video = he_estimator_video - self.he_estimator_audio = he_estimator_audio - self.train_params = train_params - - def headpose_pred_to_degree(self, pred): - device = pred.device - idx_tensor = [idx for idx in range(66)] - idx_tensor = torch.FloatTensor(idx_tensor).to(device) - pred = F.softmax(pred) - degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 - - return degree - - def get_rotation_matrix(self, yaw, pitch, roll): - yaw = yaw / 180 * 3.14 - pitch = pitch / 180 * 3.14 - roll = roll / 180 * 3.14 - - roll = roll.unsqueeze(1) - pitch = pitch.unsqueeze(1) - yaw = yaw.unsqueeze(1) - - roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), - torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), - torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) - roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) - - pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), - torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), - -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) - pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) - - yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), - torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), - torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) - yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) - - rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) - - return rot_mat - - def keypoint_transformation(self, kp_canonical, he): - kp = kp_canonical['value'] # (bs, k, 3) - yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] - t, exp = he['t'], he['exp'] - - yaw = self.headpose_pred_to_degree(yaw) - pitch = self.headpose_pred_to_degree(pitch) - roll = self.headpose_pred_to_degree(roll) - - rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) - - # keypoint rotation - kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) - - - - # keypoint translation - t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) - kp_t = kp_rotated + t - - # add expression deviation - exp = exp.view(exp.shape[0], -1, 3) - kp_transformed = kp_t + exp - - return {'value': kp_transformed} - - def forward(self, source_image, target_audio): - pose_source = self.he_estimator_video(source_image) - pose_generated = self.he_estimator_audio(target_audio) - kp_canonical = self.kp_extractor(source_image) - kp_source = self.keypoint_transformation(kp_canonical, pose_source) - kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated) - generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated) - return generated \ No newline at end of file diff --git a/sadtalker_audio2pose/src/facerender/pirender/base_function.py b/sadtalker_audio2pose/src/facerender/pirender/base_function.py deleted file mode 100644 index 650fb7de1b95fc34e4b7c17b2526c1f450a577a0..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/pirender/base_function.py +++ /dev/null @@ -1,368 +0,0 @@ -import sys -import math - -import torch -from torch import nn -from torch.nn import functional as F -from torch.autograd import Function -from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm - - -class LayerNorm2d(nn.Module): - def __init__(self, n_out, affine=True): - super(LayerNorm2d, self).__init__() - self.n_out = n_out - self.affine = affine - - if self.affine: - self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) - self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) - - def forward(self, x): - normalized_shape = x.size()[1:] - if self.affine: - return F.layer_norm(x, normalized_shape, \ - self.weight.expand(normalized_shape), - self.bias.expand(normalized_shape)) - - else: - return F.layer_norm(x, normalized_shape) - -class ADAINHourglass(nn.Module): - def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect): - super(ADAINHourglass, self).__init__() - self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect) - self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect) - self.output_nc = self.decoder.output_nc - - def forward(self, x, z): - return self.decoder(self.encoder(x, z), z) - - - -class ADAINEncoder(nn.Module): - def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(ADAINEncoder, self).__init__() - self.layers = layers - self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3) - for i in range(layers): - in_channels = min(ngf * (2**i), img_f) - out_channels = min(ngf *(2**(i+1)), img_f) - model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect) - setattr(self, 'encoder' + str(i), model) - self.output_nc = out_channels - - def forward(self, x, z): - out = self.input_layer(x) - out_list = [out] - for i in range(self.layers): - model = getattr(self, 'encoder' + str(i)) - out = model(out, z) - out_list.append(out) - return out_list - -class ADAINDecoder(nn.Module): - """docstring for ADAINDecoder""" - def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True, - nonlinearity=nn.LeakyReLU(), use_spect=False): - - super(ADAINDecoder, self).__init__() - self.encoder_layers = encoder_layers - self.decoder_layers = decoder_layers - self.skip_connect = skip_connect - use_transpose = True - - for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]: - in_channels = min(ngf * (2**(i+1)), img_f) - in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels - out_channels = min(ngf * (2**i), img_f) - model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect) - setattr(self, 'decoder' + str(i), model) - - self.output_nc = out_channels*2 if self.skip_connect else out_channels - - def forward(self, x, z): - out = x.pop() if self.skip_connect else x - for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]: - model = getattr(self, 'decoder' + str(i)) - out = model(out, z) - out = torch.cat([out, x.pop()], 1) if self.skip_connect else out - return out - -class ADAINEncoderBlock(nn.Module): - def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(ADAINEncoderBlock, self).__init__() - kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1} - kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1} - - self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect) - self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect) - - - self.norm_0 = ADAIN(input_nc, feature_nc) - self.norm_1 = ADAIN(output_nc, feature_nc) - self.actvn = nonlinearity - - def forward(self, x, z): - x = self.conv_0(self.actvn(self.norm_0(x, z))) - x = self.conv_1(self.actvn(self.norm_1(x, z))) - return x - -class ADAINDecoderBlock(nn.Module): - def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(ADAINDecoderBlock, self).__init__() - # Attributes - self.actvn = nonlinearity - hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc - - kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1} - if use_transpose: - kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1} - else: - kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1} - - # create conv layers - self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect) - if use_transpose: - self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect) - self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect) - else: - self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect), - nn.Upsample(scale_factor=2)) - self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect), - nn.Upsample(scale_factor=2)) - # define normalization layers - self.norm_0 = ADAIN(input_nc, feature_nc) - self.norm_1 = ADAIN(hidden_nc, feature_nc) - self.norm_s = ADAIN(input_nc, feature_nc) - - def forward(self, x, z): - x_s = self.shortcut(x, z) - dx = self.conv_0(self.actvn(self.norm_0(x, z))) - dx = self.conv_1(self.actvn(self.norm_1(dx, z))) - out = x_s + dx - return out - - def shortcut(self, x, z): - x_s = self.conv_s(self.actvn(self.norm_s(x, z))) - return x_s - - -def spectral_norm(module, use_spect=True): - """use spectral normal layer to stable the training process""" - if use_spect: - return SpectralNorm(module) - else: - return module - - -class ADAIN(nn.Module): - def __init__(self, norm_nc, feature_nc): - super().__init__() - - self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) - - nhidden = 128 - use_bias=True - - self.mlp_shared = nn.Sequential( - nn.Linear(feature_nc, nhidden, bias=use_bias), - nn.ReLU() - ) - self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias) - self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias) - - def forward(self, x, feature): - - # Part 1. generate parameter-free normalized activations - normalized = self.param_free_norm(x) - - # Part 2. produce scaling and bias conditioned on feature - feature = feature.view(feature.size(0), -1) - actv = self.mlp_shared(feature) - gamma = self.mlp_gamma(actv) - beta = self.mlp_beta(actv) - - # apply scale and bias - gamma = gamma.view(*gamma.size()[:2], 1,1) - beta = beta.view(*beta.size()[:2], 1,1) - out = normalized * (1 + gamma) + beta - return out - - -class FineEncoder(nn.Module): - """docstring for Encoder""" - def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineEncoder, self).__init__() - self.layers = layers - self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect) - for i in range(layers): - in_channels = min(ngf*(2**i), img_f) - out_channels = min(ngf*(2**(i+1)), img_f) - model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) - setattr(self, 'down' + str(i), model) - self.output_nc = out_channels - - def forward(self, x): - x = self.first(x) - out=[x] - for i in range(self.layers): - model = getattr(self, 'down'+str(i)) - x = model(x) - out.append(x) - return out - -class FineDecoder(nn.Module): - """docstring for FineDecoder""" - def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineDecoder, self).__init__() - self.layers = layers - for i in range(layers)[::-1]: - in_channels = min(ngf*(2**(i+1)), img_f) - out_channels = min(ngf*(2**i), img_f) - up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) - res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect) - jump = Jump(out_channels, norm_layer, nonlinearity, use_spect) - - setattr(self, 'up' + str(i), up) - setattr(self, 'res' + str(i), res) - setattr(self, 'jump' + str(i), jump) - - self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh') - - self.output_nc = out_channels - - def forward(self, x, z): - out = x.pop() - for i in range(self.layers)[::-1]: - res_model = getattr(self, 'res' + str(i)) - up_model = getattr(self, 'up' + str(i)) - jump_model = getattr(self, 'jump' + str(i)) - out = res_model(out, z) - out = up_model(out) - out = jump_model(x.pop()) + out - out_image = self.final(out) - return out_image - -class FirstBlock2d(nn.Module): - """ - Downsampling block for use in encoder. - """ - def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FirstBlock2d, self).__init__() - kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity) - else: - self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) - - - def forward(self, x): - out = self.model(x) - return out - -class DownBlock2d(nn.Module): - def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(DownBlock2d, self).__init__() - - - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - pool = nn.AvgPool2d(kernel_size=(2, 2)) - - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity, pool) - else: - self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool) - - def forward(self, x): - out = self.model(x) - return out - -class UpBlock2d(nn.Module): - def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(UpBlock2d, self).__init__() - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity) - else: - self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) - - def forward(self, x): - out = self.model(F.interpolate(x, scale_factor=2)) - return out - -class FineADAINResBlocks(nn.Module): - def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineADAINResBlocks, self).__init__() - self.num_block = num_block - for i in range(num_block): - model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect) - setattr(self, 'res'+str(i), model) - - def forward(self, x, z): - for i in range(self.num_block): - model = getattr(self, 'res'+str(i)) - x = model(x, z) - return x - -class Jump(nn.Module): - def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(Jump, self).__init__() - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) - - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity) - else: - self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity) - - def forward(self, x): - out = self.model(x) - return out - -class FineADAINResBlock2d(nn.Module): - """ - Define an Residual block for different types - """ - def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineADAINResBlock2d, self).__init__() - - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - - self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) - self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) - self.norm1 = ADAIN(input_nc, feature_nc) - self.norm2 = ADAIN(input_nc, feature_nc) - - self.actvn = nonlinearity - - - def forward(self, x, z): - dx = self.actvn(self.norm1(self.conv1(x), z)) - dx = self.norm2(self.conv2(x), z) - out = dx + x - return out - -class FinalBlock2d(nn.Module): - """ - Define the output layer - """ - def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'): - super(FinalBlock2d, self).__init__() - - kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - - if tanh_or_sigmoid == 'sigmoid': - out_nonlinearity = nn.Sigmoid() - else: - out_nonlinearity = nn.Tanh() - - self.model = nn.Sequential(conv, out_nonlinearity) - def forward(self, x): - out = self.model(x) - return out \ No newline at end of file diff --git a/sadtalker_audio2pose/src/facerender/pirender/config.py b/sadtalker_audio2pose/src/facerender/pirender/config.py deleted file mode 100644 index 29dc2d1b9008dbf2dc3c0a307212471621bae8da..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/pirender/config.py +++ /dev/null @@ -1,211 +0,0 @@ -import collections -import functools -import os -import re - -import yaml - -class AttrDict(dict): - """Dict as attribute trick.""" - - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - for key, value in self.__dict__.items(): - if isinstance(value, dict): - self.__dict__[key] = AttrDict(value) - elif isinstance(value, (list, tuple)): - if isinstance(value[0], dict): - self.__dict__[key] = [AttrDict(item) for item in value] - else: - self.__dict__[key] = value - - def yaml(self): - """Convert object to yaml dict and return.""" - yaml_dict = {} - for key, value in self.__dict__.items(): - if isinstance(value, AttrDict): - yaml_dict[key] = value.yaml() - elif isinstance(value, list): - if isinstance(value[0], AttrDict): - new_l = [] - for item in value: - new_l.append(item.yaml()) - yaml_dict[key] = new_l - else: - yaml_dict[key] = value - else: - yaml_dict[key] = value - return yaml_dict - - def __repr__(self): - """Print all variables.""" - ret_str = [] - for key, value in self.__dict__.items(): - if isinstance(value, AttrDict): - ret_str.append('{}:'.format(key)) - child_ret_str = value.__repr__().split('\n') - for item in child_ret_str: - ret_str.append(' ' + item) - elif isinstance(value, list): - if isinstance(value[0], AttrDict): - ret_str.append('{}:'.format(key)) - for item in value: - # Treat as AttrDict above. - child_ret_str = item.__repr__().split('\n') - for item in child_ret_str: - ret_str.append(' ' + item) - else: - ret_str.append('{}: {}'.format(key, value)) - else: - ret_str.append('{}: {}'.format(key, value)) - return '\n'.join(ret_str) - - -class Config(AttrDict): - r"""Configuration class. This should include every human specifiable - hyperparameter values for your training.""" - - def __init__(self, filename=None, args=None, verbose=False, is_train=True): - super(Config, self).__init__() - # Set default parameters. - # Logging. - - large_number = 1000000000 - self.snapshot_save_iter = large_number - self.snapshot_save_epoch = large_number - self.snapshot_save_start_iter = 0 - self.snapshot_save_start_epoch = 0 - self.image_save_iter = large_number - self.eval_epoch = large_number - self.start_eval_epoch = large_number - self.eval_epoch = large_number - self.max_epoch = large_number - self.max_iter = large_number - self.logging_iter = 100 - self.image_to_tensorboard=False - self.which_iter = 0 # args.which_iter - self.resume = False - - self.checkpoints_dir = '/Users/shadowcun/Downloads/' - self.name = 'face' - self.phase = 'train' if is_train else 'test' - - # Networks. - self.gen = AttrDict(type='generators.dummy') - self.dis = AttrDict(type='discriminators.dummy') - - # Optimizers. - self.gen_optimizer = AttrDict(type='adam', - lr=0.0001, - adam_beta1=0.0, - adam_beta2=0.999, - eps=1e-8, - lr_policy=AttrDict(iteration_mode=False, - type='step', - step_size=large_number, - gamma=1)) - self.dis_optimizer = AttrDict(type='adam', - lr=0.0001, - adam_beta1=0.0, - adam_beta2=0.999, - eps=1e-8, - lr_policy=AttrDict(iteration_mode=False, - type='step', - step_size=large_number, - gamma=1)) - # Data. - self.data = AttrDict(name='dummy', - type='datasets.images', - num_workers=0) - self.test_data = AttrDict(name='dummy', - type='datasets.images', - num_workers=0, - test=AttrDict(is_lmdb=False, - roots='', - batch_size=1)) - self.trainer = AttrDict( - model_average=False, - model_average_beta=0.9999, - model_average_start_iteration=1000, - model_average_batch_norm_estimation_iteration=30, - model_average_remove_sn=True, - image_to_tensorboard=False, - hparam_to_tensorboard=False, - distributed_data_parallel='pytorch', - delay_allreduce=True, - gan_relativistic=False, - gen_step=1, - dis_step=1) - - # # Cudnn. - self.cudnn = AttrDict(deterministic=False, - benchmark=True) - - # Others. - self.pretrained_weight = '' - self.inference_args = AttrDict() - - - # Update with given configurations. - assert os.path.exists(filename), 'File {} not exist.'.format(filename) - loader = yaml.SafeLoader - loader.add_implicit_resolver( - u'tag:yaml.org,2002:float', - re.compile(u'''^(?: - [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? - |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) - |\\.[0-9_]+(?:[eE][-+][0-9]+)? - |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* - |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.')) - try: - with open(filename, 'r') as f: - cfg_dict = yaml.load(f, Loader=loader) - except EnvironmentError: - print('Please check the file with name of "%s"', filename) - recursive_update(self, cfg_dict) - - # Put common opts in both gen and dis. - if 'common' in cfg_dict: - self.common = AttrDict(**cfg_dict['common']) - self.gen.common = self.common - self.dis.common = self.common - - - if verbose: - print(' config '.center(80, '-')) - print(self.__repr__()) - print(''.center(80, '-')) - - -def rsetattr(obj, attr, val): - """Recursively find object and set value""" - pre, _, post = attr.rpartition('.') - return setattr(rgetattr(obj, pre) if pre else obj, post, val) - - -def rgetattr(obj, attr, *args): - """Recursively find object and return value""" - - def _getattr(obj, attr): - r"""Get attribute.""" - return getattr(obj, attr, *args) - - return functools.reduce(_getattr, [obj] + attr.split('.')) - - -def recursive_update(d, u): - """Recursively update AttrDict d with AttrDict u""" - for key, value in u.items(): - if isinstance(value, collections.abc.Mapping): - d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) - elif isinstance(value, (list, tuple)): - if isinstance(value[0], dict): - d.__dict__[key] = [AttrDict(item) for item in value] - else: - d.__dict__[key] = value - else: - d.__dict__[key] = value - return d diff --git a/sadtalker_audio2pose/src/facerender/pirender/face_model.py b/sadtalker_audio2pose/src/facerender/pirender/face_model.py deleted file mode 100644 index 0f83e2fc5d8c66cf9bd2e2c5549773e11e0f8a44..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/pirender/face_model.py +++ /dev/null @@ -1,178 +0,0 @@ -import functools -import torch -import torch.nn as nn -from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder - -def convert_flow_to_deformation(flow): - r"""convert flow fields to deformations. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - deformation (tensor): The deformation used for warpping - """ - b,c,h,w = flow.shape - flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) - grid = make_coordinate_grid(flow) - deformation = grid + flow_norm.permute(0,2,3,1) - return deformation - -def make_coordinate_grid(flow): - r"""obtain coordinate grid with the same size as the flow filed. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - grid (tensor): The grid with the same size as the input flow - """ - b,c,h,w = flow.shape - - x = torch.arange(w).to(flow) - y = torch.arange(h).to(flow) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - - yy = y.view(-1, 1).repeat(1, w) - xx = x.view(1, -1).repeat(h, 1) - - meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) - meshed = meshed.expand(b, -1, -1, -1) - return meshed - - -def warp_image(source_image, deformation): - r"""warp the input image according to the deformation - - Args: - source_image (tensor): source images to be warpped - deformation (tensor): deformations used to warp the images; value in range (-1, 1) - Returns: - output (tensor): the warpped images - """ - _, h_old, w_old, _ = deformation.shape - _, _, h, w = source_image.shape - if h_old != h or w_old != w: - deformation = deformation.permute(0, 3, 1, 2) - deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') - deformation = deformation.permute(0, 2, 3, 1) - return torch.nn.functional.grid_sample(source_image, deformation) - - -class FaceGenerator(nn.Module): - def __init__( - self, - mapping_net, - warpping_net, - editing_net, - common - ): - super(FaceGenerator, self).__init__() - self.mapping_net = MappingNet(**mapping_net) - self.warpping_net = WarpingNet(**warpping_net, **common) - self.editing_net = EditingNet(**editing_net, **common) - - def forward( - self, - input_image, - driving_source, - stage=None - ): - if stage == 'warp': - descriptor = self.mapping_net(driving_source) - output = self.warpping_net(input_image, descriptor) - else: - descriptor = self.mapping_net(driving_source) - output = self.warpping_net(input_image, descriptor) - output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) - return output - -class MappingNet(nn.Module): - def __init__(self, coeff_nc, descriptor_nc, layer): - super( MappingNet, self).__init__() - - self.layer = layer - nonlinearity = nn.LeakyReLU(0.1) - - self.first = nn.Sequential( - torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) - - for i in range(layer): - net = nn.Sequential(nonlinearity, - torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) - setattr(self, 'encoder' + str(i), net) - - self.pooling = nn.AdaptiveAvgPool1d(1) - self.output_nc = descriptor_nc - - def forward(self, input_3dmm): - out = self.first(input_3dmm) - for i in range(self.layer): - model = getattr(self, 'encoder' + str(i)) - out = model(out) + out[:,:,3:-3] - out = self.pooling(out) - return out - -class WarpingNet(nn.Module): - def __init__( - self, - image_nc, - descriptor_nc, - base_nc, - max_nc, - encoder_layer, - decoder_layer, - use_spect - ): - super( WarpingNet, self).__init__() - - nonlinearity = nn.LeakyReLU(0.1) - norm_layer = functools.partial(LayerNorm2d, affine=True) - kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} - - self.descriptor_nc = descriptor_nc - self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, - max_nc, encoder_layer, decoder_layer, **kwargs) - - self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), - nonlinearity, - nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) - - self.pool = nn.AdaptiveAvgPool2d(1) - - def forward(self, input_image, descriptor): - final_output={} - output = self.hourglass(input_image, descriptor) - final_output['flow_field'] = self.flow_out(output) - - deformation = convert_flow_to_deformation(final_output['flow_field']) - final_output['warp_image'] = warp_image(input_image, deformation) - return final_output - - -class EditingNet(nn.Module): - def __init__( - self, - image_nc, - descriptor_nc, - layer, - base_nc, - max_nc, - num_res_blocks, - use_spect): - super(EditingNet, self).__init__() - - nonlinearity = nn.LeakyReLU(0.1) - norm_layer = functools.partial(LayerNorm2d, affine=True) - kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} - self.descriptor_nc = descriptor_nc - - # encoder part - self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) - self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) - - def forward(self, input_image, warp_image, descriptor): - x = torch.cat([input_image, warp_image], 1) - x = self.encoder(x) - gen_image = self.decoder(x, descriptor) - return gen_image diff --git a/sadtalker_audio2pose/src/facerender/pirender_animate.py b/sadtalker_audio2pose/src/facerender/pirender_animate.py deleted file mode 100644 index 07d4ccf0918f09dcfa422a85694bd17bf42d11ff..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/pirender_animate.py +++ /dev/null @@ -1,266 +0,0 @@ -import os -import uuid -import cv2 -from tqdm import tqdm -import yaml -import numpy as np -import warnings -from skimage import img_as_ubyte -import safetensors -import safetensors.torch -warnings.filterwarnings('ignore') - - -import imageio -import torch -import torchvision - -from src.facerender.pirender.config import Config -from src.facerender.pirender.face_model import FaceGenerator - -from pydub import AudioSegment -from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list -from src.utils.paste_pic import paste_pic -from src.utils.videoio import save_video_with_watermark -from src.utils.flow_util import vis_flow -from scipy.io import savemat,loadmat - -try: - import webui # in webui - in_webui = True -except: - in_webui = False - -expession = loadmat('expression.mat') -control_dict = {} -for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']: - control_dict[item] = torch.tensor(expession[item])[0] - -class AnimateFromCoeff_PIRender(): - - def __init__(self, sadtalker_path, device): - - opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False) - opt.device = device - self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device) - checkpoint_path = sadtalker_path['pirender_checkpoint'] - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False) - print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path)) - self.net_G = self.net_G_ema.eval() - self.device = device - - - def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - - num = 16 - - # import pdb; pdb.set_trace() - # target_semantics_ - current = target_semantics[0, 0, :64, 0] - for control_k in range(len(control_dict.keys())): - listx = list(control_dict.keys()) - control_v = control_dict[listx[control_k]] - for i in range(num): - expression = (control_v-current)*i/(num-1)+current - target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None] - - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - result = img_as_ubyte(video) - - ### the generated video is 256x256, so we keep the aspect ratio, - original_size = crop_info[0] - if original_size: - result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - - video_name = x['video_name'] + '.mp4' - path = os.path.join(video_save_dir, 'temp_'+video_name) - - imageio.mimsave(path, result, fps=float(25)) - - av_path = os.path.join(video_save_dir, video_name) - return_path = av_path - - audio_path = x['audio_path'] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') - start_time = 0 - # cog will not keep the .mp3 filename - sound = AudioSegment.from_file(audio_path) - frames = frame_num - end_time = start_time + frames*1/25*1000 - word1=sound.set_frame_rate(16000) - word = word1[start_time:end_time] - word.export(new_audio_path, format="wav") - - save_video_with_watermark(path, new_audio_path, av_path, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name}') - - if 'full' in preprocess.lower(): - # only add watermark to the full image. - video_name_full = x['video_name'] + '_full.mp4' - full_video_path = os.path.join(video_save_dir, video_name_full) - return_path = full_video_path - paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) - print(f'The generated video is named {video_save_dir}/{video_name_full}') - else: - full_video_path = av_path - - #### paste back then enhancers - if enhancer: - video_name_enhancer = x['video_name'] + '_enhanced.mp4' - enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) - av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) - return_path = av_path_enhancer - - try: - enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - except: - enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - - save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') - os.remove(enhanced_path) - - os.remove(path) - os.remove(new_audio_path) - - return return_path - - def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - - - num = 16 - - current = target_semantics[0, 0, :64, 0] - for control_k in range(len(control_dict.keys())): - listx = list(control_dict.keys()) - control_v = control_dict[listx[control_k]] - for i in range(num): - expression = (control_v-current)*i/(num-1)+current - target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None] - - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - - results = np.stack(video, axis=0) - - ### the generated video is 256x256, so we keep the aspect ratio, - # original_size = crop_info[0] - # if original_size: - # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - # results = np.stack(result, axis=0) - - x_name = os.path.basename(pic_path) - save_name = os.path.join(video_save_dir, x_name + '.flo') - save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4') - - flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False) - - flow_viz = [] - for kk in range(flow_full.shape[0]): - tmp = vis_flow(flow_full[kk]) - flow_viz.append(tmp) - flow_viz = np.stack(flow_viz) - - torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'}) - - return save_name_flow_vis - - -def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False): - - if not os.path.isfile(pic_path): - raise ValueError('pic_path must be a valid path to video/image file') - elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_img = cv2.imread(pic_path) - else: - # loader for videos - video_stream = cv2.VideoCapture(pic_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - break - full_img = frame - frame_h = full_img.shape[0] - frame_w = full_img.shape[1] - - # full images, we only use it as reference for zero init image. - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) - # template = np.zeros((frame_h, frame_w, 2)) # full flows - out_tmp = [] - for crop_frame in tqdm(flows, 'seamlessClone:'): - p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4) - - gen_img = np.zeros((frame_h, frame_w, 2)) - # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE) - gen_img[oy1:oy2,ox1:ox2] = p - out_tmp.append(gen_img) - - np.save(save_name, np.stack(out_tmp)) - return np.stack(out_tmp) \ No newline at end of file diff --git a/sadtalker_audio2pose/src/facerender/pirender_animate_control.py b/sadtalker_audio2pose/src/facerender/pirender_animate_control.py deleted file mode 100644 index 1c357f35577816c8d6731627afd505c6dd8efdca..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/pirender_animate_control.py +++ /dev/null @@ -1,251 +0,0 @@ -import os -import uuid -import cv2 -from tqdm import tqdm -import yaml -import numpy as np -import warnings -from skimage import img_as_ubyte -import safetensors -import safetensors.torch -warnings.filterwarnings('ignore') - - -import imageio -import torch -import torchvision - -from src.facerender.pirender.config import Config -from src.facerender.pirender.face_model import FaceGenerator - -from pydub import AudioSegment -from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list -from src.utils.paste_pic import paste_pic -from src.utils.videoio import save_video_with_watermark -from src.utils.flow_util import vis_flow - -from scipy.io import savemat,loadmat - -try: - import webui # in webui - in_webui = True -except: - in_webui = False - -expession = loadmat('expression.mat') -control_dict = {} -for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']: - control_dict[item] = torch.tensor(expession[item])[0] - -class AnimateFromCoeff_PIRender(): - - def __init__(self, sadtalker_path, device): - - opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False) - opt.device = device - self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device) - checkpoint_path = sadtalker_path['pirender_checkpoint'] - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False) - print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path)) - self.net_G = self.net_G_ema.eval() - self.device = device - - - def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - num = 10 - - # target_semantics_ - current = target_semantics['target_semantics_list'][0, :64, 0] - for control in control_dict: - for i in range(num): - expression = (control_dict[control]-current)*i/(num-1)+current - target_semantics['target_semantics_list'][:, :64, :] = expression[None, :, None] - - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - result = img_as_ubyte(video) - - ### the generated video is 256x256, so we keep the aspect ratio, - original_size = crop_info[0] - if original_size: - result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - - video_name = x['video_name'] + '.mp4' - path = os.path.join(video_save_dir, 'temp_'+video_name) - - imageio.mimsave(path, result, fps=float(25)) - - av_path = os.path.join(video_save_dir, video_name) - return_path = av_path - - audio_path = x['audio_path'] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') - start_time = 0 - # cog will not keep the .mp3 filename - sound = AudioSegment.from_file(audio_path) - frames = frame_num - end_time = start_time + frames*1/25*1000 - word1=sound.set_frame_rate(16000) - word = word1[start_time:end_time] - word.export(new_audio_path, format="wav") - - save_video_with_watermark(path, new_audio_path, av_path, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name}') - - if 'full' in preprocess.lower(): - # only add watermark to the full image. - video_name_full = x['video_name'] + '_full.mp4' - full_video_path = os.path.join(video_save_dir, video_name_full) - return_path = full_video_path - paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) - print(f'The generated video is named {video_save_dir}/{video_name_full}') - else: - full_video_path = av_path - - #### paste back then enhancers - if enhancer: - video_name_enhancer = x['video_name'] + '_enhanced.mp4' - enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) - av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) - return_path = av_path_enhancer - - try: - enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - except: - enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - - save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') - os.remove(enhanced_path) - - os.remove(path) - os.remove(new_audio_path) - - return return_path - - def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - - results = np.stack(video, axis=0) - - ### the generated video is 256x256, so we keep the aspect ratio, - # original_size = crop_info[0] - # if original_size: - # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - # results = np.stack(result, axis=0) - - x_name = os.path.basename(pic_path) - save_name = os.path.join(video_save_dir, x_name + '.flo') - save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4') - - flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False) - - flow_viz = [] - for kk in range(flow_full.shape[0]): - tmp = vis_flow(flow_full[kk]) - flow_viz.append(tmp) - flow_viz = np.stack(flow_viz) - - torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'}) - - return save_name_flow_vis - - -def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False): - - if not os.path.isfile(pic_path): - raise ValueError('pic_path must be a valid path to video/image file') - elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_img = cv2.imread(pic_path) - else: - # loader for videos - video_stream = cv2.VideoCapture(pic_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - break - full_img = frame - frame_h = full_img.shape[0] - frame_w = full_img.shape[1] - - # full images, we only use it as reference for zero init image. - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) - # template = np.zeros((frame_h, frame_w, 2)) # full flows - out_tmp = [] - for crop_frame in tqdm(flows, 'seamlessClone:'): - p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4) - - gen_img = np.zeros((frame_h, frame_w, 2)) - # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE) - gen_img[oy1:oy2,ox1:ox2] = p - out_tmp.append(gen_img) - - np.save(save_name, np.stack(out_tmp)) - return np.stack(out_tmp) \ No newline at end of file diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/__init__.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/__init__.py deleted file mode 100644 index 48871cdcdc882c903501ecc6d70fcb1b50bd7e9f..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/sync_batchnorm/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- -# File : __init__.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d -from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/batchnorm.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/batchnorm.py deleted file mode 100644 index b4cc2ccd2f0c904cbe433fb6136f443f0fa86fa6..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/sync_batchnorm/batchnorm.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -# File : batchnorm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import collections - -import torch -import torch.nn.functional as F - -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast - -from .comm import SyncMaster - -__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] - - -def _sum_ft(tensor): - """sum over the first and last dimention""" - return tensor.sum(dim=0).sum(dim=-1) - - -def _unsqueeze_ft(tensor): - """add new dementions at the front and the tail""" - return tensor.unsqueeze(0).unsqueeze(-1) - - -_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) -_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) - - -class _SynchronizedBatchNorm(_BatchNorm): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): - super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) - - self._sync_master = SyncMaster(self._data_parallel_master) - - self._is_parallel = False - self._parallel_id = None - self._slave_pipe = None - - def forward(self, input): - # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. - if not (self._is_parallel and self.training): - return F.batch_norm( - input, self.running_mean, self.running_var, self.weight, self.bias, - self.training, self.momentum, self.eps) - - # Resize the input to (B, C, -1). - input_shape = input.size() - input = input.view(input.size(0), self.num_features, -1) - - # Compute the sum and square-sum. - sum_size = input.size(0) * input.size(2) - input_sum = _sum_ft(input) - input_ssum = _sum_ft(input ** 2) - - # Reduce-and-broadcast the statistics. - if self._parallel_id == 0: - mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) - else: - mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) - - # Compute the output. - if self.affine: - # MJY:: Fuse the multiplication for speed. - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) - else: - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) - - # Reshape it. - return output.view(input_shape) - - def __data_parallel_replicate__(self, ctx, copy_id): - self._is_parallel = True - self._parallel_id = copy_id - - # parallel_id == 0 means master device. - if self._parallel_id == 0: - ctx.sync_master = self._sync_master - else: - self._slave_pipe = ctx.sync_master.register_slave(copy_id) - - def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" - - # Always using same "device order" makes the ReduceAdd operation faster. - # Thanks to:: Tete Xiao (http://tetexiao.com/) - intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) - - to_reduce = [i[1][:2] for i in intermediates] - to_reduce = [j for i in to_reduce for j in i] # flatten - target_gpus = [i[1].sum.get_device() for i in intermediates] - - sum_size = sum([i[1].sum_size for i in intermediates]) - sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) - mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) - - broadcasted = Broadcast.apply(target_gpus, mean, inv_std) - - outputs = [] - for i, rec in enumerate(intermediates): - outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) - - return outputs - - def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with sum and square-sum. This method - also maintains the moving average on the master device.""" - assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' - mean = sum_ / size - sumvar = ssum - sum_ * mean - unbias_var = sumvar / (size - 1) - bias_var = sumvar / size - - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data - - return mean, bias_var.clamp(self.eps) ** -0.5 - - -class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): - r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a - mini-batch. - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm1d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm - - Args: - num_features: num_features from an expected input of size - `batch_size x num_features [x width]` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C)` or :math:`(N, C, L)` - - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm1d, self)._check_input_dim(input) - - -class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch - of 3d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm2d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 4: - raise ValueError('expected 4D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm2d, self)._check_input_dim(input) - - -class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch - of 4d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm3d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm - or Spatio-temporal BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x depth x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C, D, H, W)` - - Output: :math:`(N, C, D, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 5: - raise ValueError('expected 5D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/comm.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/comm.py deleted file mode 100644 index b66ec4aea213edf4330beda0a8c8b93d6db77a60..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/sync_batchnorm/comm.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- -# File : comm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import queue -import collections -import threading - -__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] - - -class FutureResult(object): - """A thread-safe future implementation. Used only as one-to-one pipe.""" - - def __init__(self): - self._result = None - self._lock = threading.Lock() - self._cond = threading.Condition(self._lock) - - def put(self, result): - with self._lock: - assert self._result is None, 'Previous result has\'t been fetched.' - self._result = result - self._cond.notify() - - def get(self): - with self._lock: - if self._result is None: - self._cond.wait() - - res = self._result - self._result = None - return res - - -_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) -_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) - - -class SlavePipe(_SlavePipeBase): - """Pipe for master-slave communication.""" - - def run_slave(self, msg): - self.queue.put((self.identifier, msg)) - ret = self.result.get() - self.queue.put(True) - return ret - - -class SyncMaster(object): - """An abstract `SyncMaster` object. - - - During the replication, as the data parallel will trigger an callback of each module, all slave devices should - call `register(id)` and obtain an `SlavePipe` to communicate with the master. - - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, - and passed to a registered callback. - - After receiving the messages, the master device should gather the information and determine to message passed - back to each slave devices. - """ - - def __init__(self, master_callback): - """ - - Args: - master_callback: a callback to be invoked after having collected messages from slave devices. - """ - self._master_callback = master_callback - self._queue = queue.Queue() - self._registry = collections.OrderedDict() - self._activated = False - - def __getstate__(self): - return {'master_callback': self._master_callback} - - def __setstate__(self, state): - self.__init__(state['master_callback']) - - def register_slave(self, identifier): - """ - Register an slave device. - - Args: - identifier: an identifier, usually is the device id. - - Returns: a `SlavePipe` object which can be used to communicate with the master device. - - """ - if self._activated: - assert self._queue.empty(), 'Queue is not clean before next initialization.' - self._activated = False - self._registry.clear() - future = FutureResult() - self._registry[identifier] = _MasterRegistry(future) - return SlavePipe(identifier, self._queue, future) - - def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. - The messages were first collected from each devices (including the master device), and then - an callback will be invoked to compute the message to be sent back to each devices - (including the master device). - - Args: - master_msg: the message that the master want to send to itself. This will be placed as the first - message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. - - Returns: the message to be sent back to the master device. - - """ - self._activated = True - - intermediates = [(0, master_msg)] - for i in range(self.nr_slaves): - intermediates.append(self._queue.get()) - - results = self._master_callback(intermediates) - assert results[0][0] == 0, 'The first result should belongs to the master.' - - for i, res in results: - if i == 0: - continue - self._registry[i].result.put(res) - - for i in range(self.nr_slaves): - assert self._queue.get() is True - - return results[0][1] - - @property - def nr_slaves(self): - return len(self._registry) diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/replicate.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/replicate.py deleted file mode 100644 index 9b97380d9c5fbe75c4b3583d3668ccd6a2848699..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/sync_batchnorm/replicate.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# File : replicate.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import functools - -from torch.nn.parallel.data_parallel import DataParallel - -__all__ = [ - 'CallbackContext', - 'execute_replication_callbacks', - 'DataParallelWithCallback', - 'patch_replication_callback' -] - - -class CallbackContext(object): - pass - - -def execute_replication_callbacks(modules): - """ - Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. - - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Note that, as all modules are isomorphism, we assign each sub-module with a context - (shared among multiple copies of this module on different devices). - Through this context, different copies can share some information. - - We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback - of any slave copies. - """ - master_copy = modules[0] - nr_modules = len(list(master_copy.modules())) - ctxs = [CallbackContext() for _ in range(nr_modules)] - - for i, module in enumerate(modules): - for j, m in enumerate(module.modules()): - if hasattr(m, '__data_parallel_replicate__'): - m.__data_parallel_replicate__(ctxs[j], i) - - -class DataParallelWithCallback(DataParallel): - """ - Data Parallel with a replication callback. - - An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by - original `replicate` function. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - # sync_bn.__data_parallel_replicate__ will be invoked. - """ - - def replicate(self, module, device_ids): - modules = super(DataParallelWithCallback, self).replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - -def patch_replication_callback(data_parallel): - """ - Monkey-patch an existing `DataParallel` object. Add the replication callback. - Useful when you have customized `DataParallel` implementation. - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) - > patch_replication_callback(sync_bn) - # this is equivalent to - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - """ - - assert isinstance(data_parallel, DataParallel) - - old_replicate = data_parallel.replicate - - @functools.wraps(old_replicate) - def new_replicate(module, device_ids): - modules = old_replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - data_parallel.replicate = new_replicate diff --git a/sadtalker_audio2pose/src/facerender/sync_batchnorm/unittest.py b/sadtalker_audio2pose/src/facerender/sync_batchnorm/unittest.py deleted file mode 100644 index 9716d035495097fb086ec050ab0bc9b76b9d28a0..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/facerender/sync_batchnorm/unittest.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# File : unittest.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import unittest - -import numpy as np -from torch.autograd import Variable - - -def as_numpy(v): - if isinstance(v, Variable): - v = v.data - return v.cpu().numpy() - - -class TorchTestCase(unittest.TestCase): - def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): - npa, npb = as_numpy(a), as_numpy(b) - self.assertTrue( - np.allclose(npa, npb, atol=atol), - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - ) diff --git a/sadtalker_audio2pose/src/generate_batch.py b/sadtalker_audio2pose/src/generate_batch.py deleted file mode 100644 index 2fcaff51276d489aa76c15e4979864a4d4f74aa4..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/generate_batch.py +++ /dev/null @@ -1,120 +0,0 @@ -import os - -from tqdm import tqdm -import torch -import numpy as np -import random -import scipy.io as scio -import src.utils.audio as audio - -def crop_pad_audio(wav, audio_length): - if len(wav) > audio_length: - wav = wav[:audio_length] - elif len(wav) < audio_length: - wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) - return wav - -def parse_audio_length(audio_length, sr, fps): - bit_per_frames = sr / fps - - num_frames = int(audio_length / bit_per_frames) - audio_length = int(num_frames * bit_per_frames) - - return audio_length, num_frames - -def generate_blink_seq(num_frames): - ratio = np.zeros((num_frames,1)) - frame_id = 0 - while frame_id in range(num_frames): - start = 80 - if frame_id+start+9<=num_frames - 1: - ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] - frame_id = frame_id+start+9 - else: - break - return ratio - -def generate_blink_seq_randomly(num_frames): - ratio = np.zeros((num_frames,1)) - if num_frames<=20: - return ratio - frame_id = 0 - while frame_id in range(num_frames): - start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) - if frame_id+start+5<=num_frames - 1: - ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] - frame_id = frame_id+start+5 - else: - break - return ratio - -def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): - - syncnet_mel_step_size = 16 - fps = 25 - - pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - - - if idlemode: - num_frames = int(length_of_audio * 25) - indiv_mels = np.zeros((num_frames, 80, 16)) - else: - wav = audio.load_wav(audio_path, 16000) - wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) - wav = crop_pad_audio(wav, wav_length) - orig_mel = audio.melspectrogram(wav).T - spec = orig_mel.copy() # nframes 80 - indiv_mels = [] - - for i in tqdm(range(num_frames), 'mel:'): - start_frame_num = i-2 - start_idx = int(80. * (start_frame_num / float(fps))) - end_idx = start_idx + syncnet_mel_step_size - seq = list(range(start_idx, end_idx)) - seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] - m = spec[seq, :] - indiv_mels.append(m.T) - indiv_mels = np.asarray(indiv_mels) # T 80 16 - - ratio = generate_blink_seq_randomly(num_frames) # T - source_semantics_path = first_coeff_path - source_semantics_dict = scio.loadmat(source_semantics_path) - ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 - ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) - - if ref_eyeblink_coeff_path is not None: - ratio[:num_frames] = 0 - refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) - refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] - refeyeblink_num_frames = refeyeblink_coeff.shape[0] - if refeyeblink_num_frames frame_num: - new_degree_list = new_degree_list[:frame_num] - elif len(new_degree_list) < frame_num: - for _ in range(frame_num-len(new_degree_list)): - new_degree_list.append(new_degree_list[-1]) - print(len(new_degree_list)) - print(frame_num) - - remainder = frame_num%batch_size - if remainder!=0: - for _ in range(batch_size-remainder): - new_degree_list.append(new_degree_list[-1]) - new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) - return new_degree_np - diff --git a/sadtalker_audio2pose/src/gradio_demo.py b/sadtalker_audio2pose/src/gradio_demo.py deleted file mode 100644 index 9a2399fc44704b544ef39bb908d32a21da9fae17..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/gradio_demo.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch, uuid -import os, sys, shutil, platform -from src.facerender.pirender_animate import AnimateFromCoeff_PIRender -from src.utils.preprocess import CropAndExtract -from src.test_audio2coeff import Audio2Coeff -from src.facerender.animate import AnimateFromCoeff -from src.generate_batch import get_data -from src.generate_facerender_batch import get_facerender_data - -from src.utils.init_path import init_path - -from pydub import AudioSegment - - -def mp3_to_wav(mp3_filename,wav_filename,frame_rate): - mp3_file = AudioSegment.from_file(file=mp3_filename) - mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav") - - -class SadTalker(): - - def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False): - - if torch.cuda.is_available(): - device = "cuda" - elif platform.system() == 'Darwin': # macos - device = "mps" - else: - device = "cpu" - - self.device = device - - os.environ['TORCH_HOME']= checkpoint_path - - self.checkpoint_path = checkpoint_path - self.config_path = config_path - - - def test(self, source_image, driven_audio, preprocess='crop', - still_mode=False, use_enhancer=False, batch_size=1, size=256, - pose_style = 0, - facerender='facevid2vid', - exp_scale=1.0, - use_ref_video = False, - ref_video = None, - ref_info = None, - use_idle_mode = False, - length_of_audio = 0, use_blink=True, - result_dir='./results/'): - - self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess) - print(self.sadtalker_paths) - - self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device) - self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device) - - if facerender == 'facevid2vid' and self.device != 'mps': - self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device) - elif facerender == 'pirender' or self.device == 'mps': - self.animate_from_coeff = AnimateFromCoeff_PIRender(self.sadtalker_paths, self.device) - facerender = 'pirender' - else: - raise(RuntimeError('Unknown model: {}'.format(facerender))) - - - time_tag = str(uuid.uuid4()) - save_dir = os.path.join(result_dir, time_tag) - os.makedirs(save_dir, exist_ok=True) - - input_dir = os.path.join(save_dir, 'input') - os.makedirs(input_dir, exist_ok=True) - - print(source_image) - pic_path = os.path.join(input_dir, os.path.basename(source_image)) - shutil.move(source_image, input_dir) - - if driven_audio is not None and os.path.isfile(driven_audio): - audio_path = os.path.join(input_dir, os.path.basename(driven_audio)) - - #### mp3 to wav - if '.mp3' in audio_path: - mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000) - audio_path = audio_path.replace('.mp3', '.wav') - else: - shutil.move(driven_audio, input_dir) - - elif use_idle_mode: - audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path - from pydub import AudioSegment - one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds - one_sec_segment.export(audio_path, format="wav") - else: - print(use_ref_video, ref_info) - assert use_ref_video == True and ref_info == 'all' - - if use_ref_video and ref_info == 'all': # full ref mode - ref_video_videoname = os.path.basename(ref_video) - audio_path = os.path.join(save_dir, ref_video_videoname+'.wav') - print('new audiopath:',audio_path) - # if ref_video contains audio, set the audio from ref_video. - cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path) - os.system(cmd) - - os.makedirs(save_dir, exist_ok=True) - - #crop image and extract 3dmm from image - first_frame_dir = os.path.join(save_dir, 'first_frame_dir') - os.makedirs(first_frame_dir, exist_ok=True) - first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size) - - if first_coeff_path is None: - raise AttributeError("No face is detected") - - if use_ref_video: - print('using ref video for genreation') - ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0] - ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname) - os.makedirs(ref_video_frame_dir, exist_ok=True) - print('3DMM Extraction for the reference video providing pose') - ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False) - else: - ref_video_coeff_path = None - - if use_ref_video: - if ref_info == 'pose': - ref_pose_coeff_path = ref_video_coeff_path - ref_eyeblink_coeff_path = None - elif ref_info == 'blink': - ref_pose_coeff_path = None - ref_eyeblink_coeff_path = ref_video_coeff_path - elif ref_info == 'pose+blink': - ref_pose_coeff_path = ref_video_coeff_path - ref_eyeblink_coeff_path = ref_video_coeff_path - elif ref_info == 'all': - ref_pose_coeff_path = None - ref_eyeblink_coeff_path = None - else: - raise('error in refinfo') - else: - ref_pose_coeff_path = None - ref_eyeblink_coeff_path = None - - #audio2ceoff - if use_ref_video and ref_info == 'all': - coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) - else: - batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, \ - idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio? - coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) - - #coeff2video - data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, \ - preprocess=preprocess, size=size, expression_scale = exp_scale, facemodel=facerender) - return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size) - video_name = data['video_name'] - print(f'The generated video is named {video_name} in {save_dir}') - - del self.preprocess_model - del self.audio_to_coeff - del self.animate_from_coeff - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - import gc; gc.collect() - - return return_path - - \ No newline at end of file diff --git a/sadtalker_audio2pose/src/test_audio2coeff.py b/sadtalker_audio2pose/src/test_audio2coeff.py deleted file mode 100644 index d0f5ca9195bbc980c93fa3e37c6d06cc32953aee..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/test_audio2coeff.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -import torch -import numpy as np -from scipy.io import savemat, loadmat -from yacs.config import CfgNode as CN -from scipy.signal import savgol_filter - -import safetensors -import safetensors.torch - -from src.audio2pose_models.audio2pose import Audio2Pose -from src.audio2exp_models.networks import SimpleWrapperV2 -from src.audio2exp_models.audio2exp import Audio2Exp -from src.utils.safetensor_helper import load_x_from_safetensor - -def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if model is not None: - model.load_state_dict(checkpoint['model']) - if optimizer is not None: - optimizer.load_state_dict(checkpoint['optimizer']) - - return checkpoint['epoch'] - -class Audio2Coeff(): - - def __init__(self, sadtalker_path, device): - #load config - fcfg_pose = open(sadtalker_path['audio2pose_yaml_path']) - cfg_pose = CN.load_cfg(fcfg_pose) - cfg_pose.freeze() - fcfg_exp = open(sadtalker_path['audio2exp_yaml_path']) - cfg_exp = CN.load_cfg(fcfg_exp) - cfg_exp.freeze() - - # load audio2pose_model - self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device) - self.audio2pose_model = self.audio2pose_model.to(device) - self.audio2pose_model.eval() - for param in self.audio2pose_model.parameters(): - param.requires_grad = False - - try: - if sadtalker_path['use_safetensor']: - checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) - self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose')) - else: - load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device) - except: - raise Exception("Failed in loading audio2pose_checkpoint") - - # load audio2exp_model - netG = SimpleWrapperV2() - netG = netG.to(device) - for param in netG.parameters(): - netG.requires_grad = False - netG.eval() - try: - if sadtalker_path['use_safetensor']: - checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) - netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp')) - else: - load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device) - except: - raise Exception("Failed in loading audio2exp_checkpoint") - self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False) - self.audio2exp_model = self.audio2exp_model.to(device) - for param in self.audio2exp_model.parameters(): - param.requires_grad = False - self.audio2exp_model.eval() - - self.device = device - - def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None): - - with torch.no_grad(): - #test - results_dict_exp= self.audio2exp_model.test(batch) - exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64 - - #for class_id in range(1): - #class_id = 0#(i+10)%45 - #class_id = random.randint(0,46) #46 styles can be selected - batch['class'] = torch.LongTensor([pose_style]).to(self.device) - results_dict_pose = self.audio2pose_model.test(batch) - pose_pred = results_dict_pose['pose_pred'] #bs T 6 - - pose_len = pose_pred.shape[1] - if pose_len<13: - pose_len = int((pose_len-1)/2)*2+1 - pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device) - else: - pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) - - coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70 - - coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() - - if ref_pose_coeff_path is not None: - coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path) - - savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])), - {'coeff_3dmm': coeffs_pred_numpy}) - - return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])) - - def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path): - num_frames = coeffs_pred_numpy.shape[0] - refpose_coeff_dict = loadmat(ref_pose_coeff_path) - refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70] - refpose_num_frames = refpose_coeff.shape[0] - if refpose_num_frames= 0 - if hp.symmetric_mels: - return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value - else: - return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) - -def _denormalize(D): - if hp.allow_clipping_in_normalization: - if hp.symmetric_mels: - return (((np.clip(D, -hp.max_abs_value, - hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) - + hp.min_level_db) - else: - return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) - - if hp.symmetric_mels: - return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) - else: - return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/sadtalker_audio2pose/src/utils/croper.py b/sadtalker_audio2pose/src/utils/croper.py deleted file mode 100644 index 578372debdb8d2b99fe93d3d2ba2dfacf7cbb0ad..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/croper.py +++ /dev/null @@ -1,145 +0,0 @@ -import os -import cv2 -import time -import glob -import argparse -import scipy -import numpy as np -from PIL import Image -import torch -from tqdm import tqdm -from itertools import cycle - -from src.face3d.extract_kp_videos_safe import KeypointExtractor -from facexlib.alignment import landmark_98_to_68 - -import numpy as np -from PIL import Image - -class Preprocesser: - def __init__(self, device='cuda'): - self.predictor = KeypointExtractor(device) - - def get_landmark(self, img_np): - """get landmark with dlib - :return: np.array shape=(68, 2) - """ - with torch.no_grad(): - dets = self.predictor.det_net.detect_faces(img_np, 0.97) - - if len(dets) == 0: - return None - det = dets[0] - - img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :] - lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0] - - #### keypoints to the original location - lm[:,0] += int(det[0]) - lm[:,1] += int(det[1]) - - return lm - - def align_face(self, img, lm, output_size=1024): - """ - :param filepath: str - :return: PIL Image - """ - lm_chin = lm[0: 17] # left-right - lm_eyebrow_left = lm[17: 22] # left-right - lm_eyebrow_right = lm[22: 27] # left-right - lm_nose = lm[27: 31] # top-down - lm_nostrils = lm[31: 36] # top-down - lm_eye_left = lm[36: 42] # left-clockwise - lm_eye_right = lm[42: 48] # left-clockwise - lm_mouth_outer = lm[48: 60] # left-clockwise - lm_mouth_inner = lm[60: 68] # left-clockwise - - # Calculate auxiliary vectors. - eye_left = np.mean(lm_eye_left, axis=0) - eye_right = np.mean(lm_eye_right, axis=0) - eye_avg = (eye_left + eye_right) * 0.5 - eye_to_eye = eye_right - eye_left - mouth_left = lm_mouth_outer[0] - mouth_right = lm_mouth_outer[6] - mouth_avg = (mouth_left + mouth_right) * 0.5 - eye_to_mouth = mouth_avg - eye_avg - - # Choose oriented crop rectangle. - x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference - x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化 - x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度 - y = np.flipud(x) * [-1, 1] - c = eye_avg + eye_to_mouth * 0.1 - quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点 - qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍 - - # Shrink. - # 如果计算出的四边形太大了,就按比例缩小它 - shrink = int(np.floor(qsize / output_size * 0.5)) - if shrink > 1: - rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) - img = img.resize(rsize, Image.ANTIALIAS) - quad /= shrink - qsize /= shrink - else: - rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1])))) - - # Crop. - border = max(int(np.rint(qsize * 0.1)), 3) - crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), - int(np.ceil(max(quad[:, 1])))) - crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), - min(crop[3] + border, img.size[1])) - if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: - # img = img.crop(crop) - quad -= crop[0:2] - - # Pad. - pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), - int(np.ceil(max(quad[:, 1])))) - pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), - max(pad[3] - img.size[1] + border, 0)) - # if enable_padding and max(pad) > border - 4: - # pad = np.maximum(pad, int(np.rint(qsize * 0.3))) - # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') - # h, w, _ = img.shape - # y, x, _ = np.ogrid[:h, :w, :1] - # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), - # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) - # blur = qsize * 0.02 - # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) - # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) - # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') - # quad += pad[:2] - - # Transform. - quad = (quad + 0.5).flatten() - lx = max(min(quad[0], quad[2]), 0) - ly = max(min(quad[1], quad[7]), 0) - rx = min(max(quad[4], quad[6]), img.size[0]) - ry = min(max(quad[3], quad[5]), img.size[0]) - - # Save aligned image. - return rsize, crop, [lx, ly, rx, ry] - - def crop(self, img_np_list, still=False, xsize=512): # first frame for all video - # print(img_np_list) - img_np = img_np_list[0] - lm = self.get_landmark(img_np) - - if lm is None: - raise 'can not detect the landmark from source image' - rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - for _i in range(len(img_np_list)): - _inp = img_np_list[_i] - _inp = cv2.resize(_inp, (rsize[0], rsize[1])) - _inp = _inp[cly:cry, clx:crx] - if not still: - _inp = _inp[ly:ry, lx:rx] - img_np_list[_i] = _inp - return img_np_list, crop, quad - diff --git a/sadtalker_audio2pose/src/utils/face_enhancer.py b/sadtalker_audio2pose/src/utils/face_enhancer.py deleted file mode 100644 index 2664560a1d7199e81f1a50093f29d02de91d4bcc..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/face_enhancer.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -import torch - -from gfpgan import GFPGANer - -from tqdm import tqdm - -from src.utils.videoio import load_video_to_cv2 - -import cv2 - - -class GeneratorWithLen(object): - """ From https://stackoverflow.com/a/7460929 """ - - def __init__(self, gen, length): - self.gen = gen - self.length = length - - def __len__(self): - return self.length - - def __iter__(self): - return self.gen - -def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): - gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) - return list(gen) - -def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): - """ Provide a generator with a __len__ method so that it can passed to functions that - call len()""" - - if os.path.isfile(images): # handle video to images - # TODO: Create a generator version of load_video_to_cv2 - images = load_video_to_cv2(images) - - gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) - gen_with_len = GeneratorWithLen(gen, len(images)) - return gen_with_len - -def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): - """ Provide a generator function so that all of the enhanced images don't need - to be stored in memory at the same time. This can save tons of RAM compared to - the enhancer function. """ - - print('face enhancer....') - if not isinstance(images, list) and os.path.isfile(images): # handle video to images - images = load_video_to_cv2(images) - - # ------------------------ set up GFPGAN restorer ------------------------ - if method == 'gfpgan': - arch = 'clean' - channel_multiplier = 2 - model_name = 'GFPGANv1.4' - url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' - elif method == 'RestoreFormer': - arch = 'RestoreFormer' - channel_multiplier = 2 - model_name = 'RestoreFormer' - url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' - elif method == 'codeformer': # TODO: - arch = 'CodeFormer' - channel_multiplier = 2 - model_name = 'CodeFormer' - url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' - else: - raise ValueError(f'Wrong model version {method}.') - - - # ------------------------ set up background upsampler ------------------------ - if bg_upsampler == 'realesrgan': - if not torch.cuda.is_available(): # CPU - import warnings - warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' - 'If you really want to use it, please modify the corresponding codes.') - bg_upsampler = None - else: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - bg_upsampler = RealESRGANer( - scale=2, - model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', - model=model, - tile=400, - tile_pad=10, - pre_pad=0, - half=True) # need to set False in CPU mode - else: - bg_upsampler = None - - # determine model paths - model_path = os.path.join('gfpgan/weights', model_name + '.pth') - - if not os.path.isfile(model_path): - model_path = os.path.join('checkpoints', model_name + '.pth') - - if not os.path.isfile(model_path): - # download pre-trained models from url - model_path = url - - restorer = GFPGANer( - model_path=model_path, - upscale=2, - arch=arch, - channel_multiplier=channel_multiplier, - bg_upsampler=bg_upsampler) - - # ------------------------ restore ------------------------ - for idx in tqdm(range(len(images)), 'Face Enhancer:'): - - img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) - - # restore faces and background if necessary - cropped_faces, restored_faces, r_img = restorer.enhance( - img, - has_aligned=False, - only_center_face=False, - paste_back=True) - - r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) - yield r_img diff --git a/sadtalker_audio2pose/src/utils/flow_util.py b/sadtalker_audio2pose/src/utils/flow_util.py deleted file mode 100644 index f25046bab67cc8fbbb59efd02f48d7b6f22fc580..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/flow_util.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -import sys - - -def convert_flow_to_deformation(flow): - r"""convert flow fields to deformations. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - deformation (tensor): The deformation used for warpping - """ - b,c,h,w = flow.shape - flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) - grid = make_coordinate_grid(flow) - # print(grid.shape, flow_norm.shape) - deformation = grid + flow_norm.permute(0,2,3,1) - return deformation - -def make_coordinate_grid(flow): - r"""obtain coordinate grid with the same size as the flow filed. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - grid (tensor): The grid with the same size as the input flow - """ - b,c,h,w = flow.shape - - x = torch.arange(w).to(flow) - y = torch.arange(h).to(flow) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - - yy = y.view(-1, 1).repeat(1, w) - xx = x.view(1, -1).repeat(h, 1) - - meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) - meshed = meshed.expand(b, -1, -1, -1) - return meshed - - -def warp_image(source_image, deformation): - r"""warp the input image according to the deformation - - Args: - source_image (tensor): source images to be warpped - deformation (tensor): deformations used to warp the images; value in range (-1, 1) - Returns: - output (tensor): the warpped images - """ - _, h_old, w_old, _ = deformation.shape - _, _, h, w = source_image.shape - if h_old != h or w_old != w: - deformation = deformation.permute(0, 3, 1, 2) - deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') - deformation = deformation.permute(0, 2, 3, 1) - return torch.nn.functional.grid_sample(source_image, deformation) - - - -# visualize flow -import numpy as np - -__all__ = ['load_flow', 'save_flow', 'vis_flow'] - - -def load_flow(path): - with open(path, 'rb') as f: - magic = float(np.fromfile(f, np.float32, count=1)[0]) - if magic == 202021.25: - w, h = np.fromfile(f, np.int32, count=1)[0], np.fromfile(f, np.int32, count=1)[0] - data = np.fromfile(f, np.float32, count=h * w * 2) - data.resize((h, w, 2)) - return data - return None - - -def save_flow(path, flow): - magic = np.array([202021.25], np.float32) - h, w = flow.shape[:2] - h, w = np.array([h], np.int32), np.array([w], np.int32) - - with open(path, 'wb') as f: - magic.tofile(f) - w.tofile(f) - h.tofile(f) - flow.tofile(f) - - - -def makeColorwheel(): - # color encoding scheme - - # adapted from the color circle idea described at - # http://members.shaw.ca/quadibloc/other/colint.htm - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - - colorwheel = np.zeros([ncols, 3]) # r g b - - col = 0 - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY, 1) / RY) - col += RY - - # YG - colorwheel[col:YG + col, 0] = 255 - np.floor(255 * np.arange(0, YG, 1) / YG) - colorwheel[col:YG + col, 1] = 255 - col += YG - - # GC - colorwheel[col:GC + col, 1] = 255 - colorwheel[col:GC + col, 2] = np.floor(255 * np.arange(0, GC, 1) / GC) - col += GC - - # CB - colorwheel[col:CB + col, 1] = 255 - np.floor(255 * np.arange(0, CB, 1) / CB) - colorwheel[col:CB + col, 2] = 255 - col += CB - - # BM - colorwheel[col:BM + col, 2] = 255 - colorwheel[col:BM + col, 0] = np.floor(255 * np.arange(0, BM, 1) / BM) - col += BM - - # MR - colorwheel[col:MR + col, 2] = 255 - np.floor(255 * np.arange(0, MR, 1) / MR) - colorwheel[col:MR + col, 0] = 255 - return colorwheel - - -def computeColor(u, v): - colorwheel = makeColorwheel() - nan_u = np.isnan(u) - nan_v = np.isnan(v) - nan_u = np.where(nan_u) - nan_v = np.where(nan_v) - - u[nan_u] = 0 - u[nan_v] = 0 - v[nan_u] = 0 - v[nan_v] = 0 - - ncols = colorwheel.shape[0] - radius = np.sqrt(u ** 2 + v ** 2) - a = np.arctan2(-v, -u) / np.pi - fk = (a + 1) / 2 * (ncols - 1) # -1~1 maped to 1~ncols - k0 = fk.astype(np.uint8) # 1, 2, ..., ncols - k1 = k0 + 1 - k1[k1 == ncols] = 0 - f = fk - k0 - - img = np.empty([k1.shape[0], k1.shape[1], 3]) - ncolors = colorwheel.shape[1] - for i in range(ncolors): - tmp = colorwheel[:, i] - col0 = tmp[k0] / 255 - col1 = tmp[k1] / 255 - col = (1 - f) * col0 + f * col1 - idx = radius <= 1 - col[idx] = 1 - radius[idx] * (1 - col[idx]) # increase saturation with radius - col[~idx] *= 0.75 # out of range - img[:, :, 2 - i] = np.floor(255 * col).astype(np.uint8) - - return img.astype(np.uint8) - - -def vis_flow(flow): - eps = sys.float_info.epsilon - UNKNOWN_FLOW_THRESH = 1e9 - UNKNOWN_FLOW = 1e10 - - u = flow[:, :, 0] - v = flow[:, :, 1] - - maxu = -999 - maxv = -999 - - minu = 999 - minv = 999 - - maxrad = -1 - # fix unknown flow - greater_u = np.where(u > UNKNOWN_FLOW_THRESH) - greater_v = np.where(v > UNKNOWN_FLOW_THRESH) - u[greater_u] = 0 - u[greater_v] = 0 - v[greater_u] = 0 - v[greater_v] = 0 - - maxu = max([maxu, np.amax(u)]) - minu = min([minu, np.amin(u)]) - - maxv = max([maxv, np.amax(v)]) - minv = min([minv, np.amin(v)]) - rad = np.sqrt(np.multiply(u, u) + np.multiply(v, v)) - maxrad = max([maxrad, np.amax(rad)]) - # print('max flow: %.4f flow range: u = %.3f .. %.3f; v = %.3f .. %.3f\n' % (maxrad, minu, maxu, minv, maxv)) - - u = u / (maxrad + eps) - v = v / (maxrad + eps) - img = computeColor(u, v) - return img[:, :, [2, 1, 0]] - - -def test_visualize_flow(): - flow = load_flow('out.flo') - img = vis_flow(flow) - - import cv2 - cv2.imwrite("img.png", img) diff --git a/sadtalker_audio2pose/src/utils/hparams.py b/sadtalker_audio2pose/src/utils/hparams.py deleted file mode 100644 index 83c312d767c35b9adc988157243efc02129fdb84..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/hparams.py +++ /dev/null @@ -1,160 +0,0 @@ -from glob import glob -import os - -class HParams: - def __init__(self, **kwargs): - self.data = {} - - for key, value in kwargs.items(): - self.data[key] = value - - def __getattr__(self, key): - if key not in self.data: - raise AttributeError("'HParams' object has no attribute %s" % key) - return self.data[key] - - def set_hparam(self, key, value): - self.data[key] = value - - -# Default hyperparameters -hparams = HParams( - num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality - # network - rescale=True, # Whether to rescale audio prior to preprocessing - rescaling_max=0.9, # Rescaling value - - # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction - # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder - # Does not work if n_ffit is not multiple of hop_size!! - use_lws=False, - - n_fft=800, # Extra window size is filled with 0 paddings to match this parameter - hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) - win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) - sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) - - frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) - - # Mel and Linear spectrograms normalization/scaling and clipping - signal_normalization=True, - # Whether to normalize mel spectrograms to some predefined range (following below parameters) - allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True - symmetric_mels=True, - # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, - # faster and cleaner convergence) - max_abs_value=4., - # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not - # be too big to avoid gradient explosion, - # not too small for fast convergence) - # Contribution by @begeekmyfriend - # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude - # levels. Also allows for better G&L phase reconstruction) - preemphasize=True, # whether to apply filter - preemphasis=0.97, # filter coefficient. - - # Limits - min_level_db=-100, - ref_level_db=20, - fmin=55, - # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To - # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) - fmax=7600, # To be increased/reduced depending on data. - - ###################### Our training parameters ################################# - img_size=96, - fps=25, - - batch_size=16, - initial_learning_rate=1e-4, - nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs - num_workers=20, - checkpoint_interval=3000, - eval_interval=3000, - writer_interval=300, - save_optimizer_state=True, - - syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. - syncnet_batch_size=64, - syncnet_lr=1e-4, - syncnet_eval_interval=1000, - syncnet_checkpoint_interval=10000, - - disc_wt=0.07, - disc_initial_learning_rate=1e-4, -) - - - -# Default hyperparameters -hparamsdebug = HParams( - num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality - # network - rescale=True, # Whether to rescale audio prior to preprocessing - rescaling_max=0.9, # Rescaling value - - # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction - # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder - # Does not work if n_ffit is not multiple of hop_size!! - use_lws=False, - - n_fft=800, # Extra window size is filled with 0 paddings to match this parameter - hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) - win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) - sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) - - frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) - - # Mel and Linear spectrograms normalization/scaling and clipping - signal_normalization=True, - # Whether to normalize mel spectrograms to some predefined range (following below parameters) - allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True - symmetric_mels=True, - # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, - # faster and cleaner convergence) - max_abs_value=4., - # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not - # be too big to avoid gradient explosion, - # not too small for fast convergence) - # Contribution by @begeekmyfriend - # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude - # levels. Also allows for better G&L phase reconstruction) - preemphasize=True, # whether to apply filter - preemphasis=0.97, # filter coefficient. - - # Limits - min_level_db=-100, - ref_level_db=20, - fmin=55, - # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To - # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) - fmax=7600, # To be increased/reduced depending on data. - - ###################### Our training parameters ################################# - img_size=96, - fps=25, - - batch_size=2, - initial_learning_rate=1e-3, - nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs - num_workers=0, - checkpoint_interval=10000, - eval_interval=10, - writer_interval=5, - save_optimizer_state=True, - - syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. - syncnet_batch_size=64, - syncnet_lr=1e-4, - syncnet_eval_interval=10000, - syncnet_checkpoint_interval=10000, - - disc_wt=0.07, - disc_initial_learning_rate=1e-4, -) - - -def hparams_debug_string(): - values = hparams.values() - hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] - return "Hyperparameters:\n" + "\n".join(hp) diff --git a/sadtalker_audio2pose/src/utils/init_path.py b/sadtalker_audio2pose/src/utils/init_path.py deleted file mode 100644 index 65239fe3281798b2472f7ca0557a96157d9de930..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/init_path.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import glob - -def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): - - if old_version: - #### load all the checkpoint of `pth` - sadtalker_paths = { - 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), - 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), - 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), - 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), - 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') - } - - use_safetensor = False - elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): - print('using safetensor as default') - sadtalker_paths = { - "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), - } - use_safetensor = True - else: - print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") - use_safetensor = False - - sadtalker_paths = { - 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), - 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), - 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), - 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), - 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') - } - - sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' - sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') - sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') - sadtalker_paths['pirender_yaml_path'] = os.path.join(config_dir, 'facerender_pirender.yaml') - sadtalker_paths['pirender_checkpoint'] = os.path.join(checkpoint_dir, 'epoch_00190_iteration_000400000_checkpoint.pt') - sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') - - if 'full' in preprocess: - sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') - sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') - else: - sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') - sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') - - return sadtalker_paths \ No newline at end of file diff --git a/sadtalker_audio2pose/src/utils/model2safetensor.py b/sadtalker_audio2pose/src/utils/model2safetensor.py deleted file mode 100644 index c5b76e3d67a06fdbf6646590d44b8c225bc73d79..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/model2safetensor.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import yaml -import os - -import safetensors -from safetensors.torch import save_file -from yacs.config import CfgNode as CN -import sys - -sys.path.append('/apdcephfs/private_shadowcun/SadTalker') - -from src.face3d.models import networks - -from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector -from src.facerender.modules.mapping import MappingNet -from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator - -from src.audio2pose_models.audio2pose import Audio2Pose -from src.audio2exp_models.networks import SimpleWrapperV2 -from src.test_audio2coeff import load_cpk - -size = 256 -############ face vid2vid -config_path = os.path.join('src', 'config', 'facerender.yaml') -current_root_path = '.' - -path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') -net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='') -checkpoint = torch.load(path_of_net_recon_model, map_location='cpu') -net_recon.load_state_dict(checkpoint['net_recon']) - -with open(config_path) as f: - config = yaml.safe_load(f) - -generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], - **config['model_params']['common_params']) -kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], - **config['model_params']['common_params']) -he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], - **config['model_params']['common_params']) -mapping = MappingNet(**config['model_params']['mapping_params']) - -def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, - kp_detector=None, he_estimator=None, optimizer_generator=None, - optimizer_discriminator=None, optimizer_kp_detector=None, - optimizer_he_estimator=None, device="cpu"): - - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if generator is not None: - generator.load_state_dict(checkpoint['generator']) - if kp_detector is not None: - kp_detector.load_state_dict(checkpoint['kp_detector']) - if he_estimator is not None: - he_estimator.load_state_dict(checkpoint['he_estimator']) - if discriminator is not None: - try: - discriminator.load_state_dict(checkpoint['discriminator']) - except: - print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') - if optimizer_generator is not None: - optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) - if optimizer_discriminator is not None: - try: - optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) - except RuntimeError as e: - print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') - if optimizer_kp_detector is not None: - optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) - if optimizer_he_estimator is not None: - optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) - - return checkpoint['epoch'] - - -def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, - kp_detector=None, he_estimator=None, - device="cpu"): - - checkpoint = safetensors.torch.load_file(checkpoint_path) - - if generator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'generator' in k: - x_generator[k.replace('generator.', '')] = v - generator.load_state_dict(x_generator) - if kp_detector is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'kp_extractor' in k: - x_generator[k.replace('kp_extractor.', '')] = v - kp_detector.load_state_dict(x_generator) - if he_estimator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'he_estimator' in k: - x_generator[k.replace('he_estimator.', '')] = v - he_estimator.load_state_dict(x_generator) - - return None - -free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar' -load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) - -wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') - -audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') -audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') - -audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') -audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') - -fcfg_pose = open(audio2pose_yaml_path) -cfg_pose = CN.load_cfg(fcfg_pose) -cfg_pose.freeze() -audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint) -audio2pose_model.eval() -load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu') - -# load audio2exp_model -netG = SimpleWrapperV2() -netG.eval() -load_cpk(audio2exp_checkpoint, model=netG, device='cpu') - -class SadTalker(torch.nn.Module): - def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): - super(SadTalker, self).__init__() - self.kp_extractor = kp_extractor - self.generator = generator - self.audio2exp = netG - self.audio2pose = audio2pose - self.face_3drecon = face_3drecon - - -model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon) - -# here, we want to convert it to safetensor -save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") - -### test -load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) \ No newline at end of file diff --git a/sadtalker_audio2pose/src/utils/paste_pic.py b/sadtalker_audio2pose/src/utils/paste_pic.py deleted file mode 100644 index 4da8952e6933698fec6c7cf35042cb5b1f0dcba5..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/paste_pic.py +++ /dev/null @@ -1,69 +0,0 @@ -import cv2, os -import numpy as np -from tqdm import tqdm -import uuid - -from src.utils.videoio import save_video_with_watermark - -def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False): - - if not os.path.isfile(pic_path): - raise ValueError('pic_path must be a valid path to video/image file') - elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_img = cv2.imread(pic_path) - else: - # loader for videos - video_stream = cv2.VideoCapture(pic_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - break - full_img = frame - frame_h = full_img.shape[0] - frame_w = full_img.shape[1] - - video_stream = cv2.VideoCapture(video_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - crop_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - crop_frames.append(frame) - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - tmp_path = str(uuid.uuid4())+'.mp4' - out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) - for crop_frame in tqdm(crop_frames, 'seamlessClone:'): - p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) - - mask = 255*np.ones(p.shape, p.dtype) - location = ((ox1+ox2) // 2, (oy1+oy2) // 2) - gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE) - out_tmp.write(gen_img) - - out_tmp.release() - - save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False) - os.remove(tmp_path) diff --git a/sadtalker_audio2pose/src/utils/preprocess.py b/sadtalker_audio2pose/src/utils/preprocess.py deleted file mode 100644 index 4956c00d273467f8a0c020312401158b06c4fecd..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/preprocess.py +++ /dev/null @@ -1,170 +0,0 @@ -import numpy as np -import cv2, os, sys, torch -from tqdm import tqdm -from PIL import Image - -# 3dmm extraction -import safetensors -import safetensors.torch -from src.face3d.util.preprocess import align_img -from src.face3d.util.load_mats import load_lm3d -from src.face3d.models import networks - -from scipy.io import loadmat, savemat -from src.utils.croper import Preprocesser - - -import warnings - -from src.utils.safetensor_helper import load_x_from_safetensor -warnings.filterwarnings("ignore") - -def split_coeff(coeffs): - """ - Return: - coeffs_dict -- a dict of torch.tensors - - Parameters: - coeffs -- torch.tensor, size (B, 256) - """ - id_coeffs = coeffs[:, :80] - exp_coeffs = coeffs[:, 80: 144] - tex_coeffs = coeffs[:, 144: 224] - angles = coeffs[:, 224: 227] - gammas = coeffs[:, 227: 254] - translations = coeffs[:, 254:] - return { - 'id': id_coeffs, - 'exp': exp_coeffs, - 'tex': tex_coeffs, - 'angle': angles, - 'gamma': gammas, - 'trans': translations - } - - -class CropAndExtract(): - def __init__(self, sadtalker_path, device): - - self.propress = Preprocesser(device) - self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) - - if sadtalker_path['use_safetensor']: - checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) - self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) - else: - checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) - self.net_recon.load_state_dict(checkpoint['net_recon']) - - self.net_recon.eval() - self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) - self.device = device - - def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256): - - pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] - - landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') - coeff_path = os.path.join(save_dir, pic_name+'.mat') - png_path = os.path.join(save_dir, pic_name+'.png') - - #load input - if not os.path.isfile(input_path): - raise ValueError('input_path must be a valid path to video/image file') - elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_frames = [cv2.imread(input_path)] - fps = 25 - else: - # loader for videos - video_stream = cv2.VideoCapture(input_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - full_frames.append(frame) - if source_image_flag: - break - - x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] - - #### crop images as the - if 'crop' in crop_or_resize.lower(): # default crop - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - elif 'full' in crop_or_resize.lower(): - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - else: # resize mode - oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] - crop_info = ((ox2 - ox1, oy2 - oy1), None, None) - - frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] - if len(frames_pil) == 0: - print('No face is detected in the input file') - return None, None - - # save crop info - for frame in frames_pil: - cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) - - # 2. get the landmark according to the detected face. - if not os.path.isfile(landmarks_path): - lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) - else: - print(' Using saved landmarks.') - lm = np.loadtxt(landmarks_path).astype(np.float32) - lm = lm.reshape([len(x_full_frames), -1, 2]) - - if not os.path.isfile(coeff_path): - # load 3dmm paramter generator from Deep3DFaceRecon_pytorch - video_coeffs, full_coeffs = [], [] - for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): - frame = frames_pil[idx] - W,H = frame.size - lm1 = lm[idx].reshape([-1, 2]) - - if np.mean(lm1) == -1: - lm1 = (self.lm3d_std[:, :2]+1)/2. - lm1 = np.concatenate( - [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 - ) - else: - lm1[:, -1] = H - 1 - lm1[:, -1] - - trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) - - trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32) - im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) - - with torch.no_grad(): - full_coeff = self.net_recon(im_t) - coeffs = split_coeff(full_coeff) - - pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} - - pred_coeff = np.concatenate([ - pred_coeff['exp'], - pred_coeff['angle'], - pred_coeff['trans'], - trans_params_m[2:][None], - ], 1) - video_coeffs.append(pred_coeff) - full_coeffs.append(full_coeff.cpu().numpy()) - - semantic_npy = np.array(video_coeffs)[:,0] - - savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params}) - - return coeff_path, png_path, crop_info diff --git a/sadtalker_audio2pose/src/utils/preprocess_fromvideo.py b/sadtalker_audio2pose/src/utils/preprocess_fromvideo.py deleted file mode 100644 index 6c4aad3eef558934d9974c3170b658cff88f568c..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/preprocess_fromvideo.py +++ /dev/null @@ -1,195 +0,0 @@ -import numpy as np -import cv2, os, sys, torch -from tqdm import tqdm -from PIL import Image - -# 3dmm extraction -import safetensors -import safetensors.torch -from src.face3d.util.preprocess import align_img -from src.face3d.util.load_mats import load_lm3d -from src.face3d.models import networks - -from scipy.io import loadmat, savemat -from src.utils.croper import Preprocesser - - -import warnings - -from src.utils.safetensor_helper import load_x_from_safetensor -warnings.filterwarnings("ignore") - - -def smooth_3dmm_params(params, window_size=5): - # 创建一个新的数组来存储平滑后的参数 - smoothed_params = np.zeros_like(params) - - # 对每个参数进行平滑处理 - for i in range(params.shape[1]): - - # 在参数周围创建一个滑动窗口 - window = np.ones(int(window_size))/float(window_size) - smoothed_param = np.convolve(params[:, i], window, 'same') - - # 将平滑后的参数存储在新数组中 - smoothed_params[:, i] = smoothed_param - - return smoothed_params - - - -def split_coeff(coeffs): - """ - Return: - coeffs_dict -- a dict of torch.tensors - - Parameters: - coeffs -- torch.tensor, size (B, 256) - """ - id_coeffs = coeffs[:, :80] - exp_coeffs = coeffs[:, 80: 144] - tex_coeffs = coeffs[:, 144: 224] - angles = coeffs[:, 224: 227] - gammas = coeffs[:, 227: 254] - translations = coeffs[:, 254:] - return { - 'id': id_coeffs, - 'exp': exp_coeffs, - 'tex': tex_coeffs, - 'angle': angles, - 'gamma': gammas, - 'trans': translations - } - - -class CropAndExtract(): - def __init__(self, sadtalker_path, device): - - self.propress = Preprocesser(device) - self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) - - if sadtalker_path['use_safetensor']: - checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) - self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) - else: - checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) - self.net_recon.load_state_dict(checkpoint['net_recon']) - - self.net_recon.eval() - self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) - self.device = device - - def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256, if_smooth=False): - - pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] - - landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') - coeff_path = os.path.join(save_dir, pic_name+'.mat') - png_path = os.path.join(save_dir, pic_name+'.png') - - #load input - if not os.path.isfile(input_path): - raise ValueError('input_path must be a valid path to video/image file') - elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_frames = [cv2.imread(input_path)] - fps = 25 - else: - # loader for videos - video_stream = cv2.VideoCapture(input_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - full_frames.append(frame) - if source_image_flag: - break - - x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] - - # print(x_full_frames) - - #### crop images as the - if 'crop' in crop_or_resize.lower(): # default crop - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - elif 'full' in crop_or_resize.lower(): - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - else: # resize mode - oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] - crop_info = ((ox2 - ox1, oy2 - oy1), None, None) - - frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] - if len(frames_pil) == 0: - print('No face is detected in the input file') - return None, None - - # save crop info - for frame in frames_pil: - cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) - - # 2. get the landmark according to the detected face. - if not os.path.isfile(landmarks_path): - lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) - else: - print(' Using saved landmarks.') - lm = np.loadtxt(landmarks_path).astype(np.float32) - lm = lm.reshape([len(x_full_frames), -1, 2]) - - if not os.path.isfile(coeff_path): - # load 3dmm paramter generator from Deep3DFaceRecon_pytorch - video_coeffs, full_coeffs = [], [] - for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): - frame = frames_pil[idx] - W,H = frame.size - lm1 = lm[idx].reshape([-1, 2]) - - if np.mean(lm1) == -1: - lm1 = (self.lm3d_std[:, :2]+1)/2. - lm1 = np.concatenate( - [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 - ) - else: - lm1[:, -1] = H - 1 - lm1[:, -1] - - trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) - - trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32) - im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) - - with torch.no_grad(): - full_coeff = self.net_recon(im_t) - coeffs = split_coeff(full_coeff) - - pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} - - pred_coeff = np.concatenate([ - pred_coeff['exp'], - pred_coeff['angle'], - pred_coeff['trans'], - # trans_params_m[2:][None], - ], 1) - video_coeffs.append(pred_coeff) - full_coeffs.append(full_coeff.cpu().numpy()) - - semantic_npy = np.array(video_coeffs)[:,0] - - if if_smooth: - # pass - semantic_npy[:, -6:] = smooth_3dmm_params(semantic_npy[:, -6:], window_size=10) - - savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params}) - - return coeff_path, png_path, crop_info diff --git a/sadtalker_audio2pose/src/utils/safetensor_helper.py b/sadtalker_audio2pose/src/utils/safetensor_helper.py deleted file mode 100644 index 164ed9621eba24e0b3050ca663fcb60123517158..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/safetensor_helper.py +++ /dev/null @@ -1,8 +0,0 @@ - - -def load_x_from_safetensor(checkpoint, key): - x_generator = {} - for k,v in checkpoint.items(): - if key in k: - x_generator[k.replace(key+'.', '')] = v - return x_generator \ No newline at end of file diff --git a/sadtalker_audio2pose/src/utils/text2speech.py b/sadtalker_audio2pose/src/utils/text2speech.py deleted file mode 100644 index a0fe21daf74fcd01767b17378b7076c9dd424248..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/text2speech.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import tempfile -from TTS.api import TTS - - -class TTSTalker(): - def __init__(self) -> None: - model_name = TTS.list_models()[0] - self.tts = TTS(model_name) - - def test(self, text, language='en'): - - tempf = tempfile.NamedTemporaryFile( - delete = False, - suffix = ('.'+'wav'), - ) - - self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name) - - return tempf.name \ No newline at end of file diff --git a/sadtalker_audio2pose/src/utils/videoio.py b/sadtalker_audio2pose/src/utils/videoio.py deleted file mode 100644 index d604ae5b098006f3e59cf3c0133779ffd1cc9d5a..0000000000000000000000000000000000000000 --- a/sadtalker_audio2pose/src/utils/videoio.py +++ /dev/null @@ -1,41 +0,0 @@ -import shutil -import uuid - -import os - -import cv2 - -def load_video_to_cv2(input_path): - video_stream = cv2.VideoCapture(input_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - return full_frames - -def save_video_with_watermark(video, audio, save_path, watermark=False): - temp_file = str(uuid.uuid4())+'.mp4' - cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec mpeg4 "%s"' % (video, audio, temp_file) - os.system(cmd) - - if watermark is False: - shutil.move(temp_file, save_path) - else: - # watermark - try: - ##### check if stable-diffusion-webui - import webui - from modules import paths - watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png" - except: - # get the root path of sadtalker. - dir_path = os.path.dirname(os.path.realpath(__file__)) - watarmark_path = dir_path+"/../../docs/sadtalker_logo.png" - - cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) - os.system(cmd) - os.remove(temp_file) \ No newline at end of file diff --git a/sadtalker_video2pose/.DS_Store b/sadtalker_video2pose/.DS_Store deleted file mode 100644 index 9b7747e8985cba181c6477ae341433be1ca71030..0000000000000000000000000000000000000000 Binary files a/sadtalker_video2pose/.DS_Store and /dev/null differ diff --git a/sadtalker_video2pose/inference.py b/sadtalker_video2pose/inference.py deleted file mode 100644 index 7167fdb44be0261ed35b2fa25f39e327190c62bf..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/inference.py +++ /dev/null @@ -1,170 +0,0 @@ -from glob import glob -import shutil -import torch -from time import strftime -import os, sys, time -from argparse import ArgumentParser -import platform -import scipy -import numpy as np - -# from src.utils.preprocess import CropAndExtract -from src.utils.preprocess_fromvideo import CropAndExtract -from src.test_audio2coeff import Audio2Coeff -from src.facerender.animate import AnimateFromCoeff -from src.facerender.pirender_animate import AnimateFromCoeff_PIRender -from src.generate_batch import get_data -from src.generate_facerender_batch import get_facerender_data -from src.utils.init_path import init_path - - -def main(args): - #torch.backends.cudnn.enabled = False - - - - # args.facerender = 'pirender' - - - - pic_path = args.source_image - # audio_path = args.driven_audio - save_dir = args.result_dir - os.makedirs(save_dir, exist_ok=True) - pose_style = args.pose_style - device = args.device - batch_size = args.batch_size - input_yaw_list = args.input_yaw - input_pitch_list = args.input_pitch - input_roll_list = args.input_roll - ref_eyeblink = args.ref_eyeblink - ref_pose = args.ref_pose - - current_root_path = os.path.split(sys.argv[0])[0] - - sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess) - - #init model - preprocess_model = CropAndExtract(sadtalker_paths, device) - - audio_to_coeff = Audio2Coeff(sadtalker_paths, device) - - if args.facerender == 'facevid2vid': - animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) - elif args.facerender == 'pirender': - animate_from_coeff = AnimateFromCoeff_PIRender(sadtalker_paths, device) - else: - raise(RuntimeError('Unknown model: {}'.format(args.facerender))) - - #crop image and extract 3dmm from image - first_frame_dir = os.path.join(save_dir, 'first_frame_dir') - os.makedirs(first_frame_dir, exist_ok=True) - print('3DMM Extraction for source image') - first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\ - source_image_flag=True, pic_size=args.size) - if first_coeff_path is None: - print("Can't get the coeffs of the input") - return - - if ref_eyeblink is not None: - ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0] - ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname) - os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) - print('3DMM Extraction for the reference video providing eye blinking') - ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False) - else: - ref_eyeblink_coeff_path=None - - if ref_pose is not None: - if ref_pose == ref_eyeblink: - ref_pose_coeff_path = ref_eyeblink_coeff_path - else: - ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] - ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname) - os.makedirs(ref_pose_frame_dir, exist_ok=True) - print('3DMM Extraction for the reference video providing pose') - # print(ref_pose) - ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False, if_smooth=True) - else: - ref_pose_coeff_path=None - - # #audio2ceoff - # batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still) - # coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) - - # print(ref_pose_coeff_path) - # print(coeff_path) - - # coeff_pred_video = scipy.io.loadmat(ref_pose_coeff_path)['coeff_3dmm'] - # coeff_pred = scipy.io.loadmat(coeff_path)['coeff_3dmm'] - - # print(coeff_pred_video.shape) - # print(coeff_pred.shape) - - coeff_path = ref_pose_coeff_path - # coeff_path = smooth_3dmm_params(ref_pose_coeff_path, window_size=3) - - - - # assert False - - # 3dface render - if args.face3dvis: - from src.face3d.visualize_fromvideo import gen_composed_video - gen_composed_video(args, device, first_coeff_path, coeff_path, \ - os.path.join(save_dir, '3dface.mp4'), os.path.join(save_dir, 'landmarks.mp4'), crop_info, extended_crop= True if 'ext' in args.preprocess else False ) - return - - -if __name__ == '__main__': - - parser = ArgumentParser() - # parser.add_argument("--driven_audio", default='./sadtalker_video2pose/dummy/bus_chinese.wav', help="path to driven audio") - parser.add_argument("--source_image", default='./examples/source_image/full_body_1.png', help="path to source image") - parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking") - parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose") - parser.add_argument("--checkpoint_dir", default='./ckpts/sad_talker', help="path to output") - parser.add_argument("--result_dir", default='./results', help="path to output") - parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)") - parser.add_argument("--batch_size", type=int, default=1, help="the batch size of facerender") - parser.add_argument("--size", type=int, default=256, help="the image size of the facerender") - parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender") - parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ") - parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user") - parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user") - parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]") - parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]") - parser.add_argument("--cpu", dest="cpu", action="store_true") - parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks") - parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion") - parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" ) - parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" ) - parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" ) - parser.add_argument("--facerender", default='facevid2vid', choices=['pirender', 'facevid2vid'] ) - - - # net structure and parameters - parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless') - parser.add_argument('--init_path', type=str, default=None, help='Useless') - parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc') - parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talker/BFM_Fitting/') - parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') - - # default renderer parameters - parser.add_argument('--focal', type=float, default=1015.) - parser.add_argument('--center', type=float, default=112.) - parser.add_argument('--camera_d', type=float, default=10.) - parser.add_argument('--z_near', type=float, default=5.) - parser.add_argument('--z_far', type=float, default=15.) - - args = parser.parse_args() - - if torch.cuda.is_available() and not args.cpu: - args.device = "cuda" - elif platform.system() == 'Darwin' and args.facerender == 'pirender': # macos - args.device = "mps" - else: - args.device = "cpu" - - main(args) - diff --git a/sadtalker_video2pose/src/.DS_Store b/sadtalker_video2pose/src/.DS_Store deleted file mode 100644 index 0f8fa60ef73513c0a5ddb5161310a66031c28262..0000000000000000000000000000000000000000 Binary files a/sadtalker_video2pose/src/.DS_Store and /dev/null differ diff --git a/sadtalker_video2pose/src/audio2exp_models/audio2exp.py b/sadtalker_video2pose/src/audio2exp_models/audio2exp.py deleted file mode 100644 index e1062ab6684df01e0b3c48b6b577cc8df0503c91..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2exp_models/audio2exp.py +++ /dev/null @@ -1,41 +0,0 @@ -from tqdm import tqdm -import torch -from torch import nn - - -class Audio2Exp(nn.Module): - def __init__(self, netG, cfg, device, prepare_training_loss=False): - super(Audio2Exp, self).__init__() - self.cfg = cfg - self.device = device - self.netG = netG.to(device) - - def test(self, batch): - - mel_input = batch['indiv_mels'] # bs T 1 80 16 - bs = mel_input.shape[0] - T = mel_input.shape[1] - - exp_coeff_pred = [] - - for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames - - current_mel_input = mel_input[:,i:i+10] - - #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 - ref = batch['ref'][:, :, :64][:, i:i+10] - ratio = batch['ratio_gt'][:, i:i+10] #bs T - - audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 - - curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 - - exp_coeff_pred += [curr_exp_coeff_pred] - - # BS x T x 64 - results_dict = { - 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) - } - return results_dict - - diff --git a/sadtalker_video2pose/src/audio2exp_models/networks.py b/sadtalker_video2pose/src/audio2exp_models/networks.py deleted file mode 100644 index cd77a2f48d7c00ce85fe2eefe3a3e820730fbb74..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2exp_models/networks.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - -class Conv2d(nn.Module): - def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): - super().__init__(*args, **kwargs) - self.conv_block = nn.Sequential( - nn.Conv2d(cin, cout, kernel_size, stride, padding), - nn.BatchNorm2d(cout) - ) - self.act = nn.ReLU() - self.residual = residual - self.use_act = use_act - - def forward(self, x): - out = self.conv_block(x) - if self.residual: - out += x - - if self.use_act: - return self.act(out) - else: - return out - -class SimpleWrapperV2(nn.Module): - def __init__(self) -> None: - super().__init__() - self.audio_encoder = nn.Sequential( - Conv2d(1, 32, kernel_size=3, stride=1, padding=1), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(64, 128, kernel_size=3, stride=3, padding=1), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(256, 512, kernel_size=3, stride=1, padding=0), - Conv2d(512, 512, kernel_size=1, stride=1, padding=0), - ) - - #### load the pre-trained audio_encoder - #self.audio_encoder = self.audio_encoder.to(device) - ''' - wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] - state_dict = self.audio_encoder.state_dict() - - for k,v in wav2lip_state_dict.items(): - if 'audio_encoder' in k: - print('init:', k) - state_dict[k.replace('module.audio_encoder.', '')] = v - self.audio_encoder.load_state_dict(state_dict) - ''' - - self.mapping1 = nn.Linear(512+64+1, 64) - #self.mapping2 = nn.Linear(30, 64) - #nn.init.constant_(self.mapping1.weight, 0.) - nn.init.constant_(self.mapping1.bias, 0.) - - def forward(self, x, ref, ratio): - x = self.audio_encoder(x).view(x.size(0), -1) - ref_reshape = ref.reshape(x.size(0), -1) - ratio = ratio.reshape(x.size(0), -1) - - y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) - out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial - return out diff --git a/sadtalker_video2pose/src/audio2pose_models/audio2pose.py b/sadtalker_video2pose/src/audio2pose_models/audio2pose.py deleted file mode 100644 index 53883adc508037294ba664d05d34e5459f1879f8..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2pose_models/audio2pose.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from torch import nn -from src.audio2pose_models.cvae import CVAE -from src.audio2pose_models.discriminator import PoseSequenceDiscriminator -from src.audio2pose_models.audio_encoder import AudioEncoder - -class Audio2Pose(nn.Module): - def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): - super().__init__() - self.cfg = cfg - self.seq_len = cfg.MODEL.CVAE.SEQ_LEN - self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE - self.device = device - - self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) - self.audio_encoder.eval() - for param in self.audio_encoder.parameters(): - param.requires_grad = False - - self.netG = CVAE(cfg) - self.netD_motion = PoseSequenceDiscriminator(cfg) - - - def forward(self, x): - - batch = {} - coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 - batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6 - batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6 - batch['class'] = x['class'].squeeze(0).cuda() # bs - indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 - - # forward - audio_emb_list = [] - audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 - batch['audio_emb'] = audio_emb - batch = self.netG(batch) - - pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 - pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6 - pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6 - - batch['pose_pred'] = pose_pred - batch['pose_gt'] = pose_gt - - return batch - - def test(self, x): - - batch = {} - ref = x['ref'] #bs 1 70 - batch['ref'] = x['ref'][:,0,-6:] - batch['class'] = x['class'] - bs = ref.shape[0] - - indiv_mels= x['indiv_mels'] # bs T 1 80 16 - indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame - num_frames = x['num_frames'] - num_frames = int(num_frames) - 1 - - # - div = num_frames//self.seq_len - re = num_frames%self.seq_len - audio_emb_list = [] - pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, - device=batch['ref'].device)] - - for i in range(div): - z = torch.randn(bs, self.latent_dim).to(ref.device) - batch['z'] = z - audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 - batch['audio_emb'] = audio_emb - batch = self.netG.test(batch) - pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 - - if re != 0: - z = torch.randn(bs, self.latent_dim).to(ref.device) - batch['z'] = z - audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 - if audio_emb.shape[1] != self.seq_len: - pad_dim = self.seq_len-audio_emb.shape[1] - pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) - audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) - batch['audio_emb'] = audio_emb - batch = self.netG.test(batch) - pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) - - pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) - batch['pose_motion_pred'] = pose_motion_pred - - pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 - - batch['pose_pred'] = pose_pred - return batch diff --git a/sadtalker_video2pose/src/audio2pose_models/audio_encoder.py b/sadtalker_video2pose/src/audio2pose_models/audio_encoder.py deleted file mode 100644 index a0c165afbc25910cb66828d8676973fe727cb3a3..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2pose_models/audio_encoder.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -class Conv2d(nn.Module): - def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): - super().__init__(*args, **kwargs) - self.conv_block = nn.Sequential( - nn.Conv2d(cin, cout, kernel_size, stride, padding), - nn.BatchNorm2d(cout) - ) - self.act = nn.ReLU() - self.residual = residual - - def forward(self, x): - out = self.conv_block(x) - if self.residual: - out += x - return self.act(out) - -class AudioEncoder(nn.Module): - def __init__(self, wav2lip_checkpoint, device): - super(AudioEncoder, self).__init__() - - self.audio_encoder = nn.Sequential( - Conv2d(1, 32, kernel_size=3, stride=1, padding=1), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(64, 128, kernel_size=3, stride=3, padding=1), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(256, 512, kernel_size=3, stride=1, padding=0), - Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) - - #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. - # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] - # state_dict = self.audio_encoder.state_dict() - - # for k,v in wav2lip_state_dict.items(): - # if 'audio_encoder' in k: - # state_dict[k.replace('module.audio_encoder.', '')] = v - # self.audio_encoder.load_state_dict(state_dict) - - - def forward(self, audio_sequences): - # audio_sequences = (B, T, 1, 80, 16) - B = audio_sequences.size(0) - - audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) - - audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 - dim = audio_embedding.shape[1] - audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) - - return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 diff --git a/sadtalker_video2pose/src/audio2pose_models/cvae.py b/sadtalker_video2pose/src/audio2pose_models/cvae.py deleted file mode 100644 index 407b78894cde564dd3f2819772a84e8bb1de251d..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2pose_models/cvae.py +++ /dev/null @@ -1,149 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn -from src.audio2pose_models.res_unet import ResUnet - -def class2onehot(idx, class_num): - - assert torch.max(idx).item() < class_num - onehot = torch.zeros(idx.size(0), class_num).to(idx.device) - onehot.scatter_(1, idx, 1) - return onehot - -class CVAE(nn.Module): - def __init__(self, cfg): - super().__init__() - encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES - decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES - latent_size = cfg.MODEL.CVAE.LATENT_SIZE - num_classes = cfg.DATASET.NUM_CLASSES - audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE - audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE - seq_len = cfg.MODEL.CVAE.SEQ_LEN - - self.latent_size = latent_size - - self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len) - self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len) - def reparameterize(self, mu, logvar): - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - - def forward(self, batch): - batch = self.encoder(batch) - mu = batch['mu'] - logvar = batch['logvar'] - z = self.reparameterize(mu, logvar) - batch['z'] = z - return self.decoder(batch) - - def test(self, batch): - ''' - class_id = batch['class'] - z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device) - batch['z'] = z - ''' - return self.decoder(batch) - -class ENCODER(nn.Module): - def __init__(self, layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len): - super().__init__() - - self.resunet = ResUnet() - self.num_classes = num_classes - self.seq_len = seq_len - - self.MLP = nn.Sequential() - layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6 - for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): - self.MLP.add_module( - name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) - self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) - - self.linear_means = nn.Linear(layer_sizes[-1], latent_size) - self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size) - self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) - - self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) - - def forward(self, batch): - class_id = batch['class'] - pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6 - ref = batch['ref'] #bs 6 - bs = pose_motion_gt.shape[0] - audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size - - #pose encode - pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 - pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 - - #audio mapping - print(audio_in.shape) - audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size - audio_out = audio_out.reshape(bs, -1) - - class_bias = self.classbias[class_id] #bs latent_size - x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size - x_out = self.MLP(x_in) - - mu = self.linear_means(x_out) - logvar = self.linear_means(x_out) #bs latent_size - - batch.update({'mu':mu, 'logvar':logvar}) - return batch - -class DECODER(nn.Module): - def __init__(self, layer_sizes, latent_size, num_classes, - audio_emb_in_size, audio_emb_out_size, seq_len): - super().__init__() - - self.resunet = ResUnet() - self.num_classes = num_classes - self.seq_len = seq_len - - self.MLP = nn.Sequential() - input_size = latent_size + seq_len*audio_emb_out_size + 6 - for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): - self.MLP.add_module( - name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) - if i+1 < len(layer_sizes): - self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) - else: - self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) - - self.pose_linear = nn.Linear(6, 6) - self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) - - self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) - - def forward(self, batch): - - z = batch['z'] #bs latent_size - bs = z.shape[0] - class_id = batch['class'] - ref = batch['ref'] #bs 6 - audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size - #print('audio_in: ', audio_in[:, :, :10]) - - audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size - #print('audio_out: ', audio_out[:, :, :10]) - audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size - class_bias = self.classbias[class_id] #bs latent_size - - z = z + class_bias - x_in = torch.cat([ref, z, audio_out], dim=-1) - x_out = self.MLP(x_in) # bs layer_sizes[-1] - x_out = x_out.reshape((bs, self.seq_len, -1)) - - #print('x_out: ', x_out) - - pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 - - pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 - - batch.update({'pose_motion_pred':pose_motion_pred}) - return batch diff --git a/sadtalker_video2pose/src/audio2pose_models/discriminator.py b/sadtalker_video2pose/src/audio2pose_models/discriminator.py deleted file mode 100644 index 2f8ed6e36708d4a70227ff90109f56c6f73a17d2..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2pose_models/discriminator.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - -class ConvNormRelu(nn.Module): - def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, - kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): - super().__init__() - if kernel_size is None: - if downsample: - kernel_size, stride, padding = 4, 2, 1 - else: - kernel_size, stride, padding = 3, 1, 1 - - if conv_type == '2d': - self.conv = nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - bias=False, - ) - if norm == 'BN': - self.norm = nn.BatchNorm2d(out_channels) - elif norm == 'IN': - self.norm = nn.InstanceNorm2d(out_channels) - else: - raise NotImplementedError - elif conv_type == '1d': - self.conv = nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - bias=False, - ) - if norm == 'BN': - self.norm = nn.BatchNorm1d(out_channels) - elif norm == 'IN': - self.norm = nn.InstanceNorm1d(out_channels) - else: - raise NotImplementedError - nn.init.kaiming_normal_(self.conv.weight) - - self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) - - def forward(self, x): - x = self.conv(x) - if isinstance(self.norm, nn.InstanceNorm1d): - x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] - else: - x = self.norm(x) - x = self.act(x) - return x - - -class PoseSequenceDiscriminator(nn.Module): - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU - - self.seq = nn.Sequential( - ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 - ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 - ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 - nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 - ) - - def forward(self, x): - x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) - x = self.seq(x) - x = x.squeeze(1) - return x \ No newline at end of file diff --git a/sadtalker_video2pose/src/audio2pose_models/networks.py b/sadtalker_video2pose/src/audio2pose_models/networks.py deleted file mode 100644 index 9212b49836d9221895993d1d490a476707599922..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2pose_models/networks.py +++ /dev/null @@ -1,140 +0,0 @@ -import torch.nn as nn -import torch - - -class ResidualConv(nn.Module): - def __init__(self, input_dim, output_dim, stride, padding): - super(ResidualConv, self).__init__() - - self.conv_block = nn.Sequential( - nn.BatchNorm2d(input_dim), - nn.ReLU(), - nn.Conv2d( - input_dim, output_dim, kernel_size=3, stride=stride, padding=padding - ), - nn.BatchNorm2d(output_dim), - nn.ReLU(), - nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), - ) - self.conv_skip = nn.Sequential( - nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), - nn.BatchNorm2d(output_dim), - ) - - def forward(self, x): - - return self.conv_block(x) + self.conv_skip(x) - - -class Upsample(nn.Module): - def __init__(self, input_dim, output_dim, kernel, stride): - super(Upsample, self).__init__() - - self.upsample = nn.ConvTranspose2d( - input_dim, output_dim, kernel_size=kernel, stride=stride - ) - - def forward(self, x): - return self.upsample(x) - - -class Squeeze_Excite_Block(nn.Module): - def __init__(self, channel, reduction=16): - super(Squeeze_Excite_Block, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction, bias=False), - nn.ReLU(inplace=True), - nn.Linear(channel // reduction, channel, bias=False), - nn.Sigmoid(), - ) - - def forward(self, x): - b, c, _, _ = x.size() - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - return x * y.expand_as(x) - - -class ASPP(nn.Module): - def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): - super(ASPP, self).__init__() - - self.aspp_block1 = nn.Sequential( - nn.Conv2d( - in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] - ), - nn.ReLU(inplace=True), - nn.BatchNorm2d(out_dims), - ) - self.aspp_block2 = nn.Sequential( - nn.Conv2d( - in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] - ), - nn.ReLU(inplace=True), - nn.BatchNorm2d(out_dims), - ) - self.aspp_block3 = nn.Sequential( - nn.Conv2d( - in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] - ), - nn.ReLU(inplace=True), - nn.BatchNorm2d(out_dims), - ) - - self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) - self._init_weights() - - def forward(self, x): - x1 = self.aspp_block1(x) - x2 = self.aspp_block2(x) - x3 = self.aspp_block3(x) - out = torch.cat([x1, x2, x3], dim=1) - return self.output(out) - - def _init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - -class Upsample_(nn.Module): - def __init__(self, scale=2): - super(Upsample_, self).__init__() - - self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) - - def forward(self, x): - return self.upsample(x) - - -class AttentionBlock(nn.Module): - def __init__(self, input_encoder, input_decoder, output_dim): - super(AttentionBlock, self).__init__() - - self.conv_encoder = nn.Sequential( - nn.BatchNorm2d(input_encoder), - nn.ReLU(), - nn.Conv2d(input_encoder, output_dim, 3, padding=1), - nn.MaxPool2d(2, 2), - ) - - self.conv_decoder = nn.Sequential( - nn.BatchNorm2d(input_decoder), - nn.ReLU(), - nn.Conv2d(input_decoder, output_dim, 3, padding=1), - ) - - self.conv_attn = nn.Sequential( - nn.BatchNorm2d(output_dim), - nn.ReLU(), - nn.Conv2d(output_dim, 1, 1), - ) - - def forward(self, x1, x2): - out = self.conv_encoder(x1) + self.conv_decoder(x2) - out = self.conv_attn(out) - return out * x2 \ No newline at end of file diff --git a/sadtalker_video2pose/src/audio2pose_models/res_unet.py b/sadtalker_video2pose/src/audio2pose_models/res_unet.py deleted file mode 100644 index 280404c2a2804038705f792dd800ddf707b75cf8..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/audio2pose_models/res_unet.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch.nn as nn -from src.audio2pose_models.networks import ResidualConv, Upsample - - -class ResUnet(nn.Module): - def __init__(self, channel=1, filters=[32, 64, 128, 256]): - super(ResUnet, self).__init__() - - self.input_layer = nn.Sequential( - nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), - nn.BatchNorm2d(filters[0]), - nn.ReLU(), - nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), - ) - self.input_skip = nn.Sequential( - nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) - ) - - self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) - self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) - - self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) - - self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) - self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) - - self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) - self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) - - self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) - self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) - - self.output_layer = nn.Sequential( - nn.Conv2d(filters[0], 1, 1, 1), - nn.Sigmoid(), - ) - - def forward(self, x): - # Encode - x1 = self.input_layer(x) + self.input_skip(x) - x2 = self.residual_conv_1(x1) - x3 = self.residual_conv_2(x2) - # Bridge - x4 = self.bridge(x3) - - # Decode - x4 = self.upsample_1(x4) - x5 = torch.cat([x4, x3], dim=1) - - x6 = self.up_residual_conv1(x5) - - x6 = self.upsample_2(x6) - x7 = torch.cat([x6, x2], dim=1) - - x8 = self.up_residual_conv2(x7) - - x8 = self.upsample_3(x8) - x9 = torch.cat([x8, x1], dim=1) - - x10 = self.up_residual_conv3(x9) - - output = self.output_layer(x10) - - return output \ No newline at end of file diff --git a/sadtalker_video2pose/src/config/auido2exp.yaml b/sadtalker_video2pose/src/config/auido2exp.yaml deleted file mode 100644 index 7e0e8fbba267158d26a147c8cb2ec5acdd73f432..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/config/auido2exp.yaml +++ /dev/null @@ -1,58 +0,0 @@ -DATASET: - TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt - EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt - TRAIN_BATCH_SIZE: 32 - EVAL_BATCH_SIZE: 32 - EXP: True - EXP_DIM: 64 - FRAME_LEN: 32 - COEFF_LEN: 73 - NUM_CLASSES: 46 - AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav - COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm - LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb - DEBUG: True - NUM_REPEATS: 2 - T: 40 - - -MODEL: - FRAMEWORK: V2 - AUDIOENCODER: - LEAKY_RELU: True - NORM: 'IN' - DISCRIMINATOR: - LEAKY_RELU: False - INPUT_CHANNELS: 6 - CVAE: - AUDIO_EMB_IN_SIZE: 512 - AUDIO_EMB_OUT_SIZE: 128 - SEQ_LEN: 32 - LATENT_SIZE: 256 - ENCODER_LAYER_SIZES: [192, 1024] - DECODER_LAYER_SIZES: [1024, 192] - - -TRAIN: - MAX_EPOCH: 300 - GENERATOR: - LR: 2.0e-5 - DISCRIMINATOR: - LR: 1.0e-5 - LOSS: - W_FEAT: 0 - W_COEFF_EXP: 2 - W_LM: 1.0e-2 - W_LM_MOUTH: 0 - W_REG: 0 - W_SYNC: 0 - W_COLOR: 0 - W_EXPRESSION: 0 - W_LIPREADING: 0.01 - W_LIPREADING_VV: 0 - W_EYE_BLINK: 4 - -TAG: - NAME: small_dataset - - diff --git a/sadtalker_video2pose/src/config/auido2pose.yaml b/sadtalker_video2pose/src/config/auido2pose.yaml deleted file mode 100644 index 7702414b11581ff99aef7a3187f0d0d1388ae3f3..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/config/auido2pose.yaml +++ /dev/null @@ -1,49 +0,0 @@ -DATASET: - TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt - EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt - TRAIN_BATCH_SIZE: 64 - EVAL_BATCH_SIZE: 1 - EXP: True - EXP_DIM: 64 - FRAME_LEN: 32 - COEFF_LEN: 73 - NUM_CLASSES: 46 - AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav - COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb - DEBUG: True - - -MODEL: - AUDIOENCODER: - LEAKY_RELU: True - NORM: 'IN' - DISCRIMINATOR: - LEAKY_RELU: False - INPUT_CHANNELS: 6 - CVAE: - AUDIO_EMB_IN_SIZE: 512 - AUDIO_EMB_OUT_SIZE: 6 - SEQ_LEN: 32 - LATENT_SIZE: 64 - ENCODER_LAYER_SIZES: [192, 128] - DECODER_LAYER_SIZES: [128, 192] - - -TRAIN: - MAX_EPOCH: 150 - GENERATOR: - LR: 1.0e-4 - DISCRIMINATOR: - LR: 1.0e-4 - LOSS: - LAMBDA_REG: 1 - LAMBDA_LANDMARKS: 0 - LAMBDA_VERTICES: 0 - LAMBDA_GAN_MOTION: 0.7 - LAMBDA_GAN_COEFF: 0 - LAMBDA_KL: 1 - -TAG: - NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder - - diff --git a/sadtalker_video2pose/src/config/facerender.yaml b/sadtalker_video2pose/src/config/facerender.yaml deleted file mode 100644 index dd1e1ddfe265698e49dac4a6e103cba0aac4f3ce..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/config/facerender.yaml +++ /dev/null @@ -1,45 +0,0 @@ -model_params: - common_params: - num_kp: 15 - image_channel: 3 - feature_channel: 32 - estimate_jacobian: False # True - kp_detector_params: - temperature: 0.1 - block_expansion: 32 - max_features: 1024 - scale_factor: 0.25 # 0.25 - num_blocks: 5 - reshape_channel: 16384 # 16384 = 1024 * 16 - reshape_depth: 16 - he_estimator_params: - block_expansion: 64 - max_features: 2048 - num_bins: 66 - generator_params: - block_expansion: 64 - max_features: 512 - num_down_blocks: 2 - reshape_channel: 32 - reshape_depth: 16 # 512 = 32 * 16 - num_resblocks: 6 - estimate_occlusion_map: True - dense_motion_params: - block_expansion: 32 - max_features: 1024 - num_blocks: 5 - reshape_depth: 16 - compress: 4 - discriminator_params: - scales: [1] - block_expansion: 32 - max_features: 512 - num_blocks: 4 - sn: True - mapping_params: - coeff_nc: 70 - descriptor_nc: 1024 - layer: 3 - num_kp: 15 - num_bins: 66 - diff --git a/sadtalker_video2pose/src/config/facerender_pirender.yaml b/sadtalker_video2pose/src/config/facerender_pirender.yaml deleted file mode 100644 index f893b5d0a22f0546642c2d2bdafda88740c81138..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/config/facerender_pirender.yaml +++ /dev/null @@ -1,83 +0,0 @@ -# How often do you want to log the training stats. -# network_list: -# gen: gen_optimizer -# dis: dis_optimizer - -distributed: False -image_to_tensorboard: True -snapshot_save_iter: 40000 -snapshot_save_epoch: 20 -snapshot_save_start_iter: 20000 -snapshot_save_start_epoch: 10 -image_save_iter: 1000 -max_epoch: 200 -logging_iter: 100 -results_dir: ./eval_results - -gen_optimizer: - type: adam - lr: 0.0001 - adam_beta1: 0.5 - adam_beta2: 0.999 - lr_policy: - iteration_mode: True - type: step - step_size: 300000 - gamma: 0.2 - -trainer: - type: trainers.face_trainer::FaceTrainer - pretrain_warp_iteration: 200000 - loss_weight: - weight_perceptual_warp: 2.5 - weight_perceptual_final: 4 - vgg_param_warp: - network: vgg19 - layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] - use_style_loss: False - num_scales: 4 - vgg_param_final: - network: vgg19 - layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1'] - use_style_loss: True - num_scales: 4 - style_to_perceptual: 250 - init: - type: 'normal' - gain: 0.02 -gen: - type: generators.face_model::FaceGenerator - param: - mapping_net: - coeff_nc: 73 - descriptor_nc: 256 - layer: 3 - warpping_net: - encoder_layer: 5 - decoder_layer: 3 - base_nc: 32 - editing_net: - layer: 3 - num_res_blocks: 2 - base_nc: 64 - common: - image_nc: 3 - descriptor_nc: 256 - max_nc: 256 - use_spect: False - - -# Data options. -data: - type: data.vox_dataset::VoxDataset - path: ./dataset/vox_lmdb - resolution: 256 - semantic_radius: 13 - train: - batch_size: 5 - distributed: True - val: - batch_size: 8 - distributed: True - - diff --git a/sadtalker_video2pose/src/config/facerender_still.yaml b/sadtalker_video2pose/src/config/facerender_still.yaml deleted file mode 100644 index d6b84181763caf7184a0769e53a7e419e2e3f604..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/config/facerender_still.yaml +++ /dev/null @@ -1,45 +0,0 @@ -model_params: - common_params: - num_kp: 15 - image_channel: 3 - feature_channel: 32 - estimate_jacobian: False # True - kp_detector_params: - temperature: 0.1 - block_expansion: 32 - max_features: 1024 - scale_factor: 0.25 # 0.25 - num_blocks: 5 - reshape_channel: 16384 # 16384 = 1024 * 16 - reshape_depth: 16 - he_estimator_params: - block_expansion: 64 - max_features: 2048 - num_bins: 66 - generator_params: - block_expansion: 64 - max_features: 512 - num_down_blocks: 2 - reshape_channel: 32 - reshape_depth: 16 # 512 = 32 * 16 - num_resblocks: 6 - estimate_occlusion_map: True - dense_motion_params: - block_expansion: 32 - max_features: 1024 - num_blocks: 5 - reshape_depth: 16 - compress: 4 - discriminator_params: - scales: [1] - block_expansion: 32 - max_features: 512 - num_blocks: 4 - sn: True - mapping_params: - coeff_nc: 73 - descriptor_nc: 1024 - layer: 3 - num_kp: 15 - num_bins: 66 - diff --git a/sadtalker_video2pose/src/config/similarity_Lm3D_all.mat b/sadtalker_video2pose/src/config/similarity_Lm3D_all.mat deleted file mode 100644 index 9f5b0bd4ecffb926128a29cb1bbf9d9081c3d4e7..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/config/similarity_Lm3D_all.mat +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:53b83ce6e35c50ddc3e97603650cef4970320c157e75c241c844f29c1dcba65a -size 994 diff --git a/sadtalker_video2pose/src/face3d/data/__init__.py b/sadtalker_video2pose/src/face3d/data/__init__.py deleted file mode 100644 index be2378c5877af8e749db18d8a67a382f3eb0912b..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/data/__init__.py +++ /dev/null @@ -1,116 +0,0 @@ -"""This package includes all the modules related to data loading and preprocessing - - To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. - You need to implement four functions: - -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). - -- <__len__>: return the size of dataset. - -- <__getitem__>: get a data point from data loader. - -- : (optionally) add dataset-specific options and set default options. - -Now you can use the dataset class by specifying flag '--dataset_mode dummy'. -See our template dataset class 'template_dataset.py' for more details. -""" -import numpy as np -import importlib -import torch.utils.data -from face3d.data.base_dataset import BaseDataset - - -def find_dataset_using_name(dataset_name): - """Import the module "data/[dataset_name]_dataset.py". - - In the file, the class called DatasetNameDataset() will - be instantiated. It has to be a subclass of BaseDataset, - and it is case-insensitive. - """ - dataset_filename = "data." + dataset_name + "_dataset" - datasetlib = importlib.import_module(dataset_filename) - - dataset = None - target_dataset_name = dataset_name.replace('_', '') + 'dataset' - for name, cls in datasetlib.__dict__.items(): - if name.lower() == target_dataset_name.lower() \ - and issubclass(cls, BaseDataset): - dataset = cls - - if dataset is None: - raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) - - return dataset - - -def get_option_setter(dataset_name): - """Return the static method of the dataset class.""" - dataset_class = find_dataset_using_name(dataset_name) - return dataset_class.modify_commandline_options - - -def create_dataset(opt, rank=0): - """Create a dataset given the option. - - This function wraps the class CustomDatasetDataLoader. - This is the main interface between this package and 'train.py'/'test.py' - - Example: - >>> from data import create_dataset - >>> dataset = create_dataset(opt) - """ - data_loader = CustomDatasetDataLoader(opt, rank=rank) - dataset = data_loader.load_data() - return dataset - -class CustomDatasetDataLoader(): - """Wrapper class of Dataset class that performs multi-threaded data loading""" - - def __init__(self, opt, rank=0): - """Initialize this class - - Step 1: create a dataset instance given the name [dataset_mode] - Step 2: create a multi-threaded data loader. - """ - self.opt = opt - dataset_class = find_dataset_using_name(opt.dataset_mode) - self.dataset = dataset_class(opt) - self.sampler = None - print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) - if opt.use_ddp and opt.isTrain: - world_size = opt.world_size - self.sampler = torch.utils.data.distributed.DistributedSampler( - self.dataset, - num_replicas=world_size, - rank=rank, - shuffle=not opt.serial_batches - ) - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - sampler=self.sampler, - num_workers=int(opt.num_threads / world_size), - batch_size=int(opt.batch_size / world_size), - drop_last=True) - else: - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=opt.batch_size, - shuffle=(not opt.serial_batches) and opt.isTrain, - num_workers=int(opt.num_threads), - drop_last=True - ) - - def set_epoch(self, epoch): - self.dataset.current_epoch = epoch - if self.sampler is not None: - self.sampler.set_epoch(epoch) - - def load_data(self): - return self - - def __len__(self): - """Return the number of data in the dataset""" - return min(len(self.dataset), self.opt.max_dataset_size) - - def __iter__(self): - """Return a batch of data""" - for i, data in enumerate(self.dataloader): - if i * self.opt.batch_size >= self.opt.max_dataset_size: - break - yield data diff --git a/sadtalker_video2pose/src/face3d/data/base_dataset.py b/sadtalker_video2pose/src/face3d/data/base_dataset.py deleted file mode 100644 index 34a7ea5024206e6e58c2f404ac6a1bf0987f5fd4..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/data/base_dataset.py +++ /dev/null @@ -1,125 +0,0 @@ -"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. - -It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. -""" -import random -import numpy as np -import torch.utils.data as data -from PIL import Image -import torchvision.transforms as transforms -from abc import ABC, abstractmethod - - -class BaseDataset(data.Dataset, ABC): - """This class is an abstract base class (ABC) for datasets. - - To create a subclass, you need to implement the following four functions: - -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). - -- <__len__>: return the size of dataset. - -- <__getitem__>: get a data point. - -- : (optionally) add dataset-specific options and set default options. - """ - - def __init__(self, opt): - """Initialize the class; save the options in the class - - Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - self.opt = opt - # self.root = opt.dataroot - self.current_epoch = 0 - - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new dataset-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - return parser - - @abstractmethod - def __len__(self): - """Return the total number of images in the dataset.""" - return 0 - - @abstractmethod - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index - - a random integer for data indexing - - Returns: - a dictionary of data with their names. It ususally contains the data itself and its metadata information. - """ - pass - - -def get_transform(grayscale=False): - transform_list = [] - if grayscale: - transform_list.append(transforms.Grayscale(1)) - transform_list += [transforms.ToTensor()] - return transforms.Compose(transform_list) - -def get_affine_mat(opt, size): - shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False - w, h = size - - if 'shift' in opt.preprocess: - shift_pixs = int(opt.shift_pixs) - shift_x = random.randint(-shift_pixs, shift_pixs) - shift_y = random.randint(-shift_pixs, shift_pixs) - if 'scale' in opt.preprocess: - scale = 1 + opt.scale_delta * (2 * random.random() - 1) - if 'rot' in opt.preprocess: - rot_angle = opt.rot_angle * (2 * random.random() - 1) - rot_rad = -rot_angle * np.pi/180 - if 'flip' in opt.preprocess: - flip = random.random() > 0.5 - - shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) - flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) - shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) - rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) - scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) - shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) - - affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin - affine_inv = np.linalg.inv(affine) - return affine, affine_inv, flip - -def apply_img_affine(img, affine_inv, method=Image.BICUBIC): - return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) - -def apply_lm_affine(landmark, affine, flip, size): - _, h = size - lm = landmark.copy() - lm[:, 1] = h - 1 - lm[:, 1] - lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) - lm = lm @ np.transpose(affine) - lm[:, :2] = lm[:, :2] / lm[:, 2:] - lm = lm[:, :2] - lm[:, 1] = h - 1 - lm[:, 1] - if flip: - lm_ = lm.copy() - lm_[:17] = lm[16::-1] - lm_[17:22] = lm[26:21:-1] - lm_[22:27] = lm[21:16:-1] - lm_[31:36] = lm[35:30:-1] - lm_[36:40] = lm[45:41:-1] - lm_[40:42] = lm[47:45:-1] - lm_[42:46] = lm[39:35:-1] - lm_[46:48] = lm[41:39:-1] - lm_[48:55] = lm[54:47:-1] - lm_[55:60] = lm[59:54:-1] - lm_[60:65] = lm[64:59:-1] - lm_[65:68] = lm[67:64:-1] - lm = lm_ - return lm diff --git a/sadtalker_video2pose/src/face3d/data/flist_dataset.py b/sadtalker_video2pose/src/face3d/data/flist_dataset.py deleted file mode 100644 index 63b49caa8020f8e9aedb73a839b7112320cad68a..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/data/flist_dataset.py +++ /dev/null @@ -1,125 +0,0 @@ -"""This script defines the custom dataset for Deep3DFaceRecon_pytorch -""" - -import os.path -from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine -from data.image_folder import make_dataset -from PIL import Image -import random -import util.util as util -import numpy as np -import json -import torch -from scipy.io import loadmat, savemat -import pickle -from util.preprocess import align_img, estimate_norm -from util.load_mats import load_lm3d - - -def default_flist_reader(flist): - """ - flist format: impath label\nimpath label\n ...(same to caffe's filelist) - """ - imlist = [] - with open(flist, 'r') as rf: - for line in rf.readlines(): - impath = line.strip() - imlist.append(impath) - - return imlist - -def jason_flist_reader(flist): - with open(flist, 'r') as fp: - info = json.load(fp) - return info - -def parse_label(label): - return torch.tensor(np.array(label).astype(np.float32)) - - -class FlistDataset(BaseDataset): - """ - It requires one directories to host training images '/path/to/data/train' - You can train the model with the dataset flag '--dataroot /path/to/data'. - """ - - def __init__(self, opt): - """Initialize this dataset class. - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - BaseDataset.__init__(self, opt) - - self.lm3d_std = load_lm3d(opt.bfm_folder) - - msk_names = default_flist_reader(opt.flist) - self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] - - self.size = len(self.msk_paths) - self.opt = opt - - self.name = 'train' if opt.isTrain else 'val' - if '_' in opt.flist: - self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] - - - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index (int) -- a random integer for data indexing - - Returns a dictionary that contains A, B, A_paths and B_paths - img (tensor) -- an image in the input domain - msk (tensor) -- its corresponding attention mask - lm (tensor) -- its corresponding 3d landmarks - im_paths (str) -- image paths - aug_flag (bool) -- a flag used to tell whether its raw or augmented - """ - msk_path = self.msk_paths[index % self.size] # make sure index is within then range - img_path = msk_path.replace('mask/', '') - lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' - - raw_img = Image.open(img_path).convert('RGB') - raw_msk = Image.open(msk_path).convert('RGB') - raw_lm = np.loadtxt(lm_path).astype(np.float32) - - _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) - - aug_flag = self.opt.use_aug and self.opt.isTrain - if aug_flag: - img, lm, msk = self._augmentation(img, lm, self.opt, msk) - - _, H = img.size - M = estimate_norm(lm, H) - transform = get_transform() - img_tensor = transform(img) - msk_tensor = transform(msk)[:1, ...] - lm_tensor = parse_label(lm) - M_tensor = parse_label(M) - - - return {'imgs': img_tensor, - 'lms': lm_tensor, - 'msks': msk_tensor, - 'M': M_tensor, - 'im_paths': img_path, - 'aug_flag': aug_flag, - 'dataset': self.name} - - def _augmentation(self, img, lm, opt, msk=None): - affine, affine_inv, flip = get_affine_mat(opt, img.size) - img = apply_img_affine(img, affine_inv) - lm = apply_lm_affine(lm, affine, flip, img.size) - if msk is not None: - msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) - return img, lm, msk - - - - - def __len__(self): - """Return the total number of images in the dataset. - """ - return self.size diff --git a/sadtalker_video2pose/src/face3d/data/image_folder.py b/sadtalker_video2pose/src/face3d/data/image_folder.py deleted file mode 100644 index 07ef069029b0db1fc40b9b5f9a6f52a48c1cd162..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/data/image_folder.py +++ /dev/null @@ -1,66 +0,0 @@ -"""A modified image folder class - -We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) -so that this class can load images from both current directory and its subdirectories. -""" -import numpy as np -import torch.utils.data as data - -from PIL import Image -import os -import os.path - -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', - '.tif', '.TIF', '.tiff', '.TIFF', -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def make_dataset(dir, max_dataset_size=float("inf")): - images = [] - assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir - - for root, _, fnames in sorted(os.walk(dir, followlinks=True)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - return images[:min(max_dataset_size, len(images))] - - -def default_loader(path): - return Image.open(path).convert('RGB') - - -class ImageFolder(data.Dataset): - - def __init__(self, root, transform=None, return_paths=False, - loader=default_loader): - imgs = make_dataset(root) - if len(imgs) == 0: - raise(RuntimeError("Found 0 images in: " + root + "\n" - "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) - - self.root = root - self.imgs = imgs - self.transform = transform - self.return_paths = return_paths - self.loader = loader - - def __getitem__(self, index): - path = self.imgs[index] - img = self.loader(path) - if self.transform is not None: - img = self.transform(img) - if self.return_paths: - return img, path - else: - return img - - def __len__(self): - return len(self.imgs) diff --git a/sadtalker_video2pose/src/face3d/data/template_dataset.py b/sadtalker_video2pose/src/face3d/data/template_dataset.py deleted file mode 100644 index 693b6b09085ad424e53f26e0938b61eea30ed644..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/data/template_dataset.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Dataset class template - -This module provides a template for users to implement custom datasets. -You can specify '--dataset_mode template' to use this dataset. -The class name should be consistent with both the filename and its dataset_mode option. -The filename should be _dataset.py -The class name should be Dataset.py -You need to implement the following functions: - -- : Add dataset-specific options and rewrite default values for existing options. - -- <__init__>: Initialize this dataset class. - -- <__getitem__>: Return a data point and its metadata information. - -- <__len__>: Return the number of images. -""" -from data.base_dataset import BaseDataset, get_transform -# from data.image_folder import make_dataset -# from PIL import Image - - -class TemplateDataset(BaseDataset): - """A template dataset class for you to implement custom datasets.""" - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new dataset-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') - parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values - return parser - - def __init__(self, opt): - """Initialize this dataset class. - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - - A few things can be done here. - - save the options (have been done in BaseDataset) - - get image paths and meta information of the dataset. - - define the image transformation. - """ - # save the option and dataset root - BaseDataset.__init__(self, opt) - # get the image paths of your dataset; - self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root - # define the default transform function. You can use ; You can also define your custom transform function - self.transform = get_transform(opt) - - def __getitem__(self, index): - """Return a data point and its metadata information. - - Parameters: - index -- a random integer for data indexing - - Returns: - a dictionary of data with their names. It usually contains the data itself and its metadata information. - - Step 1: get a random image path: e.g., path = self.image_paths[index] - Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). - Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) - Step 4: return a data point as a dictionary. - """ - path = 'temp' # needs to be a string - data_A = None # needs to be a tensor - data_B = None # needs to be a tensor - return {'data_A': data_A, 'data_B': data_B, 'path': path} - - def __len__(self): - """Return the total number of images.""" - return len(self.image_paths) diff --git a/sadtalker_video2pose/src/face3d/extract_kp_videos.py b/sadtalker_video2pose/src/face3d/extract_kp_videos.py deleted file mode 100644 index 68dd79badafd406113ee85cde83492b6c7c66a9b..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/extract_kp_videos.py +++ /dev/null @@ -1,108 +0,0 @@ -import os -import cv2 -import time -import glob -import argparse -import face_alignment -import numpy as np -from PIL import Image -from tqdm import tqdm -from itertools import cycle - -from torch.multiprocessing import Pool, Process, set_start_method - -class KeypointExtractor(): - def __init__(self, device): - self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, - device=device) - - def extract_keypoint(self, images, name=None, info=True): - if isinstance(images, list): - keypoints = [] - if info: - i_range = tqdm(images,desc='landmark Det:') - else: - i_range = images - - for image in i_range: - current_kp = self.extract_keypoint(image) - if np.mean(current_kp) == -1 and keypoints: - keypoints.append(keypoints[-1]) - else: - keypoints.append(current_kp[None]) - - keypoints = np.concatenate(keypoints, 0) - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - else: - while True: - try: - keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] - break - except RuntimeError as e: - if str(e).startswith('CUDA'): - print("Warning: out of memory, sleep for 1s") - time.sleep(1) - else: - print(e) - break - except TypeError: - print('No face detected in this image') - shape = [68, 2] - keypoints = -1. * np.ones(shape) - break - if name is not None: - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - -def read_video(filename): - frames = [] - cap = cv2.VideoCapture(filename) - while cap.isOpened(): - ret, frame = cap.read() - if ret: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = Image.fromarray(frame) - frames.append(frame) - else: - break - cap.release() - return frames - -def run(data): - filename, opt, device = data - os.environ['CUDA_VISIBLE_DEVICES'] = device - kp_extractor = KeypointExtractor() - images = read_video(filename) - name = filename.split('/')[-2:] - os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) - kp_extractor.extract_keypoint( - images, - name=os.path.join(opt.output_dir, name[-2], name[-1]) - ) - -if __name__ == '__main__': - set_start_method('spawn') - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--input_dir', type=str, help='the folder of the input files') - parser.add_argument('--output_dir', type=str, help='the folder of the output files') - parser.add_argument('--device_ids', type=str, default='0,1') - parser.add_argument('--workers', type=int, default=4) - - opt = parser.parse_args() - filenames = list() - VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} - VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) - extensions = VIDEO_EXTENSIONS - - for ext in extensions: - os.listdir(f'{opt.input_dir}') - print(f'{opt.input_dir}/*.{ext}') - filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) - print('Total number of videos:', len(filenames)) - pool = Pool(opt.workers) - args_list = cycle([opt]) - device_ids = opt.device_ids.split(",") - device_ids = cycle(device_ids) - for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): - None diff --git a/sadtalker_video2pose/src/face3d/extract_kp_videos_safe.py b/sadtalker_video2pose/src/face3d/extract_kp_videos_safe.py deleted file mode 100644 index 1bc08ec9b5b48d7d7ecb53a018d24065461a4347..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/extract_kp_videos_safe.py +++ /dev/null @@ -1,145 +0,0 @@ -import os -import cv2 -import time -import glob -import argparse -import numpy as np -from PIL import Image -import torch -from tqdm import tqdm -from itertools import cycle -from torch.multiprocessing import Pool, Process, set_start_method - -from facexlib.alignment import landmark_98_to_68 -from facexlib.detection import init_detection_model - -from facexlib.utils import load_file_from_url -from facexlib.alignment.awing_arch import FAN - -def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): - if model_name == 'awing_fan': - model = FAN(num_modules=4, num_landmarks=98, device=device) - model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' - else: - raise NotImplementedError(f'{model_name} is not implemented.') - - model_path = load_file_from_url( - url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) - model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) - model.eval() - model = model.to(device) - return model - - -class KeypointExtractor(): - def __init__(self, device='cuda'): - - root_path = './ckpts/gfpgan' - - self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path) - self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path) - - def extract_keypoint(self, images, name=None, info=True): - if isinstance(images, list): - keypoints = [] - if info: - i_range = tqdm(images,desc='landmark Det:') - else: - i_range = images - - for image in i_range: - current_kp = self.extract_keypoint(image) - # current_kp = self.detector.get_landmarks(np.array(image)) - if np.mean(current_kp) == -1 and keypoints: - keypoints.append(keypoints[-1]) - else: - keypoints.append(current_kp[None]) - - keypoints = np.concatenate(keypoints, 0) - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - else: - while True: - try: - with torch.no_grad(): - # face detection -> face alignment. - img = np.array(images) - bboxes = self.det_net.detect_faces(images, 0.97) - - bboxes = bboxes[0] - img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] - - keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] - - #### keypoints to the original location - keypoints[:,0] += int(bboxes[0]) - keypoints[:,1] += int(bboxes[1]) - - break - except RuntimeError as e: - if str(e).startswith('CUDA'): - print("Warning: out of memory, sleep for 1s") - time.sleep(1) - else: - print(e) - break - except TypeError: - print('No face detected in this image') - shape = [68, 2] - keypoints = -1. * np.ones(shape) - break - if name is not None: - np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) - return keypoints - -def read_video(filename): - frames = [] - cap = cv2.VideoCapture(filename) - while cap.isOpened(): - ret, frame = cap.read() - if ret: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = Image.fromarray(frame) - frames.append(frame) - else: - break - cap.release() - return frames - -def run(data): - filename, opt, device = data - os.environ['CUDA_VISIBLE_DEVICES'] = device - kp_extractor = KeypointExtractor() - images = read_video(filename) - name = filename.split('/')[-2:] - os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) - kp_extractor.extract_keypoint( - images, - name=os.path.join(opt.output_dir, name[-2], name[-1]) - ) - -if __name__ == '__main__': - set_start_method('spawn') - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--input_dir', type=str, help='the folder of the input files') - parser.add_argument('--output_dir', type=str, help='the folder of the output files') - parser.add_argument('--device_ids', type=str, default='0,1') - parser.add_argument('--workers', type=int, default=4) - - opt = parser.parse_args() - filenames = list() - VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} - VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) - extensions = VIDEO_EXTENSIONS - - for ext in extensions: - os.listdir(f'{opt.input_dir}') - print(f'{opt.input_dir}/*.{ext}') - filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) - print('Total number of videos:', len(filenames)) - pool = Pool(opt.workers) - args_list = cycle([opt]) - device_ids = opt.device_ids.split(",") - device_ids = cycle(device_ids) - for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): - None diff --git a/sadtalker_video2pose/src/face3d/models/__init__.py b/sadtalker_video2pose/src/face3d/models/__init__.py deleted file mode 100644 index ef6b5e399254bd42850f3385878f35d4acf90852..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -"""This package contains modules related to objective functions, optimizations, and network architectures. - -To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. -You need to implement the following five functions: - -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). - -- : unpack data from dataset and apply preprocessing. - -- : produce intermediate results. - -- : calculate loss, gradients, and update network weights. - -- : (optionally) add model-specific options and set default options. - -In the function <__init__>, you need to define four lists: - -- self.loss_names (str list): specify the training losses that you want to plot and save. - -- self.model_names (str list): define networks used in our training. - -- self.visual_names (str list): specify the images that you want to display and save. - -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. - -Now you can use the model class by specifying flag '--model dummy'. -See our template model class 'template_model.py' for more details. -""" - -import importlib -from src.face3d.models.base_model import BaseModel - - -def find_model_using_name(model_name): - """Import the module "models/[model_name]_model.py". - - In the file, the class called DatasetNameModel() will - be instantiated. It has to be a subclass of BaseModel, - and it is case-insensitive. - """ - model_filename = "face3d.models." + model_name + "_model" - modellib = importlib.import_module(model_filename) - model = None - target_model_name = model_name.replace('_', '') + 'model' - for name, cls in modellib.__dict__.items(): - if name.lower() == target_model_name.lower() \ - and issubclass(cls, BaseModel): - model = cls - - if model is None: - print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) - exit(0) - - return model - - -def get_option_setter(model_name): - """Return the static method of the model class.""" - model_class = find_model_using_name(model_name) - return model_class.modify_commandline_options - - -def create_model(opt): - """Create a model given the option. - - This function warps the class CustomDatasetDataLoader. - This is the main interface between this package and 'train.py'/'test.py' - - Example: - >>> from models import create_model - >>> model = create_model(opt) - """ - model = find_model_using_name(opt.model) - instance = model(opt) - print("model [%s] was created" % type(instance).__name__) - return instance diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/README.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/README.md deleted file mode 100644 index cc7f1d45f2f5e4b752c42dc81d3e2879c1459c6e..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/README.md +++ /dev/null @@ -1,164 +0,0 @@ -# Distributed Arcface Training in Pytorch - -This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions -identity on a single server. - -## Requirements - -- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). -- `pip install -r requirements.txt`. -- Download the dataset - from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) - . - -## How to Training - -To train a model, run `train.py` with the path to the configs: - -### 1. Single node, 8 GPUs: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 -``` - -### 2. Multiple nodes, each node 8 GPUs: - -Node 0: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 -``` - -Node 1: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 -``` - -### 3.Training resnet2060 with 8 GPUs: - -```shell -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py -``` - -## Model Zoo - -- The models are available for non-commercial research purposes only. -- All models can be found in here. -- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw -- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) - -### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) - -ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face -recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. -As the result, we can evaluate the FAIR performance for different algorithms. - -For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The -globalised multi-racial testset contains 242,143 identities and 1,624,305 images. - -For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). -Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. -There are totally 13,928 positive pairs and 96,983,824 negative pairs. - -| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | -| :---: | :--- | :--- | :--- |:--- |:--- | -| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | -| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | -| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | -| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | -| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | -| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | -| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | -| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | -| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | -| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | - -### Performance on IJB-C and Verification Datasets - -| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | -| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | -| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| -| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| -| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| -| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| -| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| -| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| -| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| -| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| -| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| - -[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) - - -## [Speed Benchmark](docs/speed_benchmark.md) - -**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of -classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same -accuracy with several times faster training performance and smaller GPU memory. -Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a -sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a -sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, -we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed -training and mixed precision training. - -![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png) - -More details see -[speed_benchmark.md](docs/speed_benchmark.md) in docs. - -### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better) - -`-` means training failed because of gpu memory limitations. - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 4681 | 4824 | 5004 | -|1400000 | **1672** | 3043 | 4738 | -|5500000 | **-** | **1389** | 3975 | -|8000000 | **-** | **-** | 3565 | -|16000000 | **-** | **-** | 2679 | -|29000000 | **-** | **-** | **1855** | - -### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 7358 | 5306 | 4868 | -|1400000 | 32252 | 11178 | 6056 | -|5500000 | **-** | 32188 | 9854 | -|8000000 | **-** | **-** | 12310 | -|16000000 | **-** | **-** | 19950 | -|29000000 | **-** | **-** | 32324 | - -## Evaluation ICCV2021-MFR and IJB-C - -More details see [eval.md](docs/eval.md) in docs. - -## Test - -We tested many versions of PyTorch. Please create an issue if you are having trouble. - -- [x] torch 1.6.0 -- [x] torch 1.7.1 -- [x] torch 1.8.0 -- [x] torch 1.9.0 - -## Citation - -``` -@inproceedings{deng2019arcface, - title={Arcface: Additive angular margin loss for deep face recognition}, - author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, - booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, - pages={4690--4699}, - year={2019} -} -@inproceedings{an2020partical_fc, - title={Partial FC: Training 10 Million Identities on a Single Machine}, - author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and - Zhang, Debing and Fu Ying}, - booktitle={Arxiv 2010.05222}, - year={2020} -} -``` diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/__init__.py deleted file mode 100644 index 5650187b4fdea84c5a23e0445440901690ab682a..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 -from .mobilefacenet import get_mbf - - -def get_model(name, **kwargs): - # resnet - if name == "r18": - return iresnet18(False, **kwargs) - elif name == "r34": - return iresnet34(False, **kwargs) - elif name == "r50": - return iresnet50(False, **kwargs) - elif name == "r100": - return iresnet100(False, **kwargs) - elif name == "r200": - return iresnet200(False, **kwargs) - elif name == "r2060": - from .iresnet2060 import iresnet2060 - return iresnet2060(False, **kwargs) - elif name == "mbf": - fp16 = kwargs.get("fp16", False) - num_features = kwargs.get("num_features", 512) - return get_mbf(fp16=fp16, num_features=num_features) - else: - raise ValueError() \ No newline at end of file diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet.py deleted file mode 100644 index d29f5f2bfbd444273717c4bc8aa20ba7edd08f80..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -from torch import nn - -__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=1, - stride=stride, - bias=False) - - -class IBasicBlock(nn.Module): - expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - groups=1, base_width=64, dilation=1): - super(IBasicBlock, self).__init__() - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) - self.conv1 = conv3x3(inplanes, planes) - self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) - self.prelu = nn.PReLU(planes) - self.conv2 = conv3x3(planes, planes, stride) - self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - out = self.bn1(x) - out = self.conv1(out) - out = self.bn2(out) - out = self.prelu(out) - out = self.conv2(out) - out = self.bn3(out) - if self.downsample is not None: - identity = self.downsample(x) - out += identity - return out - - -class IResNet(nn.Module): - fc_scale = 7 * 7 - def __init__(self, - block, layers, dropout=0, num_features=512, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): - super(IResNet, self).__init__() - self.fp16 = fp16 - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) - self.prelu = nn.PReLU(self.inplanes) - self.layer1 = self._make_layer(block, 64, layers[0], stride=2) - self.layer2 = self._make_layer(block, - 128, - layers[1], - stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, - 256, - layers[2], - stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, - 512, - layers[3], - stride=2, - dilate=replace_stride_with_dilation[2]) - self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) - self.dropout = nn.Dropout(p=dropout, inplace=True) - self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) - self.features = nn.BatchNorm1d(num_features, eps=1e-05) - nn.init.constant_(self.features.weight, 1.0) - self.features.weight.requires_grad = False - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.normal_(m.weight, 0, 0.1) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - if zero_init_residual: - for m in self.modules(): - if isinstance(m, IBasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), - ) - layers = [] - layers.append( - block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append( - block(self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation)) - - return nn.Sequential(*layers) - - def forward(self, x): - with torch.cuda.amp.autocast(self.fp16): - x = self.conv1(x) - x = self.bn1(x) - x = self.prelu(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.bn2(x) - x = torch.flatten(x, 1) - x = self.dropout(x) - x = self.fc(x.float() if self.fp16 else x) - x = self.features(x) - return x - - -def _iresnet(arch, block, layers, pretrained, progress, **kwargs): - model = IResNet(block, layers, **kwargs) - if pretrained: - raise ValueError() - return model - - -def iresnet18(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, - progress, **kwargs) - - -def iresnet34(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, - progress, **kwargs) - - -def iresnet50(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, - progress, **kwargs) - - -def iresnet100(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, - progress, **kwargs) - - -def iresnet200(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, - progress, **kwargs) - diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py deleted file mode 100644 index 39bb4335716b653bd5924e20d616d825ef48339f..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/iresnet2060.py +++ /dev/null @@ -1,176 +0,0 @@ -import torch -from torch import nn - -assert torch.__version__ >= "1.8.1" -from torch.utils.checkpoint import checkpoint_sequential - -__all__ = ['iresnet2060'] - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=1, - stride=stride, - bias=False) - - -class IBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - groups=1, base_width=64, dilation=1): - super(IBasicBlock, self).__init__() - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) - self.conv1 = conv3x3(inplanes, planes) - self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) - self.prelu = nn.PReLU(planes) - self.conv2 = conv3x3(planes, planes, stride) - self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - out = self.bn1(x) - out = self.conv1(out) - out = self.bn2(out) - out = self.prelu(out) - out = self.conv2(out) - out = self.bn3(out) - if self.downsample is not None: - identity = self.downsample(x) - out += identity - return out - - -class IResNet(nn.Module): - fc_scale = 7 * 7 - - def __init__(self, - block, layers, dropout=0, num_features=512, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): - super(IResNet, self).__init__() - self.fp16 = fp16 - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) - self.prelu = nn.PReLU(self.inplanes) - self.layer1 = self._make_layer(block, 64, layers[0], stride=2) - self.layer2 = self._make_layer(block, - 128, - layers[1], - stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, - 256, - layers[2], - stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, - 512, - layers[3], - stride=2, - dilate=replace_stride_with_dilation[2]) - self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) - self.dropout = nn.Dropout(p=dropout, inplace=True) - self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) - self.features = nn.BatchNorm1d(num_features, eps=1e-05) - nn.init.constant_(self.features.weight, 1.0) - self.features.weight.requires_grad = False - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.normal_(m.weight, 0, 0.1) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - if zero_init_residual: - for m in self.modules(): - if isinstance(m, IBasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), - ) - layers = [] - layers.append( - block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append( - block(self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation)) - - return nn.Sequential(*layers) - - def checkpoint(self, func, num_seg, x): - if self.training: - return checkpoint_sequential(func, num_seg, x) - else: - return func(x) - - def forward(self, x): - with torch.cuda.amp.autocast(self.fp16): - x = self.conv1(x) - x = self.bn1(x) - x = self.prelu(x) - x = self.layer1(x) - x = self.checkpoint(self.layer2, 20, x) - x = self.checkpoint(self.layer3, 100, x) - x = self.layer4(x) - x = self.bn2(x) - x = torch.flatten(x, 1) - x = self.dropout(x) - x = self.fc(x.float() if self.fp16 else x) - x = self.features(x) - return x - - -def _iresnet(arch, block, layers, pretrained, progress, **kwargs): - model = IResNet(block, layers, **kwargs) - if pretrained: - raise ValueError() - return model - - -def iresnet2060(pretrained=False, progress=True, **kwargs): - return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py deleted file mode 100644 index c02c6c1e4fa6a6ddf09f5b01dec96971427cb110..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/backbones/mobilefacenet.py +++ /dev/null @@ -1,130 +0,0 @@ -''' -Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py -Original author cavalleria -''' - -import torch.nn as nn -from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module -import torch - - -class Flatten(Module): - def forward(self, x): - return x.view(x.size(0), -1) - - -class ConvBlock(Module): - def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): - super(ConvBlock, self).__init__() - self.layers = nn.Sequential( - Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), - BatchNorm2d(num_features=out_c), - PReLU(num_parameters=out_c) - ) - - def forward(self, x): - return self.layers(x) - - -class LinearBlock(Module): - def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): - super(LinearBlock, self).__init__() - self.layers = nn.Sequential( - Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), - BatchNorm2d(num_features=out_c) - ) - - def forward(self, x): - return self.layers(x) - - -class DepthWise(Module): - def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): - super(DepthWise, self).__init__() - self.residual = residual - self.layers = nn.Sequential( - ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), - ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), - LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) - ) - - def forward(self, x): - short_cut = None - if self.residual: - short_cut = x - x = self.layers(x) - if self.residual: - output = short_cut + x - else: - output = x - return output - - -class Residual(Module): - def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): - super(Residual, self).__init__() - modules = [] - for _ in range(num_block): - modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) - self.layers = Sequential(*modules) - - def forward(self, x): - return self.layers(x) - - -class GDC(Module): - def __init__(self, embedding_size): - super(GDC, self).__init__() - self.layers = nn.Sequential( - LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), - Flatten(), - Linear(512, embedding_size, bias=False), - BatchNorm1d(embedding_size)) - - def forward(self, x): - return self.layers(x) - - -class MobileFaceNet(Module): - def __init__(self, fp16=False, num_features=512): - super(MobileFaceNet, self).__init__() - scale = 2 - self.fp16 = fp16 - self.layers = nn.Sequential( - ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), - ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), - DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), - Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), - DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), - Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), - DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), - Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), - ) - self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) - self.features = GDC(num_features) - self._initialize_weights() - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - m.bias.data.zero_() - - def forward(self, x): - with torch.cuda.amp.autocast(self.fp16): - x = self.layers(x) - x = self.conv_sep(x.float() if self.fp16 else x) - x = self.features(x) - return x - - -def get_mbf(fp16, num_features): - return MobileFaceNet(fp16, num_features) \ No newline at end of file diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions.py deleted file mode 100644 index 3bee7cb4236e8b842a1bd1e8c26de7a11df0bf43..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 300 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py deleted file mode 100644 index bf7df5f04e2509e5dcc14adebbb9302a18f03f2b..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/3millions_pfc.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 0.1 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 300 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/base.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/base.py deleted file mode 100644 index f98c62fed44afde276dcbacecd9da0a8f474963c..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/base.py +++ /dev/null @@ -1,56 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = "ms1mv3_arcface_r50" - -config.dataset = "ms1m-retinaface-t1" -config.embedding_size = 512 -config.sample_rate = 1 -config.fp16 = False -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -if config.dataset == "emore": - config.rec = "/train_tmp/faces_emore" - config.num_classes = 85742 - config.num_image = 5822653 - config.num_epoch = 16 - config.warmup_epoch = -1 - config.decay_epoch = [8, 14, ] - config.val_targets = ["lfw", ] - -elif config.dataset == "ms1m-retinaface-t1": - config.rec = "/train_tmp/ms1m-retinaface-t1" - config.num_classes = 93431 - config.num_image = 5179510 - config.num_epoch = 25 - config.warmup_epoch = -1 - config.decay_epoch = [11, 17, 22] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] - -elif config.dataset == "glint360k": - config.rec = "/train_tmp/glint360k" - config.num_classes = 360232 - config.num_image = 17091657 - config.num_epoch = 20 - config.warmup_epoch = -1 - config.decay_epoch = [8, 12, 15, 18] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] - -elif config.dataset == "webface": - config.rec = "/train_tmp/faces_webface_112x112" - config.num_classes = 10572 - config.num_image = "forget" - config.num_epoch = 34 - config.warmup_epoch = -1 - config.decay_epoch = [20, 28, 32] - config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py deleted file mode 100644 index 44ee5e8d96249d57196df43418f6fda4ab339877..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_mbf.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "mbf" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 0.1 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 2e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py deleted file mode 100644 index f8f8ef745c0efb9d5ea67409edc8c904def8a9d9..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r100.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r100" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py deleted file mode 100644 index 473b59a954fffcaddca132fb6e0f32cbe70c70f4..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r18.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r18" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py deleted file mode 100644 index d9c22ff0c82cc98bbbe81c9a1c26c9b3fc186105..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r34.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r34" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py deleted file mode 100644 index 8ecbfda06730e3842e7b347db366e82f0714912f..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/glint360k_r50.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "cosface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/glint360k" -config.num_classes = 360232 -config.num_image = 17091657 -config.num_epoch = 20 -config.warmup_epoch = -1 -config.decay_epoch = [8, 12, 15, 18] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py deleted file mode 100644 index 47c87a99867db55c7f689574c331c14cda23ea96..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "mbf" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 2e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 20, 25] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py deleted file mode 100644 index 1aeb851b05ea22e01da87b3d387812f0253989f8..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r18" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py deleted file mode 100644 index 8693e67080dac7e7b84da08a62df326c7b12d465..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r2060" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 64 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py deleted file mode 100644 index 52bff483db179045c0e3acc8e2975477182b0756..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r34" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py deleted file mode 100644 index de81ffdd84edd6fcea7fcb4d3594db031b9e4e26..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py +++ /dev/null @@ -1,26 +0,0 @@ -from easydict import EasyDict as edict - -# make training faster -# our RAM is 256G -# mount -t tmpfs -o size=140G tmpfs /train_tmp - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "/train_tmp/ms1m-retinaface-t1" -config.num_classes = 93431 -config.num_image = 5179510 -config.num_epoch = 25 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/speed.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/speed.py deleted file mode 100644 index c172f9d44d39b534f2253630471e91cf78e6fba7..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/configs/speed.py +++ /dev/null @@ -1,23 +0,0 @@ -from easydict import EasyDict as edict - -# configs for test speed - -config = edict() -config.loss = "arcface" -config.network = "r50" -config.resume = False -config.output = None -config.embedding_size = 512 -config.sample_rate = 1.0 -config.fp16 = True -config.momentum = 0.9 -config.weight_decay = 5e-4 -config.batch_size = 128 -config.lr = 0.1 # batch size is 512 - -config.rec = "synthetic" -config.num_classes = 100 * 10000 -config.num_epoch = 30 -config.warmup_epoch = -1 -config.decay_epoch = [10, 16, 22] -config.val_targets = [] diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/dataset.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/dataset.py deleted file mode 100644 index 8bead250243237c650fa3138f6aa172d4f98535f..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/dataset.py +++ /dev/null @@ -1,124 +0,0 @@ -import numbers -import os -import queue as Queue -import threading - -import mxnet as mx -import numpy as np -import torch -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms - - -class BackgroundGenerator(threading.Thread): - def __init__(self, generator, local_rank, max_prefetch=6): - super(BackgroundGenerator, self).__init__() - self.queue = Queue.Queue(max_prefetch) - self.generator = generator - self.local_rank = local_rank - self.daemon = True - self.start() - - def run(self): - torch.cuda.set_device(self.local_rank) - for item in self.generator: - self.queue.put(item) - self.queue.put(None) - - def next(self): - next_item = self.queue.get() - if next_item is None: - raise StopIteration - return next_item - - def __next__(self): - return self.next() - - def __iter__(self): - return self - - -class DataLoaderX(DataLoader): - - def __init__(self, local_rank, **kwargs): - super(DataLoaderX, self).__init__(**kwargs) - self.stream = torch.cuda.Stream(local_rank) - self.local_rank = local_rank - - def __iter__(self): - self.iter = super(DataLoaderX, self).__iter__() - self.iter = BackgroundGenerator(self.iter, self.local_rank) - self.preload() - return self - - def preload(self): - self.batch = next(self.iter, None) - if self.batch is None: - return None - with torch.cuda.stream(self.stream): - for k in range(len(self.batch)): - self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) - - def __next__(self): - torch.cuda.current_stream().wait_stream(self.stream) - batch = self.batch - if batch is None: - raise StopIteration - self.preload() - return batch - - -class MXFaceDataset(Dataset): - def __init__(self, root_dir, local_rank): - super(MXFaceDataset, self).__init__() - self.transform = transforms.Compose( - [transforms.ToPILImage(), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - self.root_dir = root_dir - self.local_rank = local_rank - path_imgrec = os.path.join(root_dir, 'train.rec') - path_imgidx = os.path.join(root_dir, 'train.idx') - self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') - s = self.imgrec.read_idx(0) - header, _ = mx.recordio.unpack(s) - if header.flag > 0: - self.header0 = (int(header.label[0]), int(header.label[1])) - self.imgidx = np.array(range(1, int(header.label[0]))) - else: - self.imgidx = np.array(list(self.imgrec.keys)) - - def __getitem__(self, index): - idx = self.imgidx[index] - s = self.imgrec.read_idx(idx) - header, img = mx.recordio.unpack(s) - label = header.label - if not isinstance(label, numbers.Number): - label = label[0] - label = torch.tensor(label, dtype=torch.long) - sample = mx.image.imdecode(img).asnumpy() - if self.transform is not None: - sample = self.transform(sample) - return sample, label - - def __len__(self): - return len(self.imgidx) - - -class SyntheticDataset(Dataset): - def __init__(self, local_rank): - super(SyntheticDataset, self).__init__() - img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) - img = np.transpose(img, (2, 0, 1)) - img = torch.from_numpy(img).squeeze(0).float() - img = ((img / 255) - 0.5) / 0.5 - self.img = img - self.label = 1 - - def __getitem__(self, index): - return self.img, self.label - - def __len__(self): - return 1000000 diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/eval.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/eval.md deleted file mode 100644 index 4d29c855fc6e4245ed264216c1f96ab2efc57248..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/eval.md +++ /dev/null @@ -1,31 +0,0 @@ -## Eval on ICCV2021-MFR - -coming soon. - - -## Eval IJBC -You can eval ijbc with pytorch or onnx. - - -1. Eval IJBC With Onnx -```shell -CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 -``` - -2. Eval IJBC With Pytorch -```shell -CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ ---model-prefix ms1mv3_arcface_r50/backbone.pth \ ---image-path IJB_release/IJBC \ ---result-dir ms1mv3_arcface_r50 \ ---batch-size 128 \ ---job ms1mv3_arcface_r50 \ ---target IJBC \ ---network iresnet50 -``` - -## Inference - -```shell -python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 -``` diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/install.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/install.md deleted file mode 100644 index b1b770a0d93dac1f160185b5bbf4da2f414f21f6..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/install.md +++ /dev/null @@ -1,51 +0,0 @@ -## v1.8.0 -### Linux and Windows -```shell -# CUDA 11.0 -pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 10.2 -pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 - -# CPU only -pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html - -``` - - -## v1.7.1 -### Linux and Windows -```shell -# CUDA 11.0 -pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 10.2 -pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 - -# CUDA 10.1 -pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 9.2 -pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html -``` - - -## v1.6.0 - -### Linux and Windows -```shell -# CUDA 10.2 -pip install torch==1.6.0 torchvision==0.7.0 - -# CUDA 10.1 -pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html - -# CUDA 9.2 -pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html - -# CPU only -pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -``` \ No newline at end of file diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/modelzoo.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/modelzoo.md deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md b/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md deleted file mode 100644 index d54904587df4e13784dc68d5709b4d7d97490890..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/docs/speed_benchmark.md +++ /dev/null @@ -1,93 +0,0 @@ -## Test Training Speed - -- Test Commands - -You need to use the following two commands to test the Partial FC training performance. -The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, -batch size is 1024. -```shell -# Model Parallel -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions -# Partial FC 0.1 -python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc -``` - -- GPU Memory - -``` -# (Model Parallel) gpustat -i -[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB -[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB -[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB -[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB -[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB -[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB -[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB -[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB - -# (Partial FC 0.1) gpustat -i -[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· -[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· -[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· -[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· -[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· -[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· -[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· -[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· -``` - -- Training Speed - -```python -# (Model Parallel) trainging.log -Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 -Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 -Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 -Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 -Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 - -# (Partial FC 0.1) trainging.log -Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 -Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 -Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 -Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 -Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 -``` - -In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, -and the training speed is 2.5 times faster than the model parallel. - - -## Speed Benchmark - -1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 4681 | 4824 | 5004 | -|250000 | 4047 | 4521 | 4976 | -|500000 | 3087 | 4013 | 4900 | -|1000000 | 2090 | 3449 | 4803 | -|1400000 | 1672 | 3043 | 4738 | -|2000000 | - | 2593 | 4626 | -|4000000 | - | 1748 | 4208 | -|5500000 | - | 1389 | 3975 | -|8000000 | - | - | 3565 | -|16000000 | - | - | 2679 | -|29000000 | - | - | 1855 | - -2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) - -| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | -| :--- | :--- | :--- | :--- | -|125000 | 7358 | 5306 | 4868 | -|250000 | 9940 | 5826 | 5004 | -|500000 | 14220 | 7114 | 5202 | -|1000000 | 23708 | 9966 | 5620 | -|1400000 | 32252 | 11178 | 6056 | -|2000000 | - | 13978 | 6472 | -|4000000 | - | 23238 | 8284 | -|5500000 | - | 32188 | 9854 | -|8000000 | - | - | 12310 | -|16000000 | - | - | 19950 | -|29000000 | - | - | 32324 | diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/verification.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/verification.py deleted file mode 100644 index 5b1f5618184effae64895847af1a65d43d2e4418..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval/verification.py +++ /dev/null @@ -1,407 +0,0 @@ -"""Helper for evaluation on the Labeled Faces in the Wild dataset -""" - -# MIT License -# -# Copyright (c) 2016 David Sandberg -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - -import datetime -import os -import pickle - -import mxnet as mx -import numpy as np -import sklearn -import torch -from mxnet import ndarray as nd -from scipy import interpolate -from sklearn.decomposition import PCA -from sklearn.model_selection import KFold - - -class LFold: - def __init__(self, n_splits=2, shuffle=False): - self.n_splits = n_splits - if self.n_splits > 1: - self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) - - def split(self, indices): - if self.n_splits > 1: - return self.k_fold.split(indices) - else: - return [(indices, indices)] - - -def calculate_roc(thresholds, - embeddings1, - embeddings2, - actual_issame, - nrof_folds=10, - pca=0): - assert (embeddings1.shape[0] == embeddings2.shape[0]) - assert (embeddings1.shape[1] == embeddings2.shape[1]) - nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) - nrof_thresholds = len(thresholds) - k_fold = LFold(n_splits=nrof_folds, shuffle=False) - - tprs = np.zeros((nrof_folds, nrof_thresholds)) - fprs = np.zeros((nrof_folds, nrof_thresholds)) - accuracy = np.zeros((nrof_folds)) - indices = np.arange(nrof_pairs) - - if pca == 0: - diff = np.subtract(embeddings1, embeddings2) - dist = np.sum(np.square(diff), 1) - - for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): - if pca > 0: - print('doing pca on', fold_idx) - embed1_train = embeddings1[train_set] - embed2_train = embeddings2[train_set] - _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) - pca_model = PCA(n_components=pca) - pca_model.fit(_embed_train) - embed1 = pca_model.transform(embeddings1) - embed2 = pca_model.transform(embeddings2) - embed1 = sklearn.preprocessing.normalize(embed1) - embed2 = sklearn.preprocessing.normalize(embed2) - diff = np.subtract(embed1, embed2) - dist = np.sum(np.square(diff), 1) - - # Find the best threshold for the fold - acc_train = np.zeros((nrof_thresholds)) - for threshold_idx, threshold in enumerate(thresholds): - _, _, acc_train[threshold_idx] = calculate_accuracy( - threshold, dist[train_set], actual_issame[train_set]) - best_threshold_index = np.argmax(acc_train) - for threshold_idx, threshold in enumerate(thresholds): - tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( - threshold, dist[test_set], - actual_issame[test_set]) - _, _, accuracy[fold_idx] = calculate_accuracy( - thresholds[best_threshold_index], dist[test_set], - actual_issame[test_set]) - - tpr = np.mean(tprs, 0) - fpr = np.mean(fprs, 0) - return tpr, fpr, accuracy - - -def calculate_accuracy(threshold, dist, actual_issame): - predict_issame = np.less(dist, threshold) - tp = np.sum(np.logical_and(predict_issame, actual_issame)) - fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) - tn = np.sum( - np.logical_and(np.logical_not(predict_issame), - np.logical_not(actual_issame))) - fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) - - tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) - fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) - acc = float(tp + tn) / dist.size - return tpr, fpr, acc - - -def calculate_val(thresholds, - embeddings1, - embeddings2, - actual_issame, - far_target, - nrof_folds=10): - assert (embeddings1.shape[0] == embeddings2.shape[0]) - assert (embeddings1.shape[1] == embeddings2.shape[1]) - nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) - nrof_thresholds = len(thresholds) - k_fold = LFold(n_splits=nrof_folds, shuffle=False) - - val = np.zeros(nrof_folds) - far = np.zeros(nrof_folds) - - diff = np.subtract(embeddings1, embeddings2) - dist = np.sum(np.square(diff), 1) - indices = np.arange(nrof_pairs) - - for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): - - # Find the threshold that gives FAR = far_target - far_train = np.zeros(nrof_thresholds) - for threshold_idx, threshold in enumerate(thresholds): - _, far_train[threshold_idx] = calculate_val_far( - threshold, dist[train_set], actual_issame[train_set]) - if np.max(far_train) >= far_target: - f = interpolate.interp1d(far_train, thresholds, kind='slinear') - threshold = f(far_target) - else: - threshold = 0.0 - - val[fold_idx], far[fold_idx] = calculate_val_far( - threshold, dist[test_set], actual_issame[test_set]) - - val_mean = np.mean(val) - far_mean = np.mean(far) - val_std = np.std(val) - return val_mean, val_std, far_mean - - -def calculate_val_far(threshold, dist, actual_issame): - predict_issame = np.less(dist, threshold) - true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) - false_accept = np.sum( - np.logical_and(predict_issame, np.logical_not(actual_issame))) - n_same = np.sum(actual_issame) - n_diff = np.sum(np.logical_not(actual_issame)) - # print(true_accept, false_accept) - # print(n_same, n_diff) - val = float(true_accept) / float(n_same) - far = float(false_accept) / float(n_diff) - return val, far - - -def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): - # Calculate evaluation metrics - thresholds = np.arange(0, 4, 0.01) - embeddings1 = embeddings[0::2] - embeddings2 = embeddings[1::2] - tpr, fpr, accuracy = calculate_roc(thresholds, - embeddings1, - embeddings2, - np.asarray(actual_issame), - nrof_folds=nrof_folds, - pca=pca) - thresholds = np.arange(0, 4, 0.001) - val, val_std, far = calculate_val(thresholds, - embeddings1, - embeddings2, - np.asarray(actual_issame), - 1e-3, - nrof_folds=nrof_folds) - return tpr, fpr, accuracy, val, val_std, far - -@torch.no_grad() -def load_bin(path, image_size): - try: - with open(path, 'rb') as f: - bins, issame_list = pickle.load(f) # py2 - except UnicodeDecodeError as e: - with open(path, 'rb') as f: - bins, issame_list = pickle.load(f, encoding='bytes') # py3 - data_list = [] - for flip in [0, 1]: - data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) - data_list.append(data) - for idx in range(len(issame_list) * 2): - _bin = bins[idx] - img = mx.image.imdecode(_bin) - if img.shape[1] != image_size[0]: - img = mx.image.resize_short(img, image_size[0]) - img = nd.transpose(img, axes=(2, 0, 1)) - for flip in [0, 1]: - if flip == 1: - img = mx.ndarray.flip(data=img, axis=2) - data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) - if idx % 1000 == 0: - print('loading bin', idx) - print(data_list[0].shape) - return data_list, issame_list - -@torch.no_grad() -def test(data_set, backbone, batch_size, nfolds=10): - print('testing verification..') - data_list = data_set[0] - issame_list = data_set[1] - embeddings_list = [] - time_consumed = 0.0 - for i in range(len(data_list)): - data = data_list[i] - embeddings = None - ba = 0 - while ba < data.shape[0]: - bb = min(ba + batch_size, data.shape[0]) - count = bb - ba - _data = data[bb - batch_size: bb] - time0 = datetime.datetime.now() - img = ((_data / 255) - 0.5) / 0.5 - net_out: torch.Tensor = backbone(img) - _embeddings = net_out.detach().cpu().numpy() - time_now = datetime.datetime.now() - diff = time_now - time0 - time_consumed += diff.total_seconds() - if embeddings is None: - embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) - embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] - ba = bb - embeddings_list.append(embeddings) - - _xnorm = 0.0 - _xnorm_cnt = 0 - for embed in embeddings_list: - for i in range(embed.shape[0]): - _em = embed[i] - _norm = np.linalg.norm(_em) - _xnorm += _norm - _xnorm_cnt += 1 - _xnorm /= _xnorm_cnt - - acc1 = 0.0 - std1 = 0.0 - embeddings = embeddings_list[0] + embeddings_list[1] - embeddings = sklearn.preprocessing.normalize(embeddings) - print(embeddings.shape) - print('infer time', time_consumed) - _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) - acc2, std2 = np.mean(accuracy), np.std(accuracy) - return acc1, std1, acc2, std2, _xnorm, embeddings_list - - -def dumpR(data_set, - backbone, - batch_size, - name='', - data_extra=None, - label_shape=None): - print('dump verification embedding..') - data_list = data_set[0] - issame_list = data_set[1] - embeddings_list = [] - time_consumed = 0.0 - for i in range(len(data_list)): - data = data_list[i] - embeddings = None - ba = 0 - while ba < data.shape[0]: - bb = min(ba + batch_size, data.shape[0]) - count = bb - ba - - _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) - time0 = datetime.datetime.now() - if data_extra is None: - db = mx.io.DataBatch(data=(_data,), label=(_label,)) - else: - db = mx.io.DataBatch(data=(_data, _data_extra), - label=(_label,)) - model.forward(db, is_train=False) - net_out = model.get_outputs() - _embeddings = net_out[0].asnumpy() - time_now = datetime.datetime.now() - diff = time_now - time0 - time_consumed += diff.total_seconds() - if embeddings is None: - embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) - embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] - ba = bb - embeddings_list.append(embeddings) - embeddings = embeddings_list[0] + embeddings_list[1] - embeddings = sklearn.preprocessing.normalize(embeddings) - actual_issame = np.asarray(issame_list) - outname = os.path.join('temp.bin') - with open(outname, 'wb') as f: - pickle.dump((embeddings, issame_list), - f, - protocol=pickle.HIGHEST_PROTOCOL) - - -# if __name__ == '__main__': -# -# parser = argparse.ArgumentParser(description='do verification') -# # general -# parser.add_argument('--data-dir', default='', help='') -# parser.add_argument('--model', -# default='../model/softmax,50', -# help='path to load model.') -# parser.add_argument('--target', -# default='lfw,cfp_ff,cfp_fp,agedb_30', -# help='test targets.') -# parser.add_argument('--gpu', default=0, type=int, help='gpu id') -# parser.add_argument('--batch-size', default=32, type=int, help='') -# parser.add_argument('--max', default='', type=str, help='') -# parser.add_argument('--mode', default=0, type=int, help='') -# parser.add_argument('--nfolds', default=10, type=int, help='') -# args = parser.parse_args() -# image_size = [112, 112] -# print('image_size', image_size) -# ctx = mx.gpu(args.gpu) -# nets = [] -# vec = args.model.split(',') -# prefix = args.model.split(',')[0] -# epochs = [] -# if len(vec) == 1: -# pdir = os.path.dirname(prefix) -# for fname in os.listdir(pdir): -# if not fname.endswith('.params'): -# continue -# _file = os.path.join(pdir, fname) -# if _file.startswith(prefix): -# epoch = int(fname.split('.')[0].split('-')[1]) -# epochs.append(epoch) -# epochs = sorted(epochs, reverse=True) -# if len(args.max) > 0: -# _max = [int(x) for x in args.max.split(',')] -# assert len(_max) == 2 -# if len(epochs) > _max[1]: -# epochs = epochs[_max[0]:_max[1]] -# -# else: -# epochs = [int(x) for x in vec[1].split('|')] -# print('model number', len(epochs)) -# time0 = datetime.datetime.now() -# for epoch in epochs: -# print('loading', prefix, epoch) -# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) -# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) -# all_layers = sym.get_internals() -# sym = all_layers['fc1_output'] -# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) -# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) -# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], -# image_size[1]))]) -# model.set_params(arg_params, aux_params) -# nets.append(model) -# time_now = datetime.datetime.now() -# diff = time_now - time0 -# print('model loading time', diff.total_seconds()) -# -# ver_list = [] -# ver_name_list = [] -# for name in args.target.split(','): -# path = os.path.join(args.data_dir, name + ".bin") -# if os.path.exists(path): -# print('loading.. ', name) -# data_set = load_bin(path, image_size) -# ver_list.append(data_set) -# ver_name_list.append(name) -# -# if args.mode == 0: -# for i in range(len(ver_list)): -# results = [] -# for model in nets: -# acc1, std1, acc2, std2, xnorm, embeddings_list = test( -# ver_list[i], model, args.batch_size, args.nfolds) -# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) -# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) -# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) -# results.append(acc2) -# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) -# elif args.mode == 1: -# raise ValueError -# else: -# model = nets[0] -# dumpR(ver_list[0], model, args.batch_size, args.target) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval_ijbc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/eval_ijbc.py deleted file mode 100644 index 64844c4723a88b4b160d2fee9a7b626b987981d9..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/eval_ijbc.py +++ /dev/null @@ -1,483 +0,0 @@ -# coding: utf-8 - -import os -import pickle - -import matplotlib -import pandas as pd - -matplotlib.use('Agg') -import matplotlib.pyplot as plt -import timeit -import sklearn -import argparse -import cv2 -import numpy as np -import torch -from skimage import transform as trans -from backbones import get_model -from sklearn.metrics import roc_curve, auc - -from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap -from prettytable import PrettyTable -from pathlib import Path - -import sys -import warnings - -sys.path.insert(0, "../") -warnings.filterwarnings("ignore") - -parser = argparse.ArgumentParser(description='do ijb test') -# general -parser.add_argument('--model-prefix', default='', help='path to load model.') -parser.add_argument('--image-path', default='', type=str, help='') -parser.add_argument('--result-dir', default='.', type=str, help='') -parser.add_argument('--batch-size', default=128, type=int, help='') -parser.add_argument('--network', default='iresnet50', type=str, help='') -parser.add_argument('--job', default='insightface', type=str, help='job name') -parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') -args = parser.parse_args() - -target = args.target -model_path = args.model_prefix -image_path = args.image_path -result_dir = args.result_dir -gpu_id = None -use_norm_score = True # if Ture, TestMode(N1) -use_detector_score = True # if Ture, TestMode(D1) -use_flip_test = True # if Ture, TestMode(F1) -job = args.job -batch_size = args.batch_size - - -class Embedding(object): - def __init__(self, prefix, data_shape, batch_size=1): - image_size = (112, 112) - self.image_size = image_size - weight = torch.load(prefix) - resnet = get_model(args.network, dropout=0, fp16=False).cuda() - resnet.load_state_dict(weight) - model = torch.nn.DataParallel(resnet) - self.model = model - self.model.eval() - src = np.array([ - [30.2946, 51.6963], - [65.5318, 51.5014], - [48.0252, 71.7366], - [33.5493, 92.3655], - [62.7299, 92.2041]], dtype=np.float32) - src[:, 0] += 8.0 - self.src = src - self.batch_size = batch_size - self.data_shape = data_shape - - def get(self, rimg, landmark): - - assert landmark.shape[0] == 68 or landmark.shape[0] == 5 - assert landmark.shape[1] == 2 - if landmark.shape[0] == 68: - landmark5 = np.zeros((5, 2), dtype=np.float32) - landmark5[0] = (landmark[36] + landmark[39]) / 2 - landmark5[1] = (landmark[42] + landmark[45]) / 2 - landmark5[2] = landmark[30] - landmark5[3] = landmark[48] - landmark5[4] = landmark[54] - else: - landmark5 = landmark - tform = trans.SimilarityTransform() - tform.estimate(landmark5, self.src) - M = tform.params[0:2, :] - img = cv2.warpAffine(rimg, - M, (self.image_size[1], self.image_size[0]), - borderValue=0.0) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img_flip = np.fliplr(img) - img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB - img_flip = np.transpose(img_flip, (2, 0, 1)) - input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) - input_blob[0] = img - input_blob[1] = img_flip - return input_blob - - @torch.no_grad() - def forward_db(self, batch_data): - imgs = torch.Tensor(batch_data).cuda() - imgs.div_(255).sub_(0.5).div_(0.5) - feat = self.model(imgs) - feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) - return feat.cpu().numpy() - - -# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] -def divideIntoNstrand(listTemp, n): - twoList = [[] for i in range(n)] - for i, e in enumerate(listTemp): - twoList[i % n].append(e) - return twoList - - -def read_template_media_list(path): - # ijb_meta = np.loadtxt(path, dtype=str) - ijb_meta = pd.read_csv(path, sep=' ', header=None).values - templates = ijb_meta[:, 1].astype(np.int) - medias = ijb_meta[:, 2].astype(np.int) - return templates, medias - - -# In[ ]: - - -def read_template_pair_list(path): - # pairs = np.loadtxt(path, dtype=str) - pairs = pd.read_csv(path, sep=' ', header=None).values - # print(pairs.shape) - # print(pairs[:, 0].astype(np.int)) - t1 = pairs[:, 0].astype(np.int) - t2 = pairs[:, 1].astype(np.int) - label = pairs[:, 2].astype(np.int) - return t1, t2, label - - -# In[ ]: - - -def read_image_feature(path): - with open(path, 'rb') as fid: - img_feats = pickle.load(fid) - return img_feats - - -# In[ ]: - - -def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): - batch_size = args.batch_size - data_shape = (3, 112, 112) - - files = files_list - print('files:', len(files)) - rare_size = len(files) % batch_size - faceness_scores = [] - batch = 0 - img_feats = np.empty((len(files), 1024), dtype=np.float32) - - batch_data = np.empty((2 * batch_size, 3, 112, 112)) - embedding = Embedding(model_path, data_shape, batch_size) - for img_index, each_line in enumerate(files[:len(files) - rare_size]): - name_lmk_score = each_line.strip().split(' ') - img_name = os.path.join(img_path, name_lmk_score[0]) - img = cv2.imread(img_name) - lmk = np.array([float(x) for x in name_lmk_score[1:-1]], - dtype=np.float32) - lmk = lmk.reshape((5, 2)) - input_blob = embedding.get(img, lmk) - - batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] - batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] - if (img_index + 1) % batch_size == 0: - print('batch', batch) - img_feats[batch * batch_size:batch * batch_size + - batch_size][:] = embedding.forward_db(batch_data) - batch += 1 - faceness_scores.append(name_lmk_score[-1]) - - batch_data = np.empty((2 * rare_size, 3, 112, 112)) - embedding = Embedding(model_path, data_shape, rare_size) - for img_index, each_line in enumerate(files[len(files) - rare_size:]): - name_lmk_score = each_line.strip().split(' ') - img_name = os.path.join(img_path, name_lmk_score[0]) - img = cv2.imread(img_name) - lmk = np.array([float(x) for x in name_lmk_score[1:-1]], - dtype=np.float32) - lmk = lmk.reshape((5, 2)) - input_blob = embedding.get(img, lmk) - batch_data[2 * img_index][:] = input_blob[0] - batch_data[2 * img_index + 1][:] = input_blob[1] - if (img_index + 1) % rare_size == 0: - print('batch', batch) - img_feats[len(files) - - rare_size:][:] = embedding.forward_db(batch_data) - batch += 1 - faceness_scores.append(name_lmk_score[-1]) - faceness_scores = np.array(faceness_scores).astype(np.float32) - # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 - # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) - return img_feats, faceness_scores - - -# In[ ]: - - -def image2template_feature(img_feats=None, templates=None, medias=None): - # ========================================================== - # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] - # 2. compute media feature. - # 3. compute template feature. - # ========================================================== - unique_templates = np.unique(templates) - template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) - - for count_template, uqt in enumerate(unique_templates): - - (ind_t,) = np.where(templates == uqt) - face_norm_feats = img_feats[ind_t] - face_medias = medias[ind_t] - unique_medias, unique_media_counts = np.unique(face_medias, - return_counts=True) - media_norm_feats = [] - for u, ct in zip(unique_medias, unique_media_counts): - (ind_m,) = np.where(face_medias == u) - if ct == 1: - media_norm_feats += [face_norm_feats[ind_m]] - else: # image features from the same video will be aggregated into one feature - media_norm_feats += [ - np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) - ] - media_norm_feats = np.array(media_norm_feats) - # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) - template_feats[count_template] = np.sum(media_norm_feats, axis=0) - if count_template % 2000 == 0: - print('Finish Calculating {} template features.'.format( - count_template)) - # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) - template_norm_feats = sklearn.preprocessing.normalize(template_feats) - # print(template_norm_feats.shape) - return template_norm_feats, unique_templates - - -# In[ ]: - - -def verification(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - # ========================================================== - # Compute set-to-set Similarity Score. - # ========================================================== - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - - score = np.zeros((len(p1),)) # save cosine distance between pairs - - total_pairs = np.array(range(len(p1))) - batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation - sublists = [ - total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) - ] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -# In[ ]: -def verification2(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - score = np.zeros((len(p1),)) # save cosine distance between pairs - total_pairs = np.array(range(len(p1))) - batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation - sublists = [ - total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) - ] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -def read_score(path): - with open(path, 'rb') as fid: - img_feats = pickle.load(fid) - return img_feats - - -# # Step1: Load Meta Data - -# In[ ]: - -assert target == 'IJBC' or target == 'IJBB' - -# ============================================================= -# load image and template relationships for template feature embedding -# tid --> template id, mid --> media id -# format: -# image_name tid mid -# ============================================================= -start = timeit.default_timer() -templates, medias = read_template_media_list( - os.path.join('%s/meta' % image_path, - '%s_face_tid_mid.txt' % target.lower())) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# In[ ]: - -# ============================================================= -# load template pairs for template-to-template verification -# tid : template id, label : 1/0 -# format: -# tid_1 tid_2 label -# ============================================================= -start = timeit.default_timer() -p1, p2, label = read_template_pair_list( - os.path.join('%s/meta' % image_path, - '%s_template_pair_label.txt' % target.lower())) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# # Step 2: Get Image Features - -# In[ ]: - -# ============================================================= -# load image features -# format: -# img_feats: [image_num x feats_dim] (227630, 512) -# ============================================================= -start = timeit.default_timer() -img_path = '%s/loose_crop' % image_path -img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) -img_list = open(img_list_path) -files = img_list.readlines() -# files_list = divideIntoNstrand(files, rank_size) -files_list = files - -# img_feats -# for i in range(rank_size): -img_feats, faceness_scores = get_image_feature(img_path, files_list, - model_path, 0, gpu_id) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) -print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], - img_feats.shape[1])) - -# # Step3: Get Template Features - -# In[ ]: - -# ============================================================= -# compute template features from image features. -# ============================================================= -start = timeit.default_timer() -# ========================================================== -# Norm feature before aggregation into template feature? -# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). -# ========================================================== -# 1. FaceScore (Feature Norm) -# 2. FaceScore (Detector) - -if use_flip_test: - # concat --- F1 - # img_input_feats = img_feats - # add --- F2 - img_input_feats = img_feats[:, 0:img_feats.shape[1] // - 2] + img_feats[:, img_feats.shape[1] // 2:] -else: - img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] - -if use_norm_score: - img_input_feats = img_input_feats -else: - # normalise features to remove norm information - img_input_feats = img_input_feats / np.sqrt( - np.sum(img_input_feats ** 2, -1, keepdims=True)) - -if use_detector_score: - print(img_input_feats.shape, faceness_scores.shape) - img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] -else: - img_input_feats = img_input_feats - -template_norm_feats, unique_templates = image2template_feature( - img_input_feats, templates, medias) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# # Step 4: Get Template Similarity Scores - -# In[ ]: - -# ============================================================= -# compute verification scores between template pairs. -# ============================================================= -start = timeit.default_timer() -score = verification(template_norm_feats, unique_templates, p1, p2) -stop = timeit.default_timer() -print('Time: %.2f s. ' % (stop - start)) - -# In[ ]: -save_path = os.path.join(result_dir, args.job) -# save_path = result_dir + '/%s_result' % target - -if not os.path.exists(save_path): - os.makedirs(save_path) - -score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) -np.save(score_save_file, score) - -# # Step 5: Get ROC Curves and TPR@FPR Table - -# In[ ]: - -files = [score_save_file] -methods = [] -scores = [] -for file in files: - methods.append(Path(file).stem) - scores.append(np.load(file)) - -methods = np.array(methods) -scores = dict(zip(methods, scores)) -colours = dict( - zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) -x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] -tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) -fig = plt.figure() -for method in methods: - fpr, tpr, _ = roc_curve(label, scores[method]) - roc_auc = auc(fpr, tpr) - fpr = np.flipud(fpr) - tpr = np.flipud(tpr) # select largest tpr at same fpr - plt.plot(fpr, - tpr, - color=colours[method], - lw=1, - label=('[%s (AUC = %0.4f %%)]' % - (method.split('-')[-1], roc_auc * 100))) - tpr_fpr_row = [] - tpr_fpr_row.append("%s-%s" % (method, target)) - for fpr_iter in np.arange(len(x_labels)): - _, min_index = min( - list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) - tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) - tpr_fpr_table.add_row(tpr_fpr_row) -plt.xlim([10 ** -6, 0.1]) -plt.ylim([0.3, 1.0]) -plt.grid(linestyle='--', linewidth=1) -plt.xticks(x_labels) -plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) -plt.xscale('log') -plt.xlabel('False Positive Rate') -plt.ylabel('True Positive Rate') -plt.title('ROC on IJB') -plt.legend(loc="lower right") -fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) -print(tpr_fpr_table) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/inference.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/inference.py deleted file mode 100644 index 1929d4abb640d040398dda57b491b9bd96deac9d..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/inference.py +++ /dev/null @@ -1,35 +0,0 @@ -import argparse - -import cv2 -import numpy as np -import torch - -from backbones import get_model - - -@torch.no_grad() -def inference(weight, name, img): - if img is None: - img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) - else: - img = cv2.imread(img) - img = cv2.resize(img, (112, 112)) - - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = np.transpose(img, (2, 0, 1)) - img = torch.from_numpy(img).unsqueeze(0).float() - img.div_(255).sub_(0.5).div_(0.5) - net = get_model(name, fp16=False) - net.load_state_dict(torch.load(weight)) - net.eval() - feat = net(img).numpy() - print(feat) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') - parser.add_argument('--network', type=str, default='r50', help='backbone network') - parser.add_argument('--weight', type=str, default='') - parser.add_argument('--img', type=str, default=None) - args = parser.parse_args() - inference(args.weight, args.network, args.img) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/losses.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/losses.py deleted file mode 100644 index 7bfdd8c6b7f6b0d465928f19c554e62340e5ad7b..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/losses.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from torch import nn - - -def get_loss(name): - if name == "cosface": - return CosFace() - elif name == "arcface": - return ArcFace() - else: - raise ValueError() - - -class CosFace(nn.Module): - def __init__(self, s=64.0, m=0.40): - super(CosFace, self).__init__() - self.s = s - self.m = m - - def forward(self, cosine, label): - index = torch.where(label != -1)[0] - m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) - m_hot.scatter_(1, label[index, None], self.m) - cosine[index] -= m_hot - ret = cosine * self.s - return ret - - -class ArcFace(nn.Module): - def __init__(self, s=64.0, m=0.5): - super(ArcFace, self).__init__() - self.s = s - self.m = m - - def forward(self, cosine: torch.Tensor, label): - index = torch.where(label != -1)[0] - m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) - m_hot.scatter_(1, label[index, None], self.m) - cosine.acos_() - cosine[index] += m_hot - cosine.cos_().mul_(self.s) - return cosine diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_helper.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_helper.py deleted file mode 100644 index 4a01a46621dc0ea695bd903de5d1e212d424c860..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_helper.py +++ /dev/null @@ -1,250 +0,0 @@ -from __future__ import division -import datetime -import os -import os.path as osp -import glob -import numpy as np -import cv2 -import sys -import onnxruntime -import onnx -import argparse -from onnx import numpy_helper -from insightface.data import get_image - -class ArcFaceORT: - def __init__(self, model_path, cpu=False): - self.model_path = model_path - # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" - self.providers = ['CPUExecutionProvider'] if cpu else None - - #input_size is (w,h), return error message, return None if success - def check(self, track='cfat', test_img = None): - #default is cfat - max_model_size_mb=1024 - max_feat_dim=512 - max_time_cost=15 - if track.startswith('ms1m'): - max_model_size_mb=1024 - max_feat_dim=512 - max_time_cost=10 - elif track.startswith('glint'): - max_model_size_mb=1024 - max_feat_dim=1024 - max_time_cost=20 - elif track.startswith('cfat'): - max_model_size_mb = 1024 - max_feat_dim = 512 - max_time_cost = 15 - elif track.startswith('unconstrained'): - max_model_size_mb=1024 - max_feat_dim=1024 - max_time_cost=30 - else: - return "track not found" - - if not os.path.exists(self.model_path): - return "model_path not exists" - if not os.path.isdir(self.model_path): - return "model_path should be directory" - onnx_files = [] - for _file in os.listdir(self.model_path): - if _file.endswith('.onnx'): - onnx_files.append(osp.join(self.model_path, _file)) - if len(onnx_files)==0: - return "do not have onnx files" - self.model_file = sorted(onnx_files)[-1] - print('use onnx-model:', self.model_file) - try: - session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) - except: - return "load onnx failed" - input_cfg = session.get_inputs()[0] - input_shape = input_cfg.shape - print('input-shape:', input_shape) - if len(input_shape)!=4: - return "length of input_shape should be 4" - if not isinstance(input_shape[0], str): - #return "input_shape[0] should be str to support batch-inference" - print('reset input-shape[0] to None') - model = onnx.load(self.model_file) - model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' - new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') - onnx.save(model, new_model_file) - self.model_file = new_model_file - print('use new onnx-model:', self.model_file) - try: - session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) - except: - return "load onnx failed" - input_cfg = session.get_inputs()[0] - input_shape = input_cfg.shape - print('new-input-shape:', input_shape) - - self.image_size = tuple(input_shape[2:4][::-1]) - #print('image_size:', self.image_size) - input_name = input_cfg.name - outputs = session.get_outputs() - output_names = [] - for o in outputs: - output_names.append(o.name) - #print(o.name, o.shape) - if len(output_names)!=1: - return "number of output nodes should be 1" - self.session = session - self.input_name = input_name - self.output_names = output_names - #print(self.output_names) - model = onnx.load(self.model_file) - graph = model.graph - if len(graph.node)<8: - return "too small onnx graph" - - input_size = (112,112) - self.crop = None - if track=='cfat': - crop_file = osp.join(self.model_path, 'crop.txt') - if osp.exists(crop_file): - lines = open(crop_file,'r').readlines() - if len(lines)!=6: - return "crop.txt should contain 6 lines" - lines = [int(x) for x in lines] - self.crop = lines[:4] - input_size = tuple(lines[4:6]) - if input_size!=self.image_size: - return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) - - self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) - if self.model_size_mb > max_model_size_mb: - return "max model size exceed, given %.3f-MB"%self.model_size_mb - - input_mean = None - input_std = None - if track=='cfat': - pn_file = osp.join(self.model_path, 'pixel_norm.txt') - if osp.exists(pn_file): - lines = open(pn_file,'r').readlines() - if len(lines)!=2: - return "pixel_norm.txt should contain 2 lines" - input_mean = float(lines[0]) - input_std = float(lines[1]) - if input_mean is not None or input_std is not None: - if input_mean is None or input_std is None: - return "please set input_mean and input_std simultaneously" - else: - find_sub = False - find_mul = False - for nid, node in enumerate(graph.node[:8]): - print(nid, node.name) - if node.name.startswith('Sub') or node.name.startswith('_minus'): - find_sub = True - if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): - find_mul = True - if find_sub and find_mul: - print("find sub and mul") - #mxnet arcface model - input_mean = 0.0 - input_std = 1.0 - else: - input_mean = 127.5 - input_std = 127.5 - self.input_mean = input_mean - self.input_std = input_std - for initn in graph.initializer: - weight_array = numpy_helper.to_array(initn) - dt = weight_array.dtype - if dt.itemsize<4: - return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) - if test_img is None: - test_img = get_image('Tom_Hanks_54745') - test_img = cv2.resize(test_img, self.image_size) - else: - test_img = cv2.resize(test_img, self.image_size) - feat, cost = self.benchmark(test_img) - batch_result = self.check_batch(test_img) - batch_result_sum = float(np.sum(batch_result)) - if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: - print(batch_result) - print(batch_result_sum) - return "batch result output contains NaN!" - - if len(feat.shape) < 2: - return "the shape of the feature must be two, but get {}".format(str(feat.shape)) - - if feat.shape[1] > max_feat_dim: - return "max feat dim exceed, given %d"%feat.shape[1] - self.feat_dim = feat.shape[1] - cost_ms = cost*1000 - if cost_ms>max_time_cost: - return "max time cost exceed, given %.4f"%cost_ms - self.cost_ms = cost_ms - print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) - return None - - def check_batch(self, img): - if not isinstance(img, list): - imgs = [img, ] * 32 - if self.crop is not None: - nimgs = [] - for img in imgs: - nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] - if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: - nimg = cv2.resize(nimg, self.image_size) - nimgs.append(nimg) - imgs = nimgs - blob = cv2.dnn.blobFromImages( - images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, - mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) - net_out = self.session.run(self.output_names, {self.input_name: blob})[0] - return net_out - - - def meta_info(self): - return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} - - - def forward(self, imgs): - if not isinstance(imgs, list): - imgs = [imgs] - input_size = self.image_size - if self.crop is not None: - nimgs = [] - for img in imgs: - nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] - if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: - nimg = cv2.resize(nimg, input_size) - nimgs.append(nimg) - imgs = nimgs - blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) - net_out = self.session.run(self.output_names, {self.input_name : blob})[0] - return net_out - - def benchmark(self, img): - input_size = self.image_size - if self.crop is not None: - nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] - if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: - nimg = cv2.resize(nimg, input_size) - img = nimg - blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) - costs = [] - for _ in range(50): - ta = datetime.datetime.now() - net_out = self.session.run(self.output_names, {self.input_name : blob})[0] - tb = datetime.datetime.now() - cost = (tb-ta).total_seconds() - costs.append(cost) - costs = sorted(costs) - cost = costs[5] - return net_out, cost - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='') - # general - parser.add_argument('workdir', help='submitted work dir', type=str) - parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') - args = parser.parse_args() - handler = ArcFaceORT(args.workdir) - err = handler.check(args.track) - print('err:', err) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_ijbc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_ijbc.py deleted file mode 100644 index aa96b96745e23d4d6642d99f71456c10af5e4e4e..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/onnx_ijbc.py +++ /dev/null @@ -1,267 +0,0 @@ -import argparse -import os -import pickle -import timeit - -import cv2 -import mxnet as mx -import numpy as np -import pandas as pd -import prettytable -import skimage.transform -from sklearn.metrics import roc_curve -from sklearn.preprocessing import normalize - -from onnx_helper import ArcFaceORT - -SRC = np.array( - [ - [30.2946, 51.6963], - [65.5318, 51.5014], - [48.0252, 71.7366], - [33.5493, 92.3655], - [62.7299, 92.2041]] - , dtype=np.float32) -SRC[:, 0] += 8.0 - - -class AlignedDataSet(mx.gluon.data.Dataset): - def __init__(self, root, lines, align=True): - self.lines = lines - self.root = root - self.align = align - - def __len__(self): - return len(self.lines) - - def __getitem__(self, idx): - each_line = self.lines[idx] - name_lmk_score = each_line.strip().split(' ') - name = os.path.join(self.root, name_lmk_score[0]) - img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) - landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) - st = skimage.transform.SimilarityTransform() - st.estimate(landmark5, SRC) - img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) - img_1 = np.expand_dims(img, 0) - img_2 = np.expand_dims(np.fliplr(img), 0) - output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) - output = np.transpose(output, (0, 3, 1, 2)) - output = mx.nd.array(output) - return output - - -def extract(model_root, dataset): - model = ArcFaceORT(model_path=model_root) - model.check() - feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) - - def batchify_fn(data): - return mx.nd.concat(*data, dim=0) - - data_loader = mx.gluon.data.DataLoader( - dataset, 128, last_batch='keep', num_workers=4, - thread_pool=True, prefetch=16, batchify_fn=batchify_fn) - num_iter = 0 - for batch in data_loader: - batch = batch.asnumpy() - batch = (batch - model.input_mean) / model.input_std - feat = model.session.run(model.output_names, {model.input_name: batch})[0] - feat = np.reshape(feat, (-1, model.feat_dim * 2)) - feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat - num_iter += 1 - if num_iter % 50 == 0: - print(num_iter) - return feat_mat - - -def read_template_media_list(path): - ijb_meta = pd.read_csv(path, sep=' ', header=None).values - templates = ijb_meta[:, 1].astype(np.int) - medias = ijb_meta[:, 2].astype(np.int) - return templates, medias - - -def read_template_pair_list(path): - pairs = pd.read_csv(path, sep=' ', header=None).values - t1 = pairs[:, 0].astype(np.int) - t2 = pairs[:, 1].astype(np.int) - label = pairs[:, 2].astype(np.int) - return t1, t2, label - - -def read_image_feature(path): - with open(path, 'rb') as fid: - img_feats = pickle.load(fid) - return img_feats - - -def image2template_feature(img_feats=None, - templates=None, - medias=None): - unique_templates = np.unique(templates) - template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) - for count_template, uqt in enumerate(unique_templates): - (ind_t,) = np.where(templates == uqt) - face_norm_feats = img_feats[ind_t] - face_medias = medias[ind_t] - unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) - media_norm_feats = [] - for u, ct in zip(unique_medias, unique_media_counts): - (ind_m,) = np.where(face_medias == u) - if ct == 1: - media_norm_feats += [face_norm_feats[ind_m]] - else: # image features from the same video will be aggregated into one feature - media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] - media_norm_feats = np.array(media_norm_feats) - template_feats[count_template] = np.sum(media_norm_feats, axis=0) - if count_template % 2000 == 0: - print('Finish Calculating {} template features.'.format( - count_template)) - template_norm_feats = normalize(template_feats) - return template_norm_feats, unique_templates - - -def verification(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - score = np.zeros((len(p1),)) - total_pairs = np.array(range(len(p1))) - batchsize = 100000 - sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -def verification2(template_norm_feats=None, - unique_templates=None, - p1=None, - p2=None): - template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) - for count_template, uqt in enumerate(unique_templates): - template2id[uqt] = count_template - score = np.zeros((len(p1),)) # save cosine distance between pairs - total_pairs = np.array(range(len(p1))) - batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation - sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] - total_sublists = len(sublists) - for c, s in enumerate(sublists): - feat1 = template_norm_feats[template2id[p1[s]]] - feat2 = template_norm_feats[template2id[p2[s]]] - similarity_score = np.sum(feat1 * feat2, -1) - score[s] = similarity_score.flatten() - if c % 10 == 0: - print('Finish {}/{} pairs.'.format(c, total_sublists)) - return score - - -def main(args): - use_norm_score = True # if Ture, TestMode(N1) - use_detector_score = True # if Ture, TestMode(D1) - use_flip_test = True # if Ture, TestMode(F1) - assert args.target == 'IJBC' or args.target == 'IJBB' - - start = timeit.default_timer() - templates, medias = read_template_media_list( - os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - - start = timeit.default_timer() - p1, p2, label = read_template_pair_list( - os.path.join('%s/meta' % args.image_path, - '%s_template_pair_label.txt' % args.target.lower())) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - - start = timeit.default_timer() - img_path = '%s/loose_crop' % args.image_path - img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) - img_list = open(img_list_path) - files = img_list.readlines() - dataset = AlignedDataSet(root=img_path, lines=files, align=True) - img_feats = extract(args.model_root, dataset) - - faceness_scores = [] - for each_line in files: - name_lmk_score = each_line.split() - faceness_scores.append(name_lmk_score[-1]) - faceness_scores = np.array(faceness_scores).astype(np.float32) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) - start = timeit.default_timer() - - if use_flip_test: - img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] - else: - img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] - - if use_norm_score: - img_input_feats = img_input_feats - else: - img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) - - if use_detector_score: - print(img_input_feats.shape, faceness_scores.shape) - img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] - else: - img_input_feats = img_input_feats - - template_norm_feats, unique_templates = image2template_feature( - img_input_feats, templates, medias) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - - start = timeit.default_timer() - score = verification(template_norm_feats, unique_templates, p1, p2) - stop = timeit.default_timer() - print('Time: %.2f s. ' % (stop - start)) - save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) - if not os.path.exists(save_path): - os.makedirs(save_path) - score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) - np.save(score_save_file, score) - files = [score_save_file] - methods = [] - scores = [] - for file in files: - methods.append(os.path.basename(file)) - scores.append(np.load(file)) - methods = np.array(methods) - scores = dict(zip(methods, scores)) - x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] - tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) - for method in methods: - fpr, tpr, _ = roc_curve(label, scores[method]) - fpr = np.flipud(fpr) - tpr = np.flipud(tpr) - tpr_fpr_row = [] - tpr_fpr_row.append("%s-%s" % (method, args.target)) - for fpr_iter in np.arange(len(x_labels)): - _, min_index = min( - list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) - tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) - tpr_fpr_table.add_row(tpr_fpr_row) - print(tpr_fpr_table) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='do ijb test') - # general - parser.add_argument('--model-root', default='', help='path to load model.') - parser.add_argument('--image-path', default='', type=str, help='') - parser.add_argument('--result-dir', default='.', type=str, help='') - parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') - main(parser.parse_args()) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/partial_fc.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/partial_fc.py deleted file mode 100644 index e0286dd437319c920ecb61f4eb3a32333dcf49eb..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/partial_fc.py +++ /dev/null @@ -1,222 +0,0 @@ -import logging -import os - -import torch -import torch.distributed as dist -from torch.nn import Module -from torch.nn.functional import normalize, linear -from torch.nn.parameter import Parameter - - -class PartialFC(Module): - """ - Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, - Partial FC: Training 10 Million Identities on a Single Machine - See the original paper: - https://arxiv.org/abs/2010.05222 - """ - - @torch.no_grad() - def __init__(self, rank, local_rank, world_size, batch_size, resume, - margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): - """ - rank: int - Unique process(GPU) ID from 0 to world_size - 1. - local_rank: int - Unique process(GPU) ID within the server from 0 to 7. - world_size: int - Number of GPU. - batch_size: int - Batch size on current rank(GPU). - resume: bool - Select whether to restore the weight of softmax. - margin_softmax: callable - A function of margin softmax, eg: cosface, arcface. - num_classes: int - The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, - required. - sample_rate: float - The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling - can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. - embedding_size: int - The feature dimension, default is 512. - prefix: str - Path for save checkpoint, default is './'. - """ - super(PartialFC, self).__init__() - # - self.num_classes: int = num_classes - self.rank: int = rank - self.local_rank: int = local_rank - self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) - self.world_size: int = world_size - self.batch_size: int = batch_size - self.margin_softmax: callable = margin_softmax - self.sample_rate: float = sample_rate - self.embedding_size: int = embedding_size - self.prefix: str = prefix - self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) - self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) - self.num_sample: int = int(self.sample_rate * self.num_local) - - self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) - self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) - - if resume: - try: - self.weight: torch.Tensor = torch.load(self.weight_name) - self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) - if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: - raise IndexError - logging.info("softmax weight resume successfully!") - logging.info("softmax weight mom resume successfully!") - except (FileNotFoundError, KeyError, IndexError): - self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) - self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) - logging.info("softmax weight init!") - logging.info("softmax weight mom init!") - else: - self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) - self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) - logging.info("softmax weight init successfully!") - logging.info("softmax weight mom init successfully!") - self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) - - self.index = None - if int(self.sample_rate) == 1: - self.update = lambda: 0 - self.sub_weight = Parameter(self.weight) - self.sub_weight_mom = self.weight_mom - else: - self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) - - def save_params(self): - """ Save softmax weight for each rank on prefix - """ - torch.save(self.weight.data, self.weight_name) - torch.save(self.weight_mom, self.weight_mom_name) - - @torch.no_grad() - def sample(self, total_label): - """ - Sample all positive class centers in each rank, and random select neg class centers to filling a fixed - `num_sample`. - - total_label: tensor - Label after all gather, which cross all GPUs. - """ - index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) - total_label[~index_positive] = -1 - total_label[index_positive] -= self.class_start - if int(self.sample_rate) != 1: - positive = torch.unique(total_label[index_positive], sorted=True) - if self.num_sample - positive.size(0) >= 0: - perm = torch.rand(size=[self.num_local], device=self.device) - perm[positive] = 2.0 - index = torch.topk(perm, k=self.num_sample)[1] - index = index.sort()[0] - else: - index = positive - self.index = index - total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) - self.sub_weight = Parameter(self.weight[index]) - self.sub_weight_mom = self.weight_mom[index] - - def forward(self, total_features, norm_weight): - """ Partial fc forward, `logits = X * sample(W)` - """ - torch.cuda.current_stream().wait_stream(self.stream) - logits = linear(total_features, norm_weight) - return logits - - @torch.no_grad() - def update(self): - """ Set updated weight and weight_mom to memory bank. - """ - self.weight_mom[self.index] = self.sub_weight_mom - self.weight[self.index] = self.sub_weight - - def prepare(self, label, optimizer): - """ - get sampled class centers for cal softmax. - - label: tensor - Label tensor on each rank. - optimizer: opt - Optimizer for partial fc, which need to get weight mom. - """ - with torch.cuda.stream(self.stream): - total_label = torch.zeros( - size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) - dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) - self.sample(total_label) - optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) - optimizer.param_groups[-1]['params'][0] = self.sub_weight - optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom - norm_weight = normalize(self.sub_weight) - return total_label, norm_weight - - def forward_backward(self, label, features, optimizer): - """ - Partial fc forward and backward with model parallel - - label: tensor - Label tensor on each rank(GPU) - features: tensor - Features tensor on each rank(GPU) - optimizer: optimizer - Optimizer for partial fc - - Returns: - -------- - x_grad: tensor - The gradient of features. - loss_v: tensor - Loss value for cross entropy. - """ - total_label, norm_weight = self.prepare(label, optimizer) - total_features = torch.zeros( - size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) - dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) - total_features.requires_grad = True - - logits = self.forward(total_features, norm_weight) - logits = self.margin_softmax(logits, total_label) - - with torch.no_grad(): - max_fc = torch.max(logits, dim=1, keepdim=True)[0] - dist.all_reduce(max_fc, dist.ReduceOp.MAX) - - # calculate exp(logits) and all-reduce - logits_exp = torch.exp(logits - max_fc) - logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) - dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) - - # calculate prob - logits_exp.div_(logits_sum_exp) - - # get one-hot - grad = logits_exp - index = torch.where(total_label != -1)[0] - one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) - one_hot.scatter_(1, total_label[index, None], 1) - - # calculate loss - loss = torch.zeros(grad.size()[0], 1, device=grad.device) - loss[index] = grad[index].gather(1, total_label[index, None]) - dist.all_reduce(loss, dist.ReduceOp.SUM) - loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) - - # calculate grad - grad[index] -= one_hot - grad.div_(self.batch_size * self.world_size) - - logits.backward(grad) - if total_features.grad is not None: - total_features.grad.detach_() - x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) - # feature gradient all-reduce - dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) - x_grad = x_grad * self.world_size - # backward backbone - return x_grad, loss_v diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/requirement.txt b/sadtalker_video2pose/src/face3d/models/arcface_torch/requirement.txt deleted file mode 100644 index 99aef673e30b99cbe56ce82a564c1df9df24ba21..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/requirement.txt +++ /dev/null @@ -1,5 +0,0 @@ -tensorboard -easydict -mxnet -onnx -sklearn diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/run.sh b/sadtalker_video2pose/src/face3d/models/arcface_torch/run.sh deleted file mode 100644 index 67b25fd63ef3921733d81d5be844aacc5a5c84ed..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/run.sh +++ /dev/null @@ -1,2 +0,0 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 -ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/torch2onnx.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/torch2onnx.py deleted file mode 100644 index 458660df7cc7f9a567aaf492c45f232e776a9ef0..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/torch2onnx.py +++ /dev/null @@ -1,59 +0,0 @@ -import numpy as np -import onnx -import torch - - -def convert_onnx(net, path_module, output, opset=11, simplify=False): - assert isinstance(net, torch.nn.Module) - img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) - img = img.astype(np.float) - img = (img / 255. - 0.5) / 0.5 # torch style norm - img = img.transpose((2, 0, 1)) - img = torch.from_numpy(img).unsqueeze(0).float() - - weight = torch.load(path_module) - net.load_state_dict(weight) - net.eval() - torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) - model = onnx.load(output) - graph = model.graph - graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' - if simplify: - from onnxsim import simplify - model, check = simplify(model) - assert check, "Simplified ONNX model could not be validated" - onnx.save(model, output) - - -if __name__ == '__main__': - import os - import argparse - from backbones import get_model - - parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') - parser.add_argument('input', type=str, help='input backbone.pth file or path') - parser.add_argument('--output', type=str, default=None, help='output onnx path') - parser.add_argument('--network', type=str, default=None, help='backbone network') - parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') - args = parser.parse_args() - input_file = args.input - if os.path.isdir(input_file): - input_file = os.path.join(input_file, "backbone.pth") - assert os.path.exists(input_file) - model_name = os.path.basename(os.path.dirname(input_file)).lower() - params = model_name.split("_") - if len(params) >= 3 and params[1] in ('arcface', 'cosface'): - if args.network is None: - args.network = params[2] - assert args.network is not None - print(args) - backbone_onnx = get_model(args.network, dropout=0) - - output_path = args.output - if output_path is None: - output_path = os.path.join(os.path.dirname(__file__), 'onnx') - if not os.path.exists(output_path): - os.makedirs(output_path) - assert os.path.isdir(output_path) - output_file = os.path.join(output_path, "%s.onnx" % model_name) - convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/train.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/train.py deleted file mode 100644 index 0c5491de9af8fc7a2f3d0648c53b89584864f20e..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/train.py +++ /dev/null @@ -1,141 +0,0 @@ -import argparse -import logging -import os - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.utils.data.distributed -from torch.nn.utils import clip_grad_norm_ - -import losses -from backbones import get_model -from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX -from partial_fc import PartialFC -from utils.utils_amp import MaxClipGradScaler -from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint -from utils.utils_config import get_config -from utils.utils_logging import AverageMeter, init_logging - - -def main(args): - cfg = get_config(args.config) - try: - world_size = int(os.environ['WORLD_SIZE']) - rank = int(os.environ['RANK']) - dist.init_process_group('nccl') - except KeyError: - world_size = 1 - rank = 0 - dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) - - local_rank = args.local_rank - torch.cuda.set_device(local_rank) - os.makedirs(cfg.output, exist_ok=True) - init_logging(rank, cfg.output) - - if cfg.rec == "synthetic": - train_set = SyntheticDataset(local_rank=local_rank) - else: - train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) - - train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) - train_loader = DataLoaderX( - local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, - sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) - backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) - - if cfg.resume: - try: - backbone_pth = os.path.join(cfg.output, "backbone.pth") - backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) - if rank == 0: - logging.info("backbone resume successfully!") - except (FileNotFoundError, KeyError, IndexError, RuntimeError): - if rank == 0: - logging.info("resume fail, backbone init successfully!") - - backbone = torch.nn.parallel.DistributedDataParallel( - module=backbone, broadcast_buffers=False, device_ids=[local_rank]) - backbone.train() - margin_softmax = losses.get_loss(cfg.loss) - module_partial_fc = PartialFC( - rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, - batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, - sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) - - opt_backbone = torch.optim.SGD( - params=[{'params': backbone.parameters()}], - lr=cfg.lr / 512 * cfg.batch_size * world_size, - momentum=0.9, weight_decay=cfg.weight_decay) - opt_pfc = torch.optim.SGD( - params=[{'params': module_partial_fc.parameters()}], - lr=cfg.lr / 512 * cfg.batch_size * world_size, - momentum=0.9, weight_decay=cfg.weight_decay) - - num_image = len(train_set) - total_batch_size = cfg.batch_size * world_size - cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch - cfg.total_step = num_image // total_batch_size * cfg.num_epoch - - def lr_step_func(current_step): - cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] - if current_step < cfg.warmup_step: - return current_step / cfg.warmup_step - else: - return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) - - scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( - optimizer=opt_backbone, lr_lambda=lr_step_func) - scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( - optimizer=opt_pfc, lr_lambda=lr_step_func) - - for key, value in cfg.items(): - num_space = 25 - len(key) - logging.info(": " + key + " " * num_space + str(value)) - - val_target = cfg.val_targets - callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) - callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) - callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) - - loss = AverageMeter() - start_epoch = 0 - global_step = 0 - grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None - for epoch in range(start_epoch, cfg.num_epoch): - train_sampler.set_epoch(epoch) - for step, (img, label) in enumerate(train_loader): - global_step += 1 - features = F.normalize(backbone(img)) - x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) - if cfg.fp16: - features.backward(grad_amp.scale(x_grad)) - grad_amp.unscale_(opt_backbone) - clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) - grad_amp.step(opt_backbone) - grad_amp.update() - else: - features.backward(x_grad) - clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) - opt_backbone.step() - - opt_pfc.step() - module_partial_fc.update() - opt_backbone.zero_grad() - opt_pfc.zero_grad() - loss.update(loss_v, 1) - callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) - callback_verification(global_step, backbone) - scheduler_backbone.step() - scheduler_pfc.step() - callback_checkpoint(global_step, backbone, module_partial_fc) - dist.destroy_process_group() - - -if __name__ == "__main__": - torch.backends.cudnn.benchmark = True - parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') - parser.add_argument('config', type=str, help='py config file') - parser.add_argument('--local_rank', type=int, default=0, help='local_rank') - main(parser.parse_args()) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/__init__.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/plot.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/plot.py deleted file mode 100644 index 4fce6cc0ae526d5aebc8e7a1550300ceae3a2034..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/plot.py +++ /dev/null @@ -1,72 +0,0 @@ -# coding: utf-8 - -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap -from prettytable import PrettyTable -from sklearn.metrics import roc_curve, auc - -image_path = "/data/anxiang/IJB_release/IJBC" -files = [ - "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" -] - - -def read_template_pair_list(path): - pairs = pd.read_csv(path, sep=' ', header=None).values - t1 = pairs[:, 0].astype(np.int) - t2 = pairs[:, 1].astype(np.int) - label = pairs[:, 2].astype(np.int) - return t1, t2, label - - -p1, p2, label = read_template_pair_list( - os.path.join('%s/meta' % image_path, - '%s_template_pair_label.txt' % 'ijbc')) - -methods = [] -scores = [] -for file in files: - methods.append(file.split('/')[-2]) - scores.append(np.load(file)) - -methods = np.array(methods) -scores = dict(zip(methods, scores)) -colours = dict( - zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) -x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] -tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) -fig = plt.figure() -for method in methods: - fpr, tpr, _ = roc_curve(label, scores[method]) - roc_auc = auc(fpr, tpr) - fpr = np.flipud(fpr) - tpr = np.flipud(tpr) # select largest tpr at same fpr - plt.plot(fpr, - tpr, - color=colours[method], - lw=1, - label=('[%s (AUC = %0.4f %%)]' % - (method.split('-')[-1], roc_auc * 100))) - tpr_fpr_row = [] - tpr_fpr_row.append("%s-%s" % (method, "IJBC")) - for fpr_iter in np.arange(len(x_labels)): - _, min_index = min( - list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) - tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) - tpr_fpr_table.add_row(tpr_fpr_row) -plt.xlim([10 ** -6, 0.1]) -plt.ylim([0.3, 1.0]) -plt.grid(linestyle='--', linewidth=1) -plt.xticks(x_labels) -plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) -plt.xscale('log') -plt.xlabel('False Positive Rate') -plt.ylabel('True Positive Rate') -plt.title('ROC on IJB') -plt.legend(loc="lower right") -print(tpr_fpr_table) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_amp.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_amp.py deleted file mode 100644 index a6d5bcbb540ff8b04535e71c0057e124338df5bd..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_amp.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Dict, List - -import torch - -if torch.__version__ < '1.9': - Iterable = torch._six.container_abcs.Iterable -else: - import collections - - Iterable = collections.abc.Iterable -from torch.cuda.amp import GradScaler - - -class _MultiDeviceReplicator(object): - """ - Lazily serves copies of a tensor to requested devices. Copies are cached per-device. - """ - - def __init__(self, master_tensor: torch.Tensor) -> None: - assert master_tensor.is_cuda - self.master = master_tensor - self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} - - def get(self, device) -> torch.Tensor: - retval = self._per_device_tensors.get(device, None) - if retval is None: - retval = self.master.to(device=device, non_blocking=True, copy=True) - self._per_device_tensors[device] = retval - return retval - - -class MaxClipGradScaler(GradScaler): - def __init__(self, init_scale, max_scale: float, growth_interval=100): - GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) - self.max_scale = max_scale - - def scale_clip(self): - if self.get_scale() == self.max_scale: - self.set_growth_factor(1) - elif self.get_scale() < self.max_scale: - self.set_growth_factor(2) - elif self.get_scale() > self.max_scale: - self._scale.fill_(self.max_scale) - self.set_growth_factor(1) - - def scale(self, outputs): - """ - Multiplies ('scales') a tensor or list of tensors by the scale factor. - - Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned - unmodified. - - Arguments: - outputs (Tensor or iterable of Tensors): Outputs to scale. - """ - if not self._enabled: - return outputs - self.scale_clip() - # Short-circuit for the common case. - if isinstance(outputs, torch.Tensor): - assert outputs.is_cuda - if self._scale is None: - self._lazy_init_scale_growth_tracker(outputs.device) - assert self._scale is not None - return outputs * self._scale.to(device=outputs.device, non_blocking=True) - - # Invoke the more complex machinery only if we're treating multiple outputs. - stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale - - def apply_scale(val): - if isinstance(val, torch.Tensor): - assert val.is_cuda - if len(stash) == 0: - if self._scale is None: - self._lazy_init_scale_growth_tracker(val.device) - assert self._scale is not None - stash.append(_MultiDeviceReplicator(self._scale)) - return val * stash[0].get(val.device) - elif isinstance(val, Iterable): - iterable = map(apply_scale, val) - if isinstance(val, list) or isinstance(val, tuple): - return type(val)(iterable) - else: - return iterable - else: - raise ValueError("outputs must be a Tensor or an iterable of Tensors") - - return apply_scale(outputs) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py deleted file mode 100644 index 748923b36358bd118efa0532a6f512b6ca96ff34..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_callbacks.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -import os -import time -from typing import List - -import torch - -from eval import verification -from utils.utils_logging import AverageMeter - - -class CallBackVerification(object): - def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): - self.frequent: int = frequent - self.rank: int = rank - self.highest_acc: float = 0.0 - self.highest_acc_list: List[float] = [0.0] * len(val_targets) - self.ver_list: List[object] = [] - self.ver_name_list: List[str] = [] - if self.rank is 0: - self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) - - def ver_test(self, backbone: torch.nn.Module, global_step: int): - results = [] - for i in range(len(self.ver_list)): - acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( - self.ver_list[i], backbone, 10, 10) - logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) - logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) - if acc2 > self.highest_acc_list[i]: - self.highest_acc_list[i] = acc2 - logging.info( - '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) - results.append(acc2) - - def init_dataset(self, val_targets, data_dir, image_size): - for name in val_targets: - path = os.path.join(data_dir, name + ".bin") - if os.path.exists(path): - data_set = verification.load_bin(path, image_size) - self.ver_list.append(data_set) - self.ver_name_list.append(name) - - def __call__(self, num_update, backbone: torch.nn.Module): - if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: - backbone.eval() - self.ver_test(backbone, num_update) - backbone.train() - - -class CallBackLogging(object): - def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): - self.frequent: int = frequent - self.rank: int = rank - self.time_start = time.time() - self.total_step: int = total_step - self.batch_size: int = batch_size - self.world_size: int = world_size - self.writer = writer - - self.init = False - self.tic = 0 - - def __call__(self, - global_step: int, - loss: AverageMeter, - epoch: int, - fp16: bool, - learning_rate: float, - grad_scaler: torch.cuda.amp.GradScaler): - if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: - if self.init: - try: - speed: float = self.frequent * self.batch_size / (time.time() - self.tic) - speed_total = speed * self.world_size - except ZeroDivisionError: - speed_total = float('inf') - - time_now = (time.time() - self.time_start) / 3600 - time_total = time_now / ((global_step + 1) / self.total_step) - time_for_end = time_total - time_now - if self.writer is not None: - self.writer.add_scalar('time_for_end', time_for_end, global_step) - self.writer.add_scalar('learning_rate', learning_rate, global_step) - self.writer.add_scalar('loss', loss.avg, global_step) - if fp16: - msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ - "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( - speed_total, loss.avg, learning_rate, epoch, global_step, - grad_scaler.get_scale(), time_for_end - ) - else: - msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ - "Required: %1.f hours" % ( - speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end - ) - logging.info(msg) - loss.reset() - self.tic = time.time() - else: - self.init = True - self.tic = time.time() - - -class CallBackModelCheckpoint(object): - def __init__(self, rank, output="./"): - self.rank: int = rank - self.output: str = output - - def __call__(self, global_step, backbone, partial_fc, ): - if global_step > 100 and self.rank == 0: - path_module = os.path.join(self.output, "backbone.pth") - torch.save(backbone.module.state_dict(), path_module) - logging.info("Pytorch Model Saved in '{}'".format(path_module)) - - if global_step > 100 and partial_fc is not None: - partial_fc.save_params() diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_config.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_config.py deleted file mode 100644 index b60a1e5a2e860ce5511a2d3863c8b57a4df292d7..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_config.py +++ /dev/null @@ -1,16 +0,0 @@ -import importlib -import os.path as osp - - -def get_config(config_file): - assert config_file.startswith('configs/'), 'config file setting must start with configs/' - temp_config_name = osp.basename(config_file) - temp_module_name = osp.splitext(temp_config_name)[0] - config = importlib.import_module("configs.base") - cfg = config.config - config = importlib.import_module("configs.%s" % temp_module_name) - job_cfg = config.config - cfg.update(job_cfg) - if cfg.output is None: - cfg.output = osp.join('work_dirs', temp_module_name) - return cfg \ No newline at end of file diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_logging.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_logging.py deleted file mode 100644 index f2b43b851c9e06230abd94c73a1f64cfa1b6f3ac..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_logging.py +++ /dev/null @@ -1,41 +0,0 @@ -import logging -import os -import sys - - -class AverageMeter(object): - """Computes and stores the average and current value - """ - - def __init__(self): - self.val = None - self.avg = None - self.sum = None - self.count = None - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def init_logging(rank, models_root): - if rank == 0: - log_root = logging.getLogger() - log_root.setLevel(logging.INFO) - formatter = logging.Formatter("Training: %(asctime)s-%(message)s") - handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) - handler_stream = logging.StreamHandler(sys.stdout) - handler_file.setFormatter(formatter) - handler_stream.setFormatter(formatter) - log_root.addHandler(handler_file) - log_root.addHandler(handler_stream) - log_root.info('rank_id: %d' % rank) diff --git a/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_os.py b/sadtalker_video2pose/src/face3d/models/arcface_torch/utils/utils_os.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sadtalker_video2pose/src/face3d/models/base_model.py b/sadtalker_video2pose/src/face3d/models/base_model.py deleted file mode 100644 index b975223f6148febfe32d20d63980583c97b61eb3..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/base_model.py +++ /dev/null @@ -1,316 +0,0 @@ -"""This script defines the base network model for Deep3DFaceRecon_pytorch -""" - -import os -import numpy as np -import torch -from collections import OrderedDict -from abc import ABC, abstractmethod -from . import networks - - -class BaseModel(ABC): - """This class is an abstract base class (ABC) for models. - To create a subclass, you need to implement the following five functions: - -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). - -- : unpack data from dataset and apply preprocessing. - -- : produce intermediate results. - -- : calculate losses, gradients, and update network weights. - -- : (optionally) add model-specific options and set default options. - """ - - def __init__(self, opt): - """Initialize the BaseModel class. - - Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions - - When creating your custom class, you need to implement your own initialization. - In this fucntion, you should first call - Then, you need to define four lists: - -- self.loss_names (str list): specify the training losses that you want to plot and save. - -- self.model_names (str list): specify the images that you want to display and save. - -- self.visual_names (str list): define networks used in our training. - -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. - """ - self.opt = opt - self.isTrain = False - self.device = torch.device('cpu') - self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir - self.loss_names = [] - self.model_names = [] - self.visual_names = [] - self.parallel_names = [] - self.optimizers = [] - self.image_paths = [] - self.metric = 0 # used for learning rate policy 'plateau' - - @staticmethod - def dict_grad_hook_factory(add_func=lambda x: x): - saved_dict = dict() - - def hook_gen(name): - def grad_hook(grad): - saved_vals = add_func(grad) - saved_dict[name] = saved_vals - return grad_hook - return hook_gen, saved_dict - - @staticmethod - def modify_commandline_options(parser, is_train): - """Add new model-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - return parser - - @abstractmethod - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input (dict): includes the data itself and its metadata information. - """ - pass - - @abstractmethod - def forward(self): - """Run forward pass; called by both functions and .""" - pass - - @abstractmethod - def optimize_parameters(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - pass - - def setup(self, opt): - """Load and print networks; create schedulers - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - if self.isTrain: - self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] - - if not self.isTrain or opt.continue_train: - load_suffix = opt.epoch - self.load_networks(load_suffix) - - - # self.print_networks(opt.verbose) - - def parallelize(self, convert_sync_batchnorm=True): - if not self.opt.use_ddp: - for name in self.parallel_names: - if isinstance(name, str): - module = getattr(self, name) - setattr(self, name, module.to(self.device)) - else: - for name in self.model_names: - if isinstance(name, str): - module = getattr(self, name) - if convert_sync_batchnorm: - module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) - setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), - device_ids=[self.device.index], - find_unused_parameters=True, broadcast_buffers=True)) - - # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. - for name in self.parallel_names: - if isinstance(name, str) and name not in self.model_names: - module = getattr(self, name) - setattr(self, name, module.to(self.device)) - - # put state_dict of optimizer to gpu device - if self.opt.phase != 'test': - if self.opt.continue_train: - for optim in self.optimizers: - for state in optim.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.to(self.device) - - def data_dependent_initialize(self, data): - pass - - def train(self): - """Make models train mode""" - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - net.train() - - def eval(self): - """Make models eval mode""" - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - net.eval() - - def test(self): - """Forward function used in test time. - - This function wraps function in no_grad() so we don't save intermediate steps for backprop - It also calls to produce additional visualization results - """ - with torch.no_grad(): - self.forward() - self.compute_visuals() - - def compute_visuals(self): - """Calculate additional output images for visdom and HTML visualization""" - pass - - def get_image_paths(self, name='A'): - """ Return image paths that are used to load current data""" - return self.image_paths if name =='A' else self.image_paths_B - - def update_learning_rate(self): - """Update learning rates for all the networks; called at the end of every epoch""" - for scheduler in self.schedulers: - if self.opt.lr_policy == 'plateau': - scheduler.step(self.metric) - else: - scheduler.step() - - lr = self.optimizers[0].param_groups[0]['lr'] - print('learning rate = %.7f' % lr) - - def get_current_visuals(self): - """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" - visual_ret = OrderedDict() - for name in self.visual_names: - if isinstance(name, str): - visual_ret[name] = getattr(self, name)[:, :3, ...] - return visual_ret - - def get_current_losses(self): - """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" - errors_ret = OrderedDict() - for name in self.loss_names: - if isinstance(name, str): - errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number - return errors_ret - - def save_networks(self, epoch): - """Save all the networks to the disk. - - Parameters: - epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) - """ - if not os.path.isdir(self.save_dir): - os.makedirs(self.save_dir) - - save_filename = 'epoch_%s.pth' % (epoch) - save_path = os.path.join(self.save_dir, save_filename) - - save_dict = {} - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - if isinstance(net, torch.nn.DataParallel) or isinstance(net, - torch.nn.parallel.DistributedDataParallel): - net = net.module - save_dict[name] = net.state_dict() - - - for i, optim in enumerate(self.optimizers): - save_dict['opt_%02d'%i] = optim.state_dict() - - for i, sched in enumerate(self.schedulers): - save_dict['sched_%02d'%i] = sched.state_dict() - - torch.save(save_dict, save_path) - - def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): - """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" - key = keys[i] - if i + 1 == len(keys): # at the end, pointing to a parameter/buffer - if module.__class__.__name__.startswith('InstanceNorm') and \ - (key == 'running_mean' or key == 'running_var'): - if getattr(module, key) is None: - state_dict.pop('.'.join(keys)) - if module.__class__.__name__.startswith('InstanceNorm') and \ - (key == 'num_batches_tracked'): - state_dict.pop('.'.join(keys)) - else: - self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) - - def load_networks(self, epoch): - """Load all the networks from the disk. - - Parameters: - epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) - """ - if self.opt.isTrain and self.opt.pretrained_name is not None: - load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) - else: - load_dir = self.save_dir - load_filename = 'epoch_%s.pth' % (epoch) - load_path = os.path.join(load_dir, load_filename) - state_dict = torch.load(load_path, map_location=self.device) - print('loading the model from %s' % load_path) - - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - if isinstance(net, torch.nn.DataParallel): - net = net.module - net.load_state_dict(state_dict[name]) - - if self.opt.phase != 'test': - if self.opt.continue_train: - print('loading the optim from %s' % load_path) - for i, optim in enumerate(self.optimizers): - optim.load_state_dict(state_dict['opt_%02d'%i]) - - try: - print('loading the sched from %s' % load_path) - for i, sched in enumerate(self.schedulers): - sched.load_state_dict(state_dict['sched_%02d'%i]) - except: - print('Failed to load schedulers, set schedulers according to epoch count manually') - for i, sched in enumerate(self.schedulers): - sched.last_epoch = self.opt.epoch_count - 1 - - - - - def print_networks(self, verbose): - """Print the total number of parameters in the network and (if verbose) network architecture - - Parameters: - verbose (bool) -- if verbose: print the network architecture - """ - print('---------- Networks initialized -------------') - for name in self.model_names: - if isinstance(name, str): - net = getattr(self, name) - num_params = 0 - for param in net.parameters(): - num_params += param.numel() - if verbose: - print(net) - print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) - print('-----------------------------------------------') - - def set_requires_grad(self, nets, requires_grad=False): - """Set requies_grad=Fasle for all the networks to avoid unnecessary computations - Parameters: - nets (network list) -- a list of networks - requires_grad (bool) -- whether the networks require gradients or not - """ - if not isinstance(nets, list): - nets = [nets] - for net in nets: - if net is not None: - for param in net.parameters(): - param.requires_grad = requires_grad - - def generate_visuals_for_evaluation(self, data, mode): - return {} diff --git a/sadtalker_video2pose/src/face3d/models/bfm.py b/sadtalker_video2pose/src/face3d/models/bfm.py deleted file mode 100644 index 0cecaf589befac790cf9c124737ba01e27bc29e6..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/bfm.py +++ /dev/null @@ -1,331 +0,0 @@ -"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch -""" - -import numpy as np -import torch -import torch.nn.functional as F -from scipy.io import loadmat -from src.face3d.util.load_mats import transferBFM09 -import os - -def perspective_projection(focal, center): - # return p.T (N, 3) @ (3, 3) - return np.array([ - focal, 0, center, - 0, focal, center, - 0, 0, 1 - ]).reshape([3, 3]).astype(np.float32).transpose() - -class SH: - def __init__(self): - self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] - self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] - - - -class ParametricFaceModel: - def __init__(self, - bfm_folder='./BFM', - recenter=True, - camera_distance=10., - init_lit=np.array([ - 0.8, 0, 0, 0, 0, 0, 0, 0, 0 - ]), - focal=1015., - center=112., - is_train=True, - default_name='BFM_model_front.mat'): - - if not os.path.isfile(os.path.join(bfm_folder, default_name)): - transferBFM09(bfm_folder) - - model = loadmat(os.path.join(bfm_folder, default_name)) - # mean face shape. [3*N,1] - self.mean_shape = model['meanshape'].astype(np.float32) - # identity basis. [3*N,80] - self.id_base = model['idBase'].astype(np.float32) - # expression basis. [3*N,64] - self.exp_base = model['exBase'].astype(np.float32) - # mean face texture. [3*N,1] (0-255) - self.mean_tex = model['meantex'].astype(np.float32) - # texture basis. [3*N,80] - self.tex_base = model['texBase'].astype(np.float32) - # face indices for each vertex that lies in. starts from 0. [N,8] - self.point_buf = model['point_buf'].astype(np.int64) - 1 - # vertex indices for each face. starts from 0. [F,3] - self.face_buf = model['tri'].astype(np.int64) - 1 - # vertex indices for 68 landmarks. starts from 0. [68,1] - self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 - - if is_train: - # vertex indices for small face region to compute photometric error. starts from 0. - self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 - # vertex indices for each face from small face region. starts from 0. [f,3] - self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 - # vertex indices for pre-defined skin region to compute reflectance loss - self.skin_mask = np.squeeze(model['skinmask']) - - if recenter: - mean_shape = self.mean_shape.reshape([-1, 3]) - mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) - self.mean_shape = mean_shape.reshape([-1, 1]) - - self.persc_proj = perspective_projection(focal, center) - self.device = 'cpu' - self.camera_distance = camera_distance - self.SH = SH() - self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) - - - def to(self, device): - self.device = device - for key, value in self.__dict__.items(): - if type(value).__module__ == np.__name__: - setattr(self, key, torch.tensor(value).to(device)) - - - def compute_shape(self, id_coeff, exp_coeff): - """ - Return: - face_shape -- torch.tensor, size (B, N, 3) - - Parameters: - id_coeff -- torch.tensor, size (B, 80), identity coeffs - exp_coeff -- torch.tensor, size (B, 64), expression coeffs - """ - batch_size = id_coeff.shape[0] - id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) - exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) - face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) - return face_shape.reshape([batch_size, -1, 3]) - - - def compute_texture(self, tex_coeff, normalize=True): - """ - Return: - face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) - - Parameters: - tex_coeff -- torch.tensor, size (B, 80) - """ - batch_size = tex_coeff.shape[0] - face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex - if normalize: - face_texture = face_texture / 255. - return face_texture.reshape([batch_size, -1, 3]) - - - def compute_norm(self, face_shape): - """ - Return: - vertex_norm -- torch.tensor, size (B, N, 3) - - Parameters: - face_shape -- torch.tensor, size (B, N, 3) - """ - - v1 = face_shape[:, self.face_buf[:, 0]] - v2 = face_shape[:, self.face_buf[:, 1]] - v3 = face_shape[:, self.face_buf[:, 2]] - e1 = v1 - v2 - e2 = v2 - v3 - face_norm = torch.cross(e1, e2, dim=-1) - face_norm = F.normalize(face_norm, dim=-1, p=2) - face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) - - vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) - vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) - return vertex_norm - - - def compute_color(self, face_texture, face_norm, gamma): - """ - Return: - face_color -- torch.tensor, size (B, N, 3), range (0, 1.) - - Parameters: - face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) - face_norm -- torch.tensor, size (B, N, 3), rotated face normal - gamma -- torch.tensor, size (B, 27), SH coeffs - """ - batch_size = gamma.shape[0] - v_num = face_texture.shape[1] - a, c = self.SH.a, self.SH.c - gamma = gamma.reshape([batch_size, 3, 9]) - gamma = gamma + self.init_lit - gamma = gamma.permute(0, 2, 1) - Y = torch.cat([ - a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), - -a[1] * c[1] * face_norm[..., 1:2], - a[1] * c[1] * face_norm[..., 2:], - -a[1] * c[1] * face_norm[..., :1], - a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], - -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], - 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), - -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], - 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) - ], dim=-1) - r = Y @ gamma[..., :1] - g = Y @ gamma[..., 1:2] - b = Y @ gamma[..., 2:] - face_color = torch.cat([r, g, b], dim=-1) * face_texture - return face_color - - - def compute_rotation(self, angles): - """ - Return: - rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat - - Parameters: - angles -- torch.tensor, size (B, 3), radian - """ - - batch_size = angles.shape[0] - ones = torch.ones([batch_size, 1]).to(self.device) - zeros = torch.zeros([batch_size, 1]).to(self.device) - x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], - - rot_x = torch.cat([ - ones, zeros, zeros, - zeros, torch.cos(x), -torch.sin(x), - zeros, torch.sin(x), torch.cos(x) - ], dim=1).reshape([batch_size, 3, 3]) - - rot_y = torch.cat([ - torch.cos(y), zeros, torch.sin(y), - zeros, ones, zeros, - -torch.sin(y), zeros, torch.cos(y) - ], dim=1).reshape([batch_size, 3, 3]) - - rot_z = torch.cat([ - torch.cos(z), -torch.sin(z), zeros, - torch.sin(z), torch.cos(z), zeros, - zeros, zeros, ones - ], dim=1).reshape([batch_size, 3, 3]) - - rot = rot_z @ rot_y @ rot_x - return rot.permute(0, 2, 1) - - - def to_camera(self, face_shape): - face_shape[..., -1] = self.camera_distance - face_shape[..., -1] - return face_shape - - def to_image(self, face_shape): - """ - Return: - face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction - - Parameters: - face_shape -- torch.tensor, size (B, N, 3) - """ - # to image_plane - face_proj = face_shape @ self.persc_proj - face_proj = face_proj[..., :2] / face_proj[..., 2:] - - return face_proj - - - def transform(self, face_shape, rot, trans): - """ - Return: - face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans - - Parameters: - face_shape -- torch.tensor, size (B, N, 3) - rot -- torch.tensor, size (B, 3, 3) - trans -- torch.tensor, size (B, 3) - """ - return face_shape @ rot + trans.unsqueeze(1) - - - def get_landmarks(self, face_proj): - """ - Return: - face_lms -- torch.tensor, size (B, 68, 2) - - Parameters: - face_proj -- torch.tensor, size (B, N, 2) - """ - return face_proj[:, self.keypoints] - - def split_coeff(self, coeffs): - """ - Return: - coeffs_dict -- a dict of torch.tensors - - Parameters: - coeffs -- torch.tensor, size (B, 256) - """ - id_coeffs = coeffs[:, :80] - exp_coeffs = coeffs[:, 80: 144] - tex_coeffs = coeffs[:, 144: 224] - angles = coeffs[:, 224: 227] - gammas = coeffs[:, 227: 254] - translations = coeffs[:, 254:] - return { - 'id': id_coeffs, - 'exp': exp_coeffs, - 'tex': tex_coeffs, - 'angle': angles, - 'gamma': gammas, - 'trans': translations - } - def compute_for_render(self, coeffs): - """ - Return: - face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate - face_color -- torch.tensor, size (B, N, 3), in RGB order - landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction - Parameters: - coeffs -- torch.tensor, size (B, 257) - """ - coef_dict = self.split_coeff(coeffs) - face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) - rotation = self.compute_rotation(coef_dict['angle']) - - - face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) - face_vertex = self.to_camera(face_shape_transformed) - - face_proj = self.to_image(face_vertex) - landmark = self.get_landmarks(face_proj) - - face_texture = self.compute_texture(coef_dict['tex']) - face_norm = self.compute_norm(face_shape) - face_norm_roted = face_norm @ rotation - face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) - - return face_vertex, face_texture, face_color, landmark - - def compute_for_render_woRotation(self, coeffs): - """ - Return: - face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate - face_color -- torch.tensor, size (B, N, 3), in RGB order - landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction - Parameters: - coeffs -- torch.tensor, size (B, 257) - """ - coef_dict = self.split_coeff(coeffs) - face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) - #rotation = self.compute_rotation(coef_dict['angle']) - - - #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) - face_vertex = self.to_camera(face_shape) - - face_proj = self.to_image(face_vertex) - landmark = self.get_landmarks(face_proj) - - face_texture = self.compute_texture(coef_dict['tex']) - face_norm = self.compute_norm(face_shape) - face_norm_roted = face_norm # @ rotation - face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) - - return face_vertex, face_texture, face_color, landmark - - -if __name__ == '__main__': - transferBFM09() \ No newline at end of file diff --git a/sadtalker_video2pose/src/face3d/models/facerecon_model.py b/sadtalker_video2pose/src/face3d/models/facerecon_model.py deleted file mode 100644 index 58a836a45a05fa192591cca5cf684783a6fb8533..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/facerecon_model.py +++ /dev/null @@ -1,220 +0,0 @@ -"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch -""" - -import numpy as np -import torch -from src.face3d.models.base_model import BaseModel -from src.face3d.models import networks -from src.face3d.models.bfm import ParametricFaceModel -from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss -from src.face3d.util import util -from src.face3d.util.nvdiffrast import MeshRenderer -# from src.face3d.util.preprocess import estimate_norm_torch - -import trimesh -from scipy.io import savemat - -class FaceReconModel(BaseModel): - - @staticmethod - def modify_commandline_options(parser, is_train=False): - """ Configures options specific for CUT model - """ - # net structure and parameters - parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') - parser.add_argument('--init_path', type=str, default='./ckpts/sad_talkers/init_model/resnet50-0676ba61.pth') - parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') - parser.add_argument('--bfm_folder', type=str, default='./ckpts/sad_talkers/BFM_Fitting/') - parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') - - # renderer parameters - parser.add_argument('--focal', type=float, default=1015.) - parser.add_argument('--center', type=float, default=112.) - parser.add_argument('--camera_d', type=float, default=10.) - parser.add_argument('--z_near', type=float, default=5.) - parser.add_argument('--z_far', type=float, default=15.) - - if is_train: - # training parameters - parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') - parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') - parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') - parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') - - - # augmentation parameters - parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') - parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') - parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') - - # loss weights - parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') - parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') - parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') - parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') - parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') - parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') - parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') - parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') - parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') - - opt, _ = parser.parse_known_args() - parser.set_defaults( - focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. - ) - if is_train: - parser.set_defaults( - use_crop_face=True, use_predef_M=False - ) - return parser - - def __init__(self, opt): - """Initialize this model class. - - Parameters: - opt -- training/test options - - A few things can be done here. - - (required) call the initialization function of BaseModel - - define loss function, visualization images, model names, and optimizers - """ - BaseModel.__init__(self, opt) # call the initialization method of BaseModel - - self.visual_names = ['output_vis'] - self.model_names = ['net_recon'] - self.parallel_names = self.model_names + ['renderer'] - - self.facemodel = ParametricFaceModel( - bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, - is_train=self.isTrain, default_name=opt.bfm_model - ) - - fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi - self.renderer = MeshRenderer( - rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center) - ) - - if self.isTrain: - self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] - - self.net_recog = networks.define_net_recog( - net_recog=opt.net_recog, pretrained_path=opt.net_recog_path - ) - # loss func name: (compute_%s_loss) % loss_name - self.compute_feat_loss = perceptual_loss - self.comupte_color_loss = photo_loss - self.compute_lm_loss = landmark_loss - self.compute_reg_loss = reg_loss - self.compute_reflc_loss = reflectance_loss - - self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) - self.optimizers = [self.optimizer] - self.parallel_names += ['net_recog'] - # Our program will automatically call to define schedulers, load networks, and print networks - - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input: a dictionary that contains the data itself and its metadata information. - """ - self.input_img = input['imgs'].to(self.device) - self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None - self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None - self.trans_m = input['M'].to(self.device) if 'M' in input else None - self.image_paths = input['im_paths'] if 'im_paths' in input else None - - def forward(self, output_coeff, device): - self.facemodel.to(device) - self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ - self.facemodel.compute_for_render(output_coeff) - self.pred_mask, _, self.pred_face = self.renderer( - self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) - - self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) - - - def compute_losses(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - - assert self.net_recog.training == False - trans_m = self.trans_m - if not self.opt.use_predef_M: - trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) - - pred_feat = self.net_recog(self.pred_face, trans_m) - gt_feat = self.net_recog(self.input_img, self.trans_m) - self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) - - face_mask = self.pred_mask - if self.opt.use_crop_face: - face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) - - face_mask = face_mask.detach() - self.loss_color = self.opt.w_color * self.comupte_color_loss( - self.pred_face, self.input_img, self.atten_mask * face_mask) - - loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) - self.loss_reg = self.opt.w_reg * loss_reg - self.loss_gamma = self.opt.w_gamma * loss_gamma - - self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) - - self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) - - self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ - + self.loss_lm + self.loss_reflc - - - def optimize_parameters(self, isTrain=True): - self.forward() - self.compute_losses() - """Update network weights; it will be called in every training iteration.""" - if isTrain: - self.optimizer.zero_grad() - self.loss_all.backward() - self.optimizer.step() - - def compute_visuals(self): - with torch.no_grad(): - input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() - output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img - output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() - - if self.gt_lm is not None: - gt_lm_numpy = self.gt_lm.cpu().numpy() - pred_lm_numpy = self.pred_lm.detach().cpu().numpy() - output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') - output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') - - output_vis_numpy = np.concatenate((input_img_numpy, - output_vis_numpy_raw, output_vis_numpy), axis=-2) - else: - output_vis_numpy = np.concatenate((input_img_numpy, - output_vis_numpy_raw), axis=-2) - - self.output_vis = torch.tensor( - output_vis_numpy / 255., dtype=torch.float32 - ).permute(0, 3, 1, 2).to(self.device) - - def save_mesh(self, name): - - recon_shape = self.pred_vertex # get reconstructed shape - recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space - recon_shape = recon_shape.cpu().numpy()[0] - recon_color = self.pred_color - recon_color = recon_color.cpu().numpy()[0] - tri = self.facemodel.face_buf.cpu().numpy() - mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) - mesh.export(name) - - def save_coeff(self,name): - - pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} - pred_lm = self.pred_lm.cpu().numpy() - pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate - pred_coeffs['lm68'] = pred_lm - savemat(name,pred_coeffs) - - - diff --git a/sadtalker_video2pose/src/face3d/models/losses.py b/sadtalker_video2pose/src/face3d/models/losses.py deleted file mode 100644 index 01d9da84f28d54e772bebd2385ae5a7fedd10f7d..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/losses.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from kornia.geometry import warp_affine -import torch.nn.functional as F - -def resize_n_crop(image, M, dsize=112): - # image: (b, c, h, w) - # M : (b, 2, 3) - return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) - -### perceptual level loss -class PerceptualLoss(nn.Module): - def __init__(self, recog_net, input_size=112): - super(PerceptualLoss, self).__init__() - self.recog_net = recog_net - self.preprocess = lambda x: 2 * x - 1 - self.input_size=input_size - def forward(imageA, imageB, M): - """ - 1 - cosine distance - Parameters: - imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order - imageB --same as imageA - """ - - imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) - imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) - - # freeze bn - self.recog_net.eval() - - id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) - id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) - cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) - # assert torch.sum((cosine_d > 1).float()) == 0 - return torch.sum(1 - cosine_d) / cosine_d.shape[0] - -def perceptual_loss(id_featureA, id_featureB): - cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) - # assert torch.sum((cosine_d > 1).float()) == 0 - return torch.sum(1 - cosine_d) / cosine_d.shape[0] - -### image level loss -def photo_loss(imageA, imageB, mask, eps=1e-6): - """ - l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) - Parameters: - imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order - imageB --same as imageA - """ - loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask - loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) - return loss - -def landmark_loss(predict_lm, gt_lm, weight=None): - """ - weighted mse loss - Parameters: - predict_lm --torch.tensor (B, 68, 2) - gt_lm --torch.tensor (B, 68, 2) - weight --numpy.array (1, 68) - """ - if not weight: - weight = np.ones([68]) - weight[28:31] = 20 - weight[-8:] = 20 - weight = np.expand_dims(weight, 0) - weight = torch.tensor(weight).to(predict_lm.device) - loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight - loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) - return loss - - -### regulization -def reg_loss(coeffs_dict, opt=None): - """ - l2 norm without the sqrt, from yu's implementation (mse) - tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss - Parameters: - coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans - - """ - # coefficient regularization to ensure plausible 3d faces - if opt: - w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex - else: - w_id, w_exp, w_tex = 1, 1, 1, 1 - creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ - w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ - w_tex * torch.sum(coeffs_dict['tex'] ** 2) - creg_loss = creg_loss / coeffs_dict['id'].shape[0] - - # gamma regularization to ensure a nearly-monochromatic light - gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) - gamma_mean = torch.mean(gamma, dim=1, keepdims=True) - gamma_loss = torch.mean((gamma - gamma_mean) ** 2) - - return creg_loss, gamma_loss - -def reflectance_loss(texture, mask): - """ - minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo - Parameters: - texture --torch.tensor, (B, N, 3) - mask --torch.tensor, (N), 1 or 0 - - """ - mask = mask.reshape([1, mask.shape[0], 1]) - texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) - loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) - return loss - diff --git a/sadtalker_video2pose/src/face3d/models/networks.py b/sadtalker_video2pose/src/face3d/models/networks.py deleted file mode 100644 index 1e69eba1ade2e6431e7e7fd526ea68b8f63e7152..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/networks.py +++ /dev/null @@ -1,521 +0,0 @@ -"""This script defines deep neural networks for Deep3DFaceRecon_pytorch -""" - -import os -import numpy as np -import torch.nn.functional as F -from torch.nn import init -import functools -from torch.optim import lr_scheduler -import torch -from torch import Tensor -import torch.nn as nn -try: - from torch.hub import load_state_dict_from_url -except ImportError: - from torch.utils.model_zoo import load_url as load_state_dict_from_url -from typing import Type, Any, Callable, Union, List, Optional -from .arcface_torch.backbones import get_model -from kornia.geometry import warp_affine - -def resize_n_crop(image, M, dsize=112): - # image: (b, c, h, w) - # M : (b, 2, 3) - return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) - -def filter_state_dict(state_dict, remove_name='fc'): - new_state_dict = {} - for key in state_dict: - if remove_name in key: - continue - new_state_dict[key] = state_dict[key] - return new_state_dict - -def get_scheduler(optimizer, opt): - """Return a learning rate scheduler - - Parameters: - optimizer -- the optimizer of the network - opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  - opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine - - For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. - See https://pytorch.org/docs/stable/optim.html for more details. - """ - if opt.lr_policy == 'linear': - def lambda_rule(epoch): - lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) - return lr_l - scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) - elif opt.lr_policy == 'step': - scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) - elif opt.lr_policy == 'plateau': - scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) - elif opt.lr_policy == 'cosine': - scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) - else: - return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) - return scheduler - - -def define_net_recon(net_recon, use_last_fc=False, init_path=None): - return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) - -def define_net_recog(net_recog, pretrained_path=None): - net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) - net.eval() - return net - -class ReconNetWrapper(nn.Module): - fc_dim=257 - def __init__(self, net_recon, use_last_fc=False, init_path=None): - super(ReconNetWrapper, self).__init__() - self.use_last_fc = use_last_fc - if net_recon not in func_dict: - return NotImplementedError('network [%s] is not implemented', net_recon) - func, last_dim = func_dict[net_recon] - backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) - if init_path and os.path.isfile(init_path): - state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) - backbone.load_state_dict(state_dict) - print("loading init net_recon %s from %s" %(net_recon, init_path)) - self.backbone = backbone - if not use_last_fc: - self.final_layers = nn.ModuleList([ - conv1x1(last_dim, 80, bias=True), # id layer - conv1x1(last_dim, 64, bias=True), # exp layer - conv1x1(last_dim, 80, bias=True), # tex layer - conv1x1(last_dim, 3, bias=True), # angle layer - conv1x1(last_dim, 27, bias=True), # gamma layer - conv1x1(last_dim, 2, bias=True), # tx, ty - conv1x1(last_dim, 1, bias=True) # tz - ]) - for m in self.final_layers: - nn.init.constant_(m.weight, 0.) - nn.init.constant_(m.bias, 0.) - - def forward(self, x): - x = self.backbone(x) - if not self.use_last_fc: - output = [] - for layer in self.final_layers: - output.append(layer(x)) - x = torch.flatten(torch.cat(output, dim=1), 1) - return x - - -class RecogNetWrapper(nn.Module): - def __init__(self, net_recog, pretrained_path=None, input_size=112): - super(RecogNetWrapper, self).__init__() - net = get_model(name=net_recog, fp16=False) - if pretrained_path: - state_dict = torch.load(pretrained_path, map_location='cpu') - net.load_state_dict(state_dict) - print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) - for param in net.parameters(): - param.requires_grad = False - self.net = net - self.preprocess = lambda x: 2 * x - 1 - self.input_size=input_size - - def forward(self, image, M): - image = self.preprocess(resize_n_crop(image, M, self.input_size)) - id_feature = F.normalize(self.net(image), dim=-1, p=2) - return id_feature - - -# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'wide_resnet50_2', 'wide_resnet101_2'] - - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', - 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', -} - - -def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - - -def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) - - -class BasicBlock(nn.Module): - expansion: int = 1 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - expansion: int = 4 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(Bottleneck, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - - def __init__( - self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 1000, - zero_init_residual: bool = False, - use_last_fc: bool = False, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None - ) -> None: - super(ResNet, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - self._norm_layer = norm_layer - - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.use_last_fc = use_last_fc - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - - if self.use_last_fc: - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] - - def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, - stride: int = 1, dilate: bool = False) -> nn.Sequential: - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def _forward_impl(self, x: Tensor) -> Tensor: - # See note [TorchScript super()] - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - if self.use_last_fc: - x = torch.flatten(x, 1) - x = self.fc(x) - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) - - -def _resnet( - arch: str, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - pretrained: bool, - progress: bool, - **kwargs: Any -) -> ResNet: - model = ResNet(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) - model.load_state_dict(state_dict) - return model - - -def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-18 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, - **kwargs) - - -def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-34 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-50 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-101 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, - **kwargs) - - -def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNet-152 model from - `"Deep Residual Learning for Image Recognition" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, - **kwargs) - - -def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNeXt-50 32x4d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 4 - return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - -def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""ResNeXt-101 32x8d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) - - -def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""Wide ResNet-50-2 model from - `"Wide Residual Networks" `_. - - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - -def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - r"""Wide ResNet-101-2 model from - `"Wide Residual Networks" `_. - - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) - - -func_dict = { - 'resnet18': (resnet18, 512), - 'resnet50': (resnet50, 2048) -} diff --git a/sadtalker_video2pose/src/face3d/models/template_model.py b/sadtalker_video2pose/src/face3d/models/template_model.py deleted file mode 100644 index 75860272a06312bfa4de382729dce5136a480a7f..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/models/template_model.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Model class template - -This module provides a template for users to implement custom models. -You can specify '--model template' to use this model. -The class name should be consistent with both the filename and its model option. -The filename should be _dataset.py -The class name should be Dataset.py -It implements a simple image-to-image translation baseline based on regression loss. -Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: - min_ ||netG(data_A) - data_B||_1 -You need to implement the following functions: - : Add model-specific options and rewrite default values for existing options. - <__init__>: Initialize this model class. - : Unpack input data and perform data pre-processing. - : Run forward pass. This will be called by both and . - : Update network weights; it will be called in every training iteration. -""" -import numpy as np -import torch -from .base_model import BaseModel -from . import networks - - -class TemplateModel(BaseModel): - @staticmethod - def modify_commandline_options(parser, is_train=True): - """Add new model-specific options and rewrite default values for existing options. - - Parameters: - parser -- the option parser - is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - """ - parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. - if is_train: - parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. - - return parser - - def __init__(self, opt): - """Initialize this model class. - - Parameters: - opt -- training/test options - - A few things can be done here. - - (required) call the initialization function of BaseModel - - define loss function, visualization images, model names, and optimizers - """ - BaseModel.__init__(self, opt) # call the initialization method of BaseModel - # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. - self.loss_names = ['loss_G'] - # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. - self.visual_names = ['data_A', 'data_B', 'output'] - # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. - # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. - self.model_names = ['G'] - # define networks; you can use opt.isTrain to specify different behaviors for training and test. - self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) - if self.isTrain: # only defined during training time - # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. - # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) - self.criterionLoss = torch.nn.L1Loss() - # define and initialize optimizers. You can define one optimizer for each network. - # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. - self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) - self.optimizers = [self.optimizer] - - # Our program will automatically call to define schedulers, load networks, and print networks - - def set_input(self, input): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input: a dictionary that contains the data itself and its metadata information. - """ - AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B - self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A - self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B - self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths - - def forward(self): - """Run forward pass. This will be called by both functions and .""" - self.output = self.netG(self.data_A) # generate output image given the input data_A - - def backward(self): - """Calculate losses, gradients, and update network weights; called in every training iteration""" - # caculate the intermediate results if necessary; here self.output has been computed during function - # calculate loss given the input and intermediate results - self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression - self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G - - def optimize_parameters(self): - """Update network weights; it will be called in every training iteration.""" - self.forward() # first call forward to calculate intermediate results - self.optimizer.zero_grad() # clear network G's existing gradients - self.backward() # calculate gradients for network G - self.optimizer.step() # update gradients for network G diff --git a/sadtalker_video2pose/src/face3d/options/__init__.py b/sadtalker_video2pose/src/face3d/options/__init__.py deleted file mode 100644 index 06559aa558cf178b946c4523b28b098d1dfad606..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/options/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/sadtalker_video2pose/src/face3d/options/base_options.py b/sadtalker_video2pose/src/face3d/options/base_options.py deleted file mode 100644 index 9a6db3f776b11a3946eaed1a41aae732ff3a15d9..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/options/base_options.py +++ /dev/null @@ -1,169 +0,0 @@ -"""This script contains base options for Deep3DFaceRecon_pytorch -""" - -import argparse -import os -from util import util -import numpy as np -import torch -import face3d.models as models -import face3d.data as data - - -class BaseOptions(): - """This class defines options used during both training and test time. - - It also implements several helper functions such as parsing, printing, and saving the options. - It also gathers additional options defined in functions in both dataset class and model class. - """ - - def __init__(self, cmd_line=None): - """Reset the class; indicates the class hasn't been initailized""" - self.initialized = False - self.cmd_line = None - if cmd_line is not None: - self.cmd_line = cmd_line.split() - - def initialize(self, parser): - """Define the common options that are used in both training and test.""" - # basic parameters - parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') - parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') - parser.add_argument('--checkpoints_dir', type=str, default='./ckpts/sad_talkers', help='models are saved here') - parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') - parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') - parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') - parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') - parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') - parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') - parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') - - # model parameters - parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') - - # additional parameters - parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') - parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') - parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') - - self.initialized = True - return parser - - def gather_options(self): - """Initialize our parser with basic options(only once). - Add additional model-specific and dataset-specific options. - These options are defined in the function - in model and dataset classes. - """ - if not self.initialized: # check if it has been initialized - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser = self.initialize(parser) - - # get the basic options - if self.cmd_line is None: - opt, _ = parser.parse_known_args() - else: - opt, _ = parser.parse_known_args(self.cmd_line) - - # set cuda visible devices - os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids - - # modify model-related parser options - model_name = opt.model - model_option_setter = models.get_option_setter(model_name) - parser = model_option_setter(parser, self.isTrain) - if self.cmd_line is None: - opt, _ = parser.parse_known_args() # parse again with new defaults - else: - opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults - - # modify dataset-related parser options - if opt.dataset_mode: - dataset_name = opt.dataset_mode - dataset_option_setter = data.get_option_setter(dataset_name) - parser = dataset_option_setter(parser, self.isTrain) - - # save and return the parser - self.parser = parser - if self.cmd_line is None: - return parser.parse_args() - else: - return parser.parse_args(self.cmd_line) - - def print_options(self, opt): - """Print and save options - - It will print both current options and default values(if different). - It will save options into a text file / [checkpoints_dir] / opt.txt - """ - message = '' - message += '----------------- Options ---------------\n' - for k, v in sorted(vars(opt).items()): - comment = '' - default = self.parser.get_default(k) - if v != default: - comment = '\t[default: %s]' % str(default) - message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) - message += '----------------- End -------------------' - print(message) - - # save to the disk - expr_dir = os.path.join(opt.checkpoints_dir, opt.name) - util.mkdirs(expr_dir) - file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) - try: - with open(file_name, 'wt') as opt_file: - opt_file.write(message) - opt_file.write('\n') - except PermissionError as error: - print("permission error {}".format(error)) - pass - - def parse(self): - """Parse our options, create checkpoints directory suffix, and set up gpu device.""" - opt = self.gather_options() - opt.isTrain = self.isTrain # train or test - - # process opt.suffix - if opt.suffix: - suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' - opt.name = opt.name + suffix - - - # set gpu ids - str_ids = opt.gpu_ids.split(',') - gpu_ids = [] - for str_id in str_ids: - id = int(str_id) - if id >= 0: - gpu_ids.append(id) - opt.world_size = len(gpu_ids) - # if len(opt.gpu_ids) > 0: - # torch.cuda.set_device(gpu_ids[0]) - if opt.world_size == 1: - opt.use_ddp = False - - if opt.phase != 'test': - # set continue_train automatically - if opt.pretrained_name is None: - model_dir = os.path.join(opt.checkpoints_dir, opt.name) - else: - model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) - if os.path.isdir(model_dir): - model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] - if os.path.isdir(model_dir) and len(model_pths) != 0: - opt.continue_train= True - - # update the latest epoch count - if opt.continue_train: - if opt.epoch == 'latest': - epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] - if len(epoch_counts) != 0: - opt.epoch_count = max(epoch_counts) + 1 - else: - opt.epoch_count = int(opt.epoch) + 1 - - - self.print_options(opt) - self.opt = opt - return self.opt diff --git a/sadtalker_video2pose/src/face3d/options/inference_options.py b/sadtalker_video2pose/src/face3d/options/inference_options.py deleted file mode 100644 index 80b9466776e120e0fe3d164217df5071c2114cef..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/options/inference_options.py +++ /dev/null @@ -1,23 +0,0 @@ -from face3d.options.base_options import BaseOptions - - -class InferenceOptions(BaseOptions): - """This class includes test options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) # define shared options - parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') - - parser.add_argument('--input_dir', type=str, help='the folder of the input files') - parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') - parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') - parser.add_argument('--save_split_files', action='store_true', help='save split files or not') - parser.add_argument('--inference_batch_size', type=int, default=8) - - # Dropout and Batchnorm has different behavior during training and test. - self.isTrain = False - return parser diff --git a/sadtalker_video2pose/src/face3d/options/test_options.py b/sadtalker_video2pose/src/face3d/options/test_options.py deleted file mode 100644 index f81c0c6eee0549e6fa8762dc4fc4b8573b887fe4..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/options/test_options.py +++ /dev/null @@ -1,21 +0,0 @@ -"""This script contains the test options for Deep3DFaceRecon_pytorch -""" - -from .base_options import BaseOptions - - -class TestOptions(BaseOptions): - """This class includes test options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) # define shared options - parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') - parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') - - # Dropout and Batchnorm has different behavior during training and test. - self.isTrain = False - return parser diff --git a/sadtalker_video2pose/src/face3d/options/train_options.py b/sadtalker_video2pose/src/face3d/options/train_options.py deleted file mode 100644 index 1100b0e35cc8ef563f41f6b8219510edbef53233..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/options/train_options.py +++ /dev/null @@ -1,53 +0,0 @@ -"""This script contains the training options for Deep3DFaceRecon_pytorch -""" - -from .base_options import BaseOptions -from util import util - -class TrainOptions(BaseOptions): - """This class includes training options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) - # dataset parameters - # for train - parser.add_argument('--data_root', type=str, default='./', help='dataset root') - parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') - parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') - parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') - parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') - parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') - parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') - - # for val - parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') - parser.add_argument('--batch_size_val', type=int, default=32) - - - # visualization parameters - parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') - parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') - - # network saving and loading parameters - parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') - parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') - parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') - parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') - parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') - parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') - parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') - parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') - - # training parameters - parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') - parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') - parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') - parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') - - self.isTrain = True - return parser diff --git a/sadtalker_video2pose/src/face3d/util/BBRegressorParam_r.mat b/sadtalker_video2pose/src/face3d/util/BBRegressorParam_r.mat deleted file mode 100644 index a0da99af145c400a5216d9f6fb251d9412565921..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/BBRegressorParam_r.mat +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3a5a07b8ce75a39d96b918dc0fc6e110a72e090da16f5f056a0ef7bfbc3f4560 -size 22019 diff --git a/sadtalker_video2pose/src/face3d/util/__init__.py b/sadtalker_video2pose/src/face3d/util/__init__.py deleted file mode 100644 index 1c67833cc634a2ca310b883ae253b08687665f40..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""This package includes a miscellaneous collection of useful helper functions.""" -from src.face3d.util import * - diff --git a/sadtalker_video2pose/src/face3d/util/detect_lm68.py b/sadtalker_video2pose/src/face3d/util/detect_lm68.py deleted file mode 100644 index 8a2cfd22b342de5c872ff07fc1c2a9920c2985b7..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/detect_lm68.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import cv2 -import numpy as np -from scipy.io import loadmat -import tensorflow as tf -from util.preprocess import align_for_lm -from shutil import move - -mean_face = np.loadtxt('util/test_mean_face.txt') -mean_face = mean_face.reshape([68, 2]) - -def save_label(labels, save_path): - np.savetxt(save_path, labels) - -def draw_landmarks(img, landmark, save_name): - landmark = landmark - lm_img = np.zeros([img.shape[0], img.shape[1], 3]) - lm_img[:] = img.astype(np.float32) - landmark = np.round(landmark).astype(np.int32) - - for i in range(len(landmark)): - for j in range(-1, 1): - for k in range(-1, 1): - if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ - img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ - landmark[i, 0]+k > 0 and \ - landmark[i, 0]+k < img.shape[1]: - lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, - :] = np.array([0, 0, 255]) - lm_img = lm_img.astype(np.uint8) - - cv2.imwrite(save_name, lm_img) - - -def load_data(img_name, txt_name): - return cv2.imread(img_name), np.loadtxt(txt_name) - -# create tensorflow graph for landmark detector -def load_lm_graph(graph_filename): - with tf.gfile.GFile(graph_filename, 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - - with tf.Graph().as_default() as graph: - tf.import_graph_def(graph_def, name='net') - img_224 = graph.get_tensor_by_name('net/input_imgs:0') - output_lm = graph.get_tensor_by_name('net/lm:0') - lm_sess = tf.Session(graph=graph) - - return lm_sess,img_224,output_lm - -# landmark detection -def detect_68p(img_path,sess,input_op,output_op): - print('detecting landmarks......') - names = [i for i in sorted(os.listdir( - img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] - vis_path = os.path.join(img_path, 'vis') - remove_path = os.path.join(img_path, 'remove') - save_path = os.path.join(img_path, 'landmarks') - if not os.path.isdir(vis_path): - os.makedirs(vis_path) - if not os.path.isdir(remove_path): - os.makedirs(remove_path) - if not os.path.isdir(save_path): - os.makedirs(save_path) - - for i in range(0, len(names)): - name = names[i] - print('%05d' % (i), ' ', name) - full_image_name = os.path.join(img_path, name) - txt_name = '.'.join(name.split('.')[:-1]) + '.txt' - full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image - - # if an image does not have detected 5 facial landmarks, remove it from the training list - if not os.path.isfile(full_txt_name): - move(full_image_name, os.path.join(remove_path, name)) - continue - - # load data - img, five_points = load_data(full_image_name, full_txt_name) - input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection - - # if the alignment fails, remove corresponding image from the training list - if scale == 0: - move(full_txt_name, os.path.join( - remove_path, txt_name)) - move(full_image_name, os.path.join(remove_path, name)) - continue - - # detect landmarks - input_img = np.reshape( - input_img, [1, 224, 224, 3]).astype(np.float32) - landmark = sess.run( - output_op, feed_dict={input_op: input_img}) - - # transform back to original image coordinate - landmark = landmark.reshape([68, 2]) + mean_face - landmark[:, 1] = 223 - landmark[:, 1] - landmark = landmark / scale - landmark[:, 0] = landmark[:, 0] + bbox[0] - landmark[:, 1] = landmark[:, 1] + bbox[1] - landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] - - if i % 100 == 0: - draw_landmarks(img, landmark, os.path.join(vis_path, name)) - save_label(landmark, os.path.join(save_path, txt_name)) diff --git a/sadtalker_video2pose/src/face3d/util/generate_list.py b/sadtalker_video2pose/src/face3d/util/generate_list.py deleted file mode 100644 index ebe93fcc5c61fbc79f4cd004a8d1bdd10ece16eb..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/generate_list.py +++ /dev/null @@ -1,34 +0,0 @@ -"""This script is to generate training list files for Deep3DFaceRecon_pytorch -""" - -import os - -# save path to training data -def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): - save_path = os.path.join(save_folder, mode) - if not os.path.isdir(save_path): - os.makedirs(save_path) - with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: - fd.writelines([i + '\n' for i in lms_list]) - - with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: - fd.writelines([i + '\n' for i in imgs_list]) - - with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: - fd.writelines([i + '\n' for i in msks_list]) - -# check if the path is valid -def check_list(rlms_list, rimgs_list, rmsks_list): - lms_list, imgs_list, msks_list = [], [], [] - for i in range(len(rlms_list)): - flag = 'false' - lm_path = rlms_list[i] - im_path = rimgs_list[i] - msk_path = rmsks_list[i] - if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): - flag = 'true' - lms_list.append(rlms_list[i]) - imgs_list.append(rimgs_list[i]) - msks_list.append(rmsks_list[i]) - print(i, rlms_list[i], flag) - return lms_list, imgs_list, msks_list diff --git a/sadtalker_video2pose/src/face3d/util/html.py b/sadtalker_video2pose/src/face3d/util/html.py deleted file mode 100644 index c0c4e6a66ba5a34e30cee3beb13e21465c72ef38..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/html.py +++ /dev/null @@ -1,86 +0,0 @@ -import dominate -from dominate.tags import meta, h3, table, tr, td, p, a, img, br -import os - - -class HTML: - """This HTML class allows us to save images and write texts into a single HTML file. - - It consists of functions such as (add a text header to the HTML file), - (add a row of images to the HTML file), and (save the HTML to the disk). - It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. - """ - - def __init__(self, web_dir, title, refresh=0): - """Initialize the HTML classes - - Parameters: - web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: - with self.doc.head: - meta(http_equiv="refresh", content=str(refresh)) - - def get_image_dir(self): - """Return the directory that stores images""" - return self.img_dir - - def add_header(self, text): - """Insert a header to the HTML file - - Parameters: - text (str) -- the header text - """ - with self.doc: - h3(text) - - def add_images(self, ims, txts, links, width=400): - """add images to the HTML file - - Parameters: - ims (str list) -- a list of image paths - txts (str list) -- a list of image names shown on the website - links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page - """ - self.t = table(border=1, style="table-layout: fixed;") # Insert a table - self.doc.add(self.t) - with self.t: - with tr(): - for im, txt, link in zip(ims, txts, links): - with td(style="word-wrap: break-word;", halign="center", valign="top"): - with p(): - with a(href=os.path.join('images', link)): - img(style="width:%dpx" % width, src=os.path.join('images', im)) - br() - p(txt) - - def save(self): - """save the current content to the HMTL file""" - html_file = '%s/index.html' % self.web_dir - f = open(html_file, 'wt') - f.write(self.doc.render()) - f.close() - - -if __name__ == '__main__': # we show an example usage here. - html = HTML('web/', 'test_html') - html.add_header('hello world') - - ims, txts, links = [], [], [] - for n in range(4): - ims.append('image_%d.png' % n) - txts.append('text_%d' % n) - links.append('image_%d.png' % n) - html.add_images(ims, txts, links) - html.save() diff --git a/sadtalker_video2pose/src/face3d/util/load_mats.py b/sadtalker_video2pose/src/face3d/util/load_mats.py deleted file mode 100644 index b7ea0a7877e80035883138415c102910d896bb61..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/load_mats.py +++ /dev/null @@ -1,120 +0,0 @@ -"""This script is to load 3D face model for Deep3DFaceRecon_pytorch -""" - -import numpy as np -from PIL import Image -from scipy.io import loadmat, savemat -from array import array -import os.path as osp - -# load expression basis -def LoadExpBasis(bfm_folder='BFM'): - n_vertex = 53215 - Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') - exp_dim = array('i') - exp_dim.fromfile(Expbin, 1) - expMU = array('f') - expPC = array('f') - expMU.fromfile(Expbin, 3*n_vertex) - expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) - Expbin.close() - - expPC = np.array(expPC) - expPC = np.reshape(expPC, [exp_dim[0], -1]) - expPC = np.transpose(expPC) - - expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) - - return expPC, expEV - - -# transfer original BFM09 to our face model -def transferBFM09(bfm_folder='BFM'): - print('Transfer BFM09 to BFM_model_front......') - original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) - shapePC = original_BFM['shapePC'] # shape basis - shapeEV = original_BFM['shapeEV'] # corresponding eigen value - shapeMU = original_BFM['shapeMU'] # mean face - texPC = original_BFM['texPC'] # texture basis - texEV = original_BFM['texEV'] # eigen value - texMU = original_BFM['texMU'] # mean texture - - expPC, expEV = LoadExpBasis(bfm_folder) - - # transfer BFM09 to our face model - - idBase = shapePC*np.reshape(shapeEV, [-1, 199]) - idBase = idBase/1e5 # unify the scale to decimeter - idBase = idBase[:, :80] # use only first 80 basis - - exBase = expPC*np.reshape(expEV, [-1, 79]) - exBase = exBase/1e5 # unify the scale to decimeter - exBase = exBase[:, :64] # use only first 64 basis - - texBase = texPC*np.reshape(texEV, [-1, 199]) - texBase = texBase[:, :80] # use only first 80 basis - - # our face model is cropped along face landmarks and contains only 35709 vertex. - # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. - # thus we select corresponding vertex to get our face model. - - index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) - index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) - - index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) - index_shape = index_shape['trimIndex'].astype( - np.int32) - 1 # starts from 0 (to 53490) - index_shape = index_shape[index_exp] - - idBase = np.reshape(idBase, [-1, 3, 80]) - idBase = idBase[index_shape, :, :] - idBase = np.reshape(idBase, [-1, 80]) - - texBase = np.reshape(texBase, [-1, 3, 80]) - texBase = texBase[index_shape, :, :] - texBase = np.reshape(texBase, [-1, 80]) - - exBase = np.reshape(exBase, [-1, 3, 64]) - exBase = exBase[index_exp, :, :] - exBase = np.reshape(exBase, [-1, 64]) - - meanshape = np.reshape(shapeMU, [-1, 3])/1e5 - meanshape = meanshape[index_shape, :] - meanshape = np.reshape(meanshape, [1, -1]) - - meantex = np.reshape(texMU, [-1, 3]) - meantex = meantex[index_shape, :] - meantex = np.reshape(meantex, [1, -1]) - - # other info contains triangles, region used for computing photometric loss, - # region used for skin texture regularization, and 68 landmarks index etc. - other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) - frontmask2_idx = other_info['frontmask2_idx'] - skinmask = other_info['skinmask'] - keypoints = other_info['keypoints'] - point_buf = other_info['point_buf'] - tri = other_info['tri'] - tri_mask2 = other_info['tri_mask2'] - - # save our face model - savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, - 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) - - -# load landmarks for standard face, which is used for image preprocessing -def load_lm3d(bfm_folder): - - Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) - Lm3D = Lm3D['lm'] - - # calculate 5 facial landmarks using 68 landmarks - lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 - Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( - Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) - Lm3D = Lm3D[[1, 2, 0, 3, 4], :] - - return Lm3D - - -if __name__ == '__main__': - transferBFM09() \ No newline at end of file diff --git a/sadtalker_video2pose/src/face3d/util/nvdiffrast.py b/sadtalker_video2pose/src/face3d/util/nvdiffrast.py deleted file mode 100644 index 4b345db30085de501b6718ad5b49bb5f9144dd29..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/nvdiffrast.py +++ /dev/null @@ -1,126 +0,0 @@ -"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch - Attention, antialiasing step is missing in current version. -""" -import pytorch3d.ops -import torch -import torch.nn.functional as F -import kornia -from kornia.geometry.camera import pixel2cam -import numpy as np -from typing import List -from scipy.io import loadmat -from torch import nn - -from pytorch3d.structures import Meshes -from pytorch3d.renderer import ( - look_at_view_transform, - FoVPerspectiveCameras, - DirectionalLights, - RasterizationSettings, - MeshRenderer, - MeshRasterizer, - SoftPhongShader, - TexturesUV, -) - -# def ndc_projection(x=0.1, n=1.0, f=50.0): -# return np.array([[n/x, 0, 0, 0], -# [ 0, n/-x, 0, 0], -# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], -# [ 0, 0, -1, 0]]).astype(np.float32) - -class MeshRenderer(nn.Module): - def __init__(self, - rasterize_fov, - znear=0.1, - zfar=10, - rasterize_size=224): - super(MeshRenderer, self).__init__() - - # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear - # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( - # torch.diag(torch.tensor([1., -1, -1, 1]))) - self.rasterize_size = rasterize_size - self.fov = rasterize_fov - self.znear = znear - self.zfar = zfar - - self.rasterizer = None - - def forward(self, vertex, tri, feat=None): - """ - Return: - mask -- torch.tensor, size (B, 1, H, W) - depth -- torch.tensor, size (B, 1, H, W) - features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None - - Parameters: - vertex -- torch.tensor, size (B, N, 3) - tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles - feat(optional) -- torch.tensor, size (B, N ,C), features - """ - device = vertex.device - rsize = int(self.rasterize_size) - # ndc_proj = self.ndc_proj.to(device) - # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v - if vertex.shape[-1] == 3: - vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) - vertex[..., 0] = -vertex[..., 0] - - - # vertex_ndc = vertex @ ndc_proj.t() - if self.rasterizer is None: - self.rasterizer = MeshRasterizer() - print("create rasterizer on device cuda:%d"%device.index) - - # ranges = None - # if isinstance(tri, List) or len(tri.shape) == 3: - # vum = vertex_ndc.shape[1] - # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) - # fstartidx = torch.cumsum(fnum, dim=0) - fnum - # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() - # for i in range(tri.shape[0]): - # tri[i] = tri[i] + i*vum - # vertex_ndc = torch.cat(vertex_ndc, dim=0) - # tri = torch.cat(tri, dim=0) - - # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] - tri = tri.type(torch.int32).contiguous() - - # rasterize - cameras = FoVPerspectiveCameras( - device=device, - fov=self.fov, - znear=self.znear, - zfar=self.zfar, - ) - - raster_settings = RasterizationSettings( - image_size=rsize - ) - - # print(vertex.shape, tri.shape) - mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) - - fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) - rast_out = fragments.pix_to_face.squeeze(-1) - depth = fragments.zbuf - - # render depth - depth = depth.permute(0, 3, 1, 2) - mask = (rast_out > 0).float().unsqueeze(1) - depth = mask * depth - - - image = None - if feat is not None: - attributes = feat.reshape(-1,3)[mesh.faces_packed()] - image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, - fragments.bary_coords, - attributes) - # print(image.shape) - image = image.squeeze(-2).permute(0, 3, 1, 2) - image = mask * image - - return mask, depth, image - diff --git a/sadtalker_video2pose/src/face3d/util/preprocess.py b/sadtalker_video2pose/src/face3d/util/preprocess.py deleted file mode 100644 index 82b36443fe4c84c1ad6366897a8e7d4e8b63b2b6..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/preprocess.py +++ /dev/null @@ -1,134 +0,0 @@ -"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch -""" - -import numpy as np -from scipy.io import loadmat -from PIL import Image -import cv2 -import os -from skimage import transform as trans -import torch -import warnings -warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) -warnings.filterwarnings("ignore", category=FutureWarning) - - -# calculating least square problem for image alignment -def POS(xp, x): - npts = xp.shape[1] - - A = np.zeros([2*npts, 8]) - - A[0:2*npts-1:2, 0:3] = x.transpose() - A[0:2*npts-1:2, 3] = 1 - - A[1:2*npts:2, 4:7] = x.transpose() - A[1:2*npts:2, 7] = 1 - - b = np.reshape(xp.transpose(), [2*npts, 1]) - - k, _, _, _ = np.linalg.lstsq(A, b) - - R1 = k[0:3] - R2 = k[4:7] - sTx = k[3] - sTy = k[7] - s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 - t = np.stack([sTx, sTy], axis=0) - - return t, s - -# # resize and crop images for face reconstruction -# def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): -# w0, h0 = img.size -# w = (w0*s).astype(np.int32) -# h = (h0*s).astype(np.int32) -# left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) -# right = left + target_size -# up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) -# below = up + target_size - -# img = img.resize((w, h), resample=Image.BICUBIC) -# img = img.crop((left, up, right, below)) - -# if mask is not None: -# mask = mask.resize((w, h), resample=Image.BICUBIC) -# mask = mask.crop((left, up, right, below)) - -# lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - -# t[1] + h0/2], axis=1)*s -# lm = lm - np.reshape( -# np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) - -# return img, lm, mask - - -# resize and crop images for face reconstruction -def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): - w0, h0 = img.size - w = (w0*s).astype(np.int32) - h = (h0*s).astype(np.int32) - left = np.round(w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) - right = left + target_size - up = np.round(h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) - below = up + target_size - - img = img.resize((w, h), resample=Image.BICUBIC) - img = img.crop((left, up, right, below)) - # import pdb; pdb.set_trace() - if mask is not None: - mask = mask.resize((w, h), resample=Image.BICUBIC) - mask = mask.crop((left, up, right, below)) - - lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - - t[1] + h0/2], axis=1)*s - lm = lm - np.reshape( - np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) - - # orig_left, orig_up, orig_crop_size = (left,up,target_size)/s - - return img, lm, mask, left, up, target_size - -# utils for face reconstruction -def extract_5p(lm): - lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 - lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( - lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) - lm5p = lm5p[[1, 2, 0, 3, 4], :] - return lm5p - -# utils for face reconstruction -def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): - """ - Return: - transparams --numpy.array (raw_W, raw_H, scale, tx, ty) - img_new --PIL.Image (target_size, target_size, 3) - lm_new --numpy.array (68, 2), y direction is opposite to v direction - mask_new --PIL.Image (target_size, target_size) - - Parameters: - img --PIL.Image (raw_H, raw_W, 3) - lm --numpy.array (68, 2), y direction is opposite to v direction - lm3D --numpy.array (5, 3) - mask --PIL.Image (raw_H, raw_W, 3) - """ - - w0, h0 = img.size - if lm.shape[0] != 5: - lm5p = extract_5p(lm) - else: - lm5p = lm - - # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face - t, s = POS(lm5p.transpose(), lm3D.transpose()) - s = rescale_factor/s - - # processing the image - - # processing the image - img_new, lm_new, mask_new, orig_left, orig_up, orig_crop_size = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) - trans_params = np.array([w0, h0, s, t[0], t[1], orig_left, orig_up, orig_crop_size]) - # img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) - # trans_params = np.array([w0, h0, s, t[0], t[1]]) - - return trans_params, img_new, lm_new, mask_new diff --git a/sadtalker_video2pose/src/face3d/util/skin_mask.py b/sadtalker_video2pose/src/face3d/util/skin_mask.py deleted file mode 100644 index ed764759038f77b35d45448b344d4347498ca427..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/skin_mask.py +++ /dev/null @@ -1,125 +0,0 @@ -"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch -""" - -import math -import numpy as np -import os -import cv2 - -class GMM: - def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): - self.dim = dim # feature dimension - self.num = num # number of Gaussian components - self.w = w # weights of Gaussian components (a list of scalars) - self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) - self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) - self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) - self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) - - self.factor = [0]*num - for i in range(self.num): - self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 - - def likelihood(self, data): - assert(data.shape[1] == self.dim) - N = data.shape[0] - lh = np.zeros(N) - - for i in range(self.num): - data_ = data - self.mu[i] - - tmp = np.matmul(data_,self.cov_inv[i]) * data_ - tmp = np.sum(tmp,axis=1) - power = -0.5 * tmp - - p = np.array([math.exp(power[j]) for j in range(N)]) - p = p/self.factor[i] - lh += p*self.w[i] - - return lh - - -def _rgb2ycbcr(rgb): - m = np.array([[65.481, 128.553, 24.966], - [-37.797, -74.203, 112], - [112, -93.786, -18.214]]) - shape = rgb.shape - rgb = rgb.reshape((shape[0] * shape[1], 3)) - ycbcr = np.dot(rgb, m.transpose() / 255.) - ycbcr[:, 0] += 16. - ycbcr[:, 1:] += 128. - return ycbcr.reshape(shape) - - -def _bgr2ycbcr(bgr): - rgb = bgr[..., ::-1] - return _rgb2ycbcr(rgb) - - -gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] -gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), - np.array([150.19858, 105.18467, 155.51428]), - np.array([183.92976, 107.62468, 152.71820]), - np.array([114.90524, 113.59782, 151.38217])] -gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] -gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), - np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), - np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), - np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] - -gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) - -gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] -gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), - np.array([110.91392, 125.52969, 130.19237]), - np.array([129.75864, 129.96107, 126.96808]), - np.array([112.29587, 128.85121, 129.05431])] -gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] -gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), - np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), - np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), - np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] - -gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) - -prior_skin = 0.8 -prior_nonskin = 1 - prior_skin - - -# calculate skin attention mask -def skinmask(imbgr): - im = _bgr2ycbcr(imbgr) - - data = im.reshape((-1,3)) - - lh_skin = gmm_skin.likelihood(data) - lh_nonskin = gmm_nonskin.likelihood(data) - - tmp1 = prior_skin * lh_skin - tmp2 = prior_nonskin * lh_nonskin - post_skin = tmp1 / (tmp1+tmp2) # posterior probability - - post_skin = post_skin.reshape((im.shape[0],im.shape[1])) - - post_skin = np.round(post_skin*255) - post_skin = post_skin.astype(np.uint8) - post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 - - return post_skin - - -def get_skin_mask(img_path): - print('generating skin masks......') - names = [i for i in sorted(os.listdir( - img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] - save_path = os.path.join(img_path, 'mask') - if not os.path.isdir(save_path): - os.makedirs(save_path) - - for i in range(0, len(names)): - name = names[i] - print('%05d' % (i), ' ', name) - full_image_name = os.path.join(img_path, name) - img = cv2.imread(full_image_name).astype(np.float32) - skin_img = skinmask(img) - cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) diff --git a/sadtalker_video2pose/src/face3d/util/test_mean_face.txt b/sadtalker_video2pose/src/face3d/util/test_mean_face.txt deleted file mode 100644 index 1637648acf5a61cbc71b317c845414bb16d0150c..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/test_mean_face.txt +++ /dev/null @@ -1,136 +0,0 @@ --5.228591537475585938e+01 -2.078247070312500000e-01 --5.064269638061523438e+01 --1.315765380859375000e+01 --4.952939224243164062e+01 --2.592591094970703125e+01 --4.793047332763671875e+01 --3.832135772705078125e+01 --4.512159729003906250e+01 --5.059623336791992188e+01 --3.917720794677734375e+01 --6.043736648559570312e+01 --2.929953765869140625e+01 --6.861183166503906250e+01 --1.719801330566406250e+01 --7.572736358642578125e+01 --1.961936950683593750e+00 --7.862001037597656250e+01 -1.467941284179687500e+01 --7.607844543457031250e+01 -2.744073486328125000e+01 --6.915261840820312500e+01 -3.855677795410156250e+01 --5.950350570678710938e+01 -4.478240966796875000e+01 --4.867547225952148438e+01 -4.714337158203125000e+01 --3.800830078125000000e+01 -4.940315246582031250e+01 --2.496297454833984375e+01 -5.117234802246093750e+01 --1.241538238525390625e+01 -5.190507507324218750e+01 -8.244247436523437500e-01 --4.150688934326171875e+01 -2.386329650878906250e+01 --3.570307159423828125e+01 -3.017010498046875000e+01 --2.790358734130859375e+01 -3.212951660156250000e+01 --1.941773223876953125e+01 -3.156523132324218750e+01 --1.138106536865234375e+01 -2.841992187500000000e+01 -5.993263244628906250e+00 -2.895182800292968750e+01 -1.343590545654296875e+01 -3.189880371093750000e+01 -2.203153991699218750e+01 -3.302221679687500000e+01 -2.992478942871093750e+01 -3.099150085449218750e+01 -3.628388977050781250e+01 -2.765748596191406250e+01 --1.933914184570312500e+00 -1.405374145507812500e+01 --2.153038024902343750e+00 -5.772636413574218750e+00 --2.270050048828125000e+00 --2.121643066406250000e+00 --2.218330383300781250e+00 --1.068978118896484375e+01 --1.187252044677734375e+01 --1.997912597656250000e+01 --6.879402160644531250e+00 --2.143579864501953125e+01 --1.227821350097656250e+00 --2.193494415283203125e+01 -4.623237609863281250e+00 --2.152721405029296875e+01 -9.721397399902343750e+00 --1.953671264648437500e+01 --3.648714447021484375e+01 -9.811126708984375000e+00 --3.130242919921875000e+01 -1.422447967529296875e+01 --2.212834930419921875e+01 -1.493019866943359375e+01 --1.500880432128906250e+01 -1.073588562011718750e+01 --2.095037078857421875e+01 -9.054298400878906250e+00 --3.050099182128906250e+01 -8.704177856445312500e+00 -1.173237609863281250e+01 -1.054329681396484375e+01 -1.856353759765625000e+01 -1.535009765625000000e+01 -2.893331909179687500e+01 -1.451992797851562500e+01 -3.452944946289062500e+01 -1.065280151367187500e+01 -2.875990295410156250e+01 -8.654792785644531250e+00 -1.942100524902343750e+01 -9.422447204589843750e+00 --2.204488372802734375e+01 --3.983994293212890625e+01 --1.324458312988281250e+01 --3.467377471923828125e+01 --6.749649047851562500e+00 --3.092894744873046875e+01 --9.183349609375000000e-01 --3.196458435058593750e+01 -4.220649719238281250e+00 --3.090406036376953125e+01 -1.089889526367187500e+01 --3.497008514404296875e+01 -1.874589538574218750e+01 --4.065438079833984375e+01 -1.124106597900390625e+01 --4.438417816162109375e+01 -5.181709289550781250e+00 --4.649170684814453125e+01 --1.158607482910156250e+00 --4.680406951904296875e+01 --7.918922424316406250e+00 --4.671575164794921875e+01 --1.452505493164062500e+01 --4.416526031494140625e+01 --2.005007171630859375e+01 --3.997841644287109375e+01 --1.054919433593750000e+01 --3.849683380126953125e+01 --1.051826477050781250e+00 --3.794863128662109375e+01 -6.412681579589843750e+00 --3.804645538330078125e+01 -1.627674865722656250e+01 --4.039697265625000000e+01 -6.373878479003906250e+00 --4.087213897705078125e+01 --8.551712036132812500e-01 --4.157129669189453125e+01 --1.014953613281250000e+01 --4.128469085693359375e+01 diff --git a/sadtalker_video2pose/src/face3d/util/util.py b/sadtalker_video2pose/src/face3d/util/util.py deleted file mode 100644 index 79c7517ee66c8830a73fa86ab5e5c3513f11d869..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/util.py +++ /dev/null @@ -1,208 +0,0 @@ -"""This script contains basic utilities for Deep3DFaceRecon_pytorch -""" -from __future__ import print_function -import numpy as np -import torch -from PIL import Image -import os -import importlib -import argparse -from argparse import Namespace -import torchvision - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def copyconf(default_opt, **kwargs): - conf = Namespace(**vars(default_opt)) - for key in kwargs: - setattr(conf, key, kwargs[key]) - return conf - -def genvalconf(train_opt, **kwargs): - conf = Namespace(**vars(train_opt)) - attr_dict = train_opt.__dict__ - for key, value in attr_dict.items(): - if 'val' in key and key.split('_')[0] in attr_dict: - setattr(conf, key.split('_')[0], value) - - for key in kwargs: - setattr(conf, key, kwargs[key]) - - return conf - -def find_class_in_module(target_cls_name, module): - target_cls_name = target_cls_name.replace('_', '').lower() - clslib = importlib.import_module(module) - cls = None - for name, clsobj in clslib.__dict__.items(): - if name.lower() == target_cls_name: - cls = clsobj - - assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) - - return cls - - -def tensor2im(input_image, imtype=np.uint8): - """"Converts a Tensor array into a numpy image array. - - Parameters: - input_image (tensor) -- the input image tensor array, range(0, 1) - imtype (type) -- the desired type of the converted numpy array - """ - if not isinstance(input_image, np.ndarray): - if isinstance(input_image, torch.Tensor): # get the data from a variable - image_tensor = input_image.data - else: - return input_image - image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array - if image_numpy.shape[0] == 1: # grayscale to RGB - image_numpy = np.tile(image_numpy, (3, 1, 1)) - image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling - else: # if it is a numpy array, do nothing - image_numpy = input_image - return image_numpy.astype(imtype) - - -def diagnose_network(net, name='network'): - """Calculate and print the mean of average absolute(gradients) - - Parameters: - net (torch network) -- Torch network - name (str) -- the name of the network - """ - mean = 0.0 - count = 0 - for param in net.parameters(): - if param.grad is not None: - mean += torch.mean(torch.abs(param.grad.data)) - count += 1 - if count > 0: - mean = mean / count - print(name) - print(mean) - - -def save_image(image_numpy, image_path, aspect_ratio=1.0): - """Save a numpy image to the disk - - Parameters: - image_numpy (numpy array) -- input numpy array - image_path (str) -- the path of the image - """ - - image_pil = Image.fromarray(image_numpy) - h, w, _ = image_numpy.shape - - if aspect_ratio is None: - pass - elif aspect_ratio > 1.0: - image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) - elif aspect_ratio < 1.0: - image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) - image_pil.save(image_path) - - -def print_numpy(x, val=True, shp=False): - """Print the mean, min, max, median, std, and size of a numpy array - - Parameters: - val (bool) -- if print the values of the numpy array - shp (bool) -- if print the shape of the numpy array - """ - x = x.astype(np.float64) - if shp: - print('shape,', x.shape) - if val: - x = x.flatten() - print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( - np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) - - -def mkdirs(paths): - """create empty directories if they don't exist - - Parameters: - paths (str list) -- a list of directory paths - """ - if isinstance(paths, list) and not isinstance(paths, str): - for path in paths: - mkdir(path) - else: - mkdir(paths) - - -def mkdir(path): - """create a single empty directory if it didn't exist - - Parameters: - path (str) -- a single directory path - """ - if not os.path.exists(path): - os.makedirs(path) - - -def correct_resize_label(t, size): - device = t.device - t = t.detach().cpu() - resized = [] - for i in range(t.size(0)): - one_t = t[i, :1] - one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) - one_np = one_np[:, :, 0] - one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) - resized_t = torch.from_numpy(np.array(one_image)).long() - resized.append(resized_t) - return torch.stack(resized, dim=0).to(device) - - -def correct_resize(t, size, mode=Image.BICUBIC): - device = t.device - t = t.detach().cpu() - resized = [] - for i in range(t.size(0)): - one_t = t[i:i + 1] - one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) - resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 - resized.append(resized_t) - return torch.stack(resized, dim=0).to(device) - -def draw_landmarks(img, landmark, color='r', step=2): - """ - Return: - img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) - - - Parameters: - img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) - landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction - color -- str, 'r' or 'b' (red or blue) - """ - if color =='r': - c = np.array([255., 0, 0]) - else: - c = np.array([0, 0, 255.]) - - _, H, W, _ = img.shape - img, landmark = img.copy(), landmark.copy() - landmark[..., 1] = H - 1 - landmark[..., 1] - landmark = np.round(landmark).astype(np.int32) - for i in range(landmark.shape[1]): - x, y = landmark[:, i, 0], landmark[:, i, 1] - for j in range(-step, step): - for k in range(-step, step): - u = np.clip(x + j, 0, W - 1) - v = np.clip(y + k, 0, H - 1) - for m in range(landmark.shape[0]): - img[m, v[m], u[m]] = c - return img diff --git a/sadtalker_video2pose/src/face3d/util/visualizer.py b/sadtalker_video2pose/src/face3d/util/visualizer.py deleted file mode 100644 index c4a8b755e054a4a34d003962a723ef189726a7a0..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/util/visualizer.py +++ /dev/null @@ -1,227 +0,0 @@ -"""This script defines the visualizer for Deep3DFaceRecon_pytorch -""" - -import numpy as np -import os -import sys -import ntpath -import time -from . import util, html -from subprocess import Popen, PIPE -from torch.utils.tensorboard import SummaryWriter - -def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): - """Save images to the disk. - - Parameters: - webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) - visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs - image_path (str) -- the string is used to create image paths - aspect_ratio (float) -- the aspect ratio of saved images - width (int) -- the images will be resized to width x width - - This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. - """ - image_dir = webpage.get_image_dir() - short_path = ntpath.basename(image_path[0]) - name = os.path.splitext(short_path)[0] - - webpage.add_header(name) - ims, txts, links = [], [], [] - - for label, im_data in visuals.items(): - im = util.tensor2im(im_data) - image_name = '%s/%s.png' % (label, name) - os.makedirs(os.path.join(image_dir, label), exist_ok=True) - save_path = os.path.join(image_dir, image_name) - util.save_image(im, save_path, aspect_ratio=aspect_ratio) - ims.append(image_name) - txts.append(label) - links.append(image_name) - webpage.add_images(ims, txts, links, width=width) - - -class Visualizer(): - """This class includes several functions that can display/save images and print/save logging information. - - It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. - """ - - def __init__(self, opt): - """Initialize the Visualizer class - - Parameters: - opt -- stores all the experiment flags; needs to be a subclass of BaseOptions - Step 1: Cache the training/test options - Step 2: create a tensorboard writer - Step 3: create an HTML object for saveing HTML filters - Step 4: create a logging file to store training losses - """ - self.opt = opt # cache the option - self.use_html = opt.isTrain and not opt.no_html - self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) - self.win_size = opt.display_winsize - self.name = opt.name - self.saved = False - if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ - self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') - self.img_dir = os.path.join(self.web_dir, 'images') - print('create web directory %s...' % self.web_dir) - util.mkdirs([self.web_dir, self.img_dir]) - # create a logging file to store training losses - self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') - with open(self.log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ Training Loss (%s) ================\n' % now) - - def reset(self): - """Reset the self.saved status""" - self.saved = False - - - def display_current_results(self, visuals, total_iters, epoch, save_result): - """Display current results on tensorboad; save current results to an HTML file. - - Parameters: - visuals (OrderedDict) - - dictionary of images to display or save - total_iters (int) -- total iterations - epoch (int) - - the current epoch - save_result (bool) - - if save the current results to an HTML file - """ - for label, image in visuals.items(): - self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') - - if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. - self.saved = True - # save images to the disk - for label, image in visuals.items(): - image_numpy = util.tensor2im(image) - img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) - util.save_image(image_numpy, img_path) - - # update website - webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) - for n in range(epoch, 0, -1): - webpage.add_header('epoch [%d]' % n) - ims, txts, links = [], [], [] - - for label, image_numpy in visuals.items(): - image_numpy = util.tensor2im(image) - img_path = 'epoch%.3d_%s.png' % (n, label) - ims.append(img_path) - txts.append(label) - links.append(img_path) - webpage.add_images(ims, txts, links, width=self.win_size) - webpage.save() - - def plot_current_losses(self, total_iters, losses): - # G_loss_collection = {} - # D_loss_collection = {} - # for name, value in losses.items(): - # if 'G' in name or 'NCE' in name or 'idt' in name: - # G_loss_collection[name] = value - # else: - # D_loss_collection[name] = value - # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) - # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) - for name, value in losses.items(): - self.writer.add_scalar(name, value, total_iters) - - # losses: same format as |losses| of plot_current_losses - def print_current_losses(self, epoch, iters, losses, t_comp, t_data): - """print current losses on console; also save the losses to the disk - - Parameters: - epoch (int) -- current epoch - iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) - losses (OrderedDict) -- training losses stored in the format of (name, float) pairs - t_comp (float) -- computational time per data point (normalized by batch_size) - t_data (float) -- data loading time per data point (normalized by batch_size) - """ - message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) - for k, v in losses.items(): - message += '%s: %.3f ' % (k, v) - - print(message) # print the message - with open(self.log_name, "a") as log_file: - log_file.write('%s\n' % message) # save the message - - -class MyVisualizer: - def __init__(self, opt): - """Initialize the Visualizer class - - Parameters: - opt -- stores all the experiment flags; needs to be a subclass of BaseOptions - Step 1: Cache the training/test options - Step 2: create a tensorboard writer - Step 3: create an HTML object for saveing HTML filters - Step 4: create a logging file to store training losses - """ - self.opt = opt # cache the optio - self.name = opt.name - self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') - - if opt.phase != 'test': - self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) - # create a logging file to store training losses - self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') - with open(self.log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write('================ Training Loss (%s) ================\n' % now) - - - def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, - add_image=True): - """Display current results on tensorboad; save current results to an HTML file. - - Parameters: - visuals (OrderedDict) - - dictionary of images to display or save - total_iters (int) -- total iterations - epoch (int) - - the current epoch - dataset (str) - - 'train' or 'val' or 'test' - """ - # if (not add_image) and (not save_results): return - - for label, image in visuals.items(): - for i in range(image.shape[0]): - image_numpy = util.tensor2im(image[i]) - if add_image: - self.writer.add_image(label + '%s_%02d'%(dataset, i + count), - image_numpy, total_iters, dataformats='HWC') - - if save_results: - save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) - if not os.path.isdir(save_path): - os.makedirs(save_path) - - if name is not None: - img_path = os.path.join(save_path, '%s.png' % name) - else: - img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) - util.save_image(image_numpy, img_path) - - - def plot_current_losses(self, total_iters, losses, dataset='train'): - for name, value in losses.items(): - self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) - - # losses: same format as |losses| of plot_current_losses - def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): - """print current losses on console; also save the losses to the disk - - Parameters: - epoch (int) -- current epoch - iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) - losses (OrderedDict) -- training losses stored in the format of (name, float) pairs - t_comp (float) -- computational time per data point (normalized by batch_size) - t_data (float) -- data loading time per data point (normalized by batch_size) - """ - message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( - dataset, epoch, iters, t_comp, t_data) - for k, v in losses.items(): - message += '%s: %.3f ' % (k, v) - - print(message) # print the message - with open(self.log_name, "a") as log_file: - log_file.write('%s\n' % message) # save the message diff --git a/sadtalker_video2pose/src/face3d/visualize.py b/sadtalker_video2pose/src/face3d/visualize.py deleted file mode 100644 index cb8791ec30fb8f748aefc82cf4385444754825a4..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/visualize.py +++ /dev/null @@ -1,133 +0,0 @@ -# check the sync of 3dmm feature and the audio -import shutil -import cv2 -import numpy as np -from src.face3d.models.bfm import ParametricFaceModel -from src.face3d.models.facerecon_model import FaceReconModel -import torch -import subprocess, platform -import scipy.io as scio -from tqdm import tqdm - - -def draw_landmarks(image, landmarks): - for i, point in enumerate(landmarks): - cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1) - cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1) - return image - -# draft -def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False): - - coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] - info = scio.loadmat(first_frame_coeff)['trans_params'][0] - print(info) - - coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] - - # print(coeff_pred.shape) - # print(coeff_pred[1:, 64:].shape) - - if args.still: - coeff_pred[1:, 64:] = np.stack([coeff_pred[0, 64:]]*coeff_pred[1:, 64:].shape[0]) - - # assert False - - coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 - - coeff_full[:, 80:144] = coeff_pred[:, 0:64] - coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation - coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - tmp_video_path = '/tmp/face3dtmp.mp4' - facemodel = FaceReconModel(args) - im0 = cv2.imread(args.source_image) - - video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) - - # since we resize the video, we first need to resize the landmark to the cropped size resolution - # then, we need to add it back to the original video - x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256 - - W, H = im0.shape[0], im0.shape[1] - - _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7]) - orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)] - - landmark_scale = np.array([[x_scale, y_scale]]) - landmark_shift = np.array([[orig_left, orig_up]]) - landmark_shift2 = np.array([[ox1, oy1]]) - - - landmarks = [] - - for k in tqdm(range(coeff_first.shape[0]), '1st:'): - cur_coeff_full = torch.tensor(coeff_first, device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - print(orig_up, orig_left, orig_crop_size, s) - - for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): - cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - rendered_img = facemodel.pred_face - rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) - out_img = rendered_img[:, :, :3].astype(np.uint8) - - video.write(np.uint8(out_img[:,:,::-1])) - - video.release() - - # visualize landmarks - video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1])) - - for k in tqdm(range(len(landmarks)), 'face3d vis:'): - # im = draw_landmarks(im0.copy(), landmarks[k]) - im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k]) - video.write(im) - video.release() - - shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png')) - - np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks) - - command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) - subprocess.call(command, shell=platform.system() != 'Windows') - diff --git a/sadtalker_video2pose/src/face3d/visualize_fromvideo.py b/sadtalker_video2pose/src/face3d/visualize_fromvideo.py deleted file mode 100644 index 44d74872695739df70ce9009351b7cd78a8cb779..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/visualize_fromvideo.py +++ /dev/null @@ -1,133 +0,0 @@ -# check the sync of 3dmm feature and the audio -import shutil -import cv2 -import numpy as np -from src.face3d.models.bfm import ParametricFaceModel -from src.face3d.models.facerecon_model import FaceReconModel -import torch -import subprocess, platform -import scipy.io as scio -from tqdm import tqdm - - -def draw_landmarks(image, landmarks): - for i, point in enumerate(landmarks): - cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1) - cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1) - return image - -# draft -def gen_composed_video(args, device, first_frame_coeff, coeff_path, save_path, save_lmk_path, crop_info, extended_crop = False): - - coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] - info = scio.loadmat(first_frame_coeff)['trans_params'][0] - print(info) - - coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] - - # print(coeff_pred.shape) - # print(coeff_pred[1:, 64:].shape) - - if args.still: - coeff_pred[1:, 64:] = np.stack([coeff_pred[0, 64:]]*coeff_pred[1:, 64:].shape[0]) - - # assert False - - coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 - - coeff_full[:, 80:144] = coeff_pred[:, 0:64] - coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation - coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - # tmp_video_path = '/tmp/face3dtmp.mp4' - facemodel = FaceReconModel(args) - im0 = cv2.imread(args.source_image) - - video = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) - - # since we resize the video, we first need to resize the landmark to the cropped size resolution - # then, we need to add it back to the original video - x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256 - - W, H = im0.shape[0], im0.shape[1] - - _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7]) - orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)] - - landmark_scale = np.array([[x_scale, y_scale]]) - landmark_shift = np.array([[orig_left, orig_up]]) - landmark_shift2 = np.array([[ox1, oy1]]) - - - landmarks = [] - - for k in tqdm(range(coeff_first.shape[0]), '1st:'): - cur_coeff_full = torch.tensor(coeff_first, device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - print(orig_up, orig_left, orig_crop_size, s) - - for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): - cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - rendered_img = facemodel.pred_face - rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) - out_img = rendered_img[:, :, :3].astype(np.uint8) - - video.write(np.uint8(out_img[:,:,::-1])) - - video.release() - - # visualize landmarks - video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1])) - - for k in tqdm(range(len(landmarks)), 'face3d vis:'): - # im = draw_landmarks(im0.copy(), landmarks[k]) - im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k]) - video.write(im) - video.release() - - shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png')) - - np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks) - - # command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) - # subprocess.call(command, shell=platform.system() != 'Windows') - diff --git a/sadtalker_video2pose/src/face3d/visualize_old.py b/sadtalker_video2pose/src/face3d/visualize_old.py deleted file mode 100644 index b4a37b388320344fd96b4778b60679440fe584c3..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/face3d/visualize_old.py +++ /dev/null @@ -1,110 +0,0 @@ -# check the sync of 3dmm feature and the audio -import shutil -import cv2 -import numpy as np -from src.face3d.models.bfm import ParametricFaceModel -from src.face3d.models.facerecon_model import FaceReconModel -import torch -import subprocess, platform -import scipy.io as scio -from tqdm import tqdm - - -def draw_landmarks(image, landmarks): - for i, point in enumerate(landmarks): - cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 255, 0), -1) - cv2.putText(image, str(i), (int(point[0]), int(point[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1) - return image - -# draft -def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, save_lmk_path, crop_info, extended_crop = False): - - coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] - info = scio.loadmat(first_frame_coeff)['trans_params'][0] - print(info) - - coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] - - coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 - - coeff_full[:, 80:144] = coeff_pred[:, 0:64] - coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation - coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - tmp_video_path = '/tmp/face3dtmp.mp4' - facemodel = FaceReconModel(args) - im0 = cv2.imread(args.source_image) - - video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) - - # since we resize the video, we first need to resize the landmark to the cropped size resolution - # then, we need to add it back to the original video - x_scale, y_scale = (ox2 - ox1)/256 , (oy2 - oy1)/256 - - W, H = im0.shape[0], im0.shape[1] - - _, _, s, _, _, orig_left, orig_up, orig_crop_size =(info[0], info[1], info[2], info[3], info[4], info[5], info[6], info[7]) - orig_left, orig_up, orig_crop_size = [int(x) for x in (orig_left, orig_up, orig_crop_size)] - - landmark_scale = np.array([[x_scale, y_scale]]) - landmark_shift = np.array([[orig_left, orig_up]]) - landmark_shift2 = np.array([[ox1, oy1]]) - - landmarks = [] - - print(orig_up, orig_left, orig_crop_size, s) - - for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): - cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) - - facemodel.forward(cur_coeff_full, device) - - predicted_landmark = facemodel.pred_lm # TODO. - predicted_landmark = predicted_landmark.cpu().numpy().squeeze() - - predicted_landmark[:, 1] = 224 - predicted_landmark[:, 1] - - predicted_landmark = ((predicted_landmark + landmark_shift) / s[0] * landmark_scale) + landmark_shift2 - - landmarks.append(predicted_landmark) - - rendered_img = facemodel.pred_face - rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) - out_img = rendered_img[:, :, :3].astype(np.uint8) - - video.write(np.uint8(out_img[:,:,::-1])) - - video.release() - - # visualize landmarks - video = cv2.VideoWriter(save_lmk_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (im0.shape[0], im0.shape[1])) - - for k in tqdm(range(len(landmarks)), 'face3d vis:'): - # im = draw_landmarks(im0.copy(), landmarks[k]) - im = draw_landmarks(np.uint8(np.ones_like(im0)*255), landmarks[k]) - video.write(im) - video.release() - - shutil.copyfile(args.source_image, save_lmk_path.replace('.mp4', '.png')) - - np.save(save_lmk_path.replace('.mp4', '.npy'), landmarks) - - command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) - subprocess.call(command, shell=platform.system() != 'Windows') - diff --git a/sadtalker_video2pose/src/facerender/animate.py b/sadtalker_video2pose/src/facerender/animate.py deleted file mode 100644 index 45fcb45edb4169166b851a066c8aaf08063ed1c6..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/animate.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -import cv2 -import yaml -import numpy as np -import warnings -from skimage import img_as_ubyte -import safetensors -import safetensors.torch -warnings.filterwarnings('ignore') - - -import imageio -import torch -import torchvision - - -from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector -from src.facerender.modules.mapping import MappingNet -from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator -from src.facerender.modules.make_animation import make_animation - -from pydub import AudioSegment -from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list -from src.utils.paste_pic import paste_pic -from src.utils.videoio import save_video_with_watermark - -try: - import webui # in webui - in_webui = True -except: - in_webui = False - -class AnimateFromCoeff(): - - def __init__(self, sadtalker_path, device): - - with open(sadtalker_path['facerender_yaml']) as f: - config = yaml.safe_load(f) - - generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], - **config['model_params']['common_params']) - kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], - **config['model_params']['common_params']) - he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], - **config['model_params']['common_params']) - mapping = MappingNet(**config['model_params']['mapping_params']) - - generator.to(device) - kp_extractor.to(device) - he_estimator.to(device) - mapping.to(device) - for param in generator.parameters(): - param.requires_grad = False - for param in kp_extractor.parameters(): - param.requires_grad = False - for param in he_estimator.parameters(): - param.requires_grad = False - for param in mapping.parameters(): - param.requires_grad = False - - if sadtalker_path is not None: - if 'checkpoint' in sadtalker_path: # use safe tensor - self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None) - else: - self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) - else: - raise AttributeError("Checkpoint should be specified for video head pose estimator.") - - if sadtalker_path['mappingnet_checkpoint'] is not None: - self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping) - else: - raise AttributeError("Checkpoint should be specified for video head pose estimator.") - - self.kp_extractor = kp_extractor - self.generator = generator - self.he_estimator = he_estimator - self.mapping = mapping - - self.kp_extractor.eval() - self.generator.eval() - self.he_estimator.eval() - self.mapping.eval() - - self.device = device - - def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, - kp_detector=None, he_estimator=None, - device="cpu"): - - checkpoint = safetensors.torch.load_file(checkpoint_path) - - if generator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'generator' in k: - x_generator[k.replace('generator.', '')] = v - generator.load_state_dict(x_generator) - if kp_detector is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'kp_extractor' in k: - x_generator[k.replace('kp_extractor.', '')] = v - kp_detector.load_state_dict(x_generator) - if he_estimator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'he_estimator' in k: - x_generator[k.replace('he_estimator.', '')] = v - he_estimator.load_state_dict(x_generator) - - return None - - def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, - kp_detector=None, he_estimator=None, optimizer_generator=None, - optimizer_discriminator=None, optimizer_kp_detector=None, - optimizer_he_estimator=None, device="cpu"): - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if generator is not None: - generator.load_state_dict(checkpoint['generator']) - if kp_detector is not None: - kp_detector.load_state_dict(checkpoint['kp_detector']) - if he_estimator is not None: - he_estimator.load_state_dict(checkpoint['he_estimator']) - if discriminator is not None: - try: - discriminator.load_state_dict(checkpoint['discriminator']) - except: - print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') - if optimizer_generator is not None: - optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) - if optimizer_discriminator is not None: - try: - optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) - except RuntimeError as e: - print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') - if optimizer_kp_detector is not None: - optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) - if optimizer_he_estimator is not None: - optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) - - return checkpoint['epoch'] - - def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, - optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if mapping is not None: - mapping.load_state_dict(checkpoint['mapping']) - if discriminator is not None: - discriminator.load_state_dict(checkpoint['discriminator']) - if optimizer_mapping is not None: - optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) - if optimizer_discriminator is not None: - optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) - - return checkpoint['epoch'] - - def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - if 'yaw_c_seq' in x: - yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor) - yaw_c_seq = x['yaw_c_seq'].to(self.device) - else: - yaw_c_seq = None - if 'pitch_c_seq' in x: - pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor) - pitch_c_seq = x['pitch_c_seq'].to(self.device) - else: - pitch_c_seq = None - if 'roll_c_seq' in x: - roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) - roll_c_seq = x['roll_c_seq'].to(self.device) - else: - roll_c_seq = None - - frame_num = x['frame_num'] - - predictions_video = make_animation(source_image, source_semantics, target_semantics, - self.generator, self.kp_extractor, self.he_estimator, self.mapping, - yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True) - - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - predictions_video = predictions_video[:frame_num] - - video = [] - for idx in range(predictions_video.shape[0]): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - result = img_as_ubyte(video) - - ### the generated video is 256x256, so we keep the aspect ratio, - original_size = crop_info[0] - if original_size: - result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - - video_name = x['video_name'] + '.mp4' - path = os.path.join(video_save_dir, 'temp_'+video_name) - - # print(path) - - imageio.mimsave(path, result, fps=float(25)) - - av_path = os.path.join(video_save_dir, video_name) - return_path = av_path - - audio_path = x['audio_path'] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') - start_time = 0 - # cog will not keep the .mp3 filename - sound = AudioSegment.from_file(audio_path) - frames = frame_num - end_time = start_time + frames*1/25*1000 - word1=sound.set_frame_rate(16000) - word = word1[start_time:end_time] - word.export(new_audio_path, format="wav") - - save_video_with_watermark(path, new_audio_path, av_path, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name}') - - if 'full' in preprocess.lower(): - # only add watermark to the full image. - video_name_full = x['video_name'] + '_full.mp4' - full_video_path = os.path.join(video_save_dir, video_name_full) - return_path = full_video_path - paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) - print(f'The generated video is named {video_save_dir}/{video_name_full}') - else: - full_video_path = av_path - - #### paste back then enhancers - if enhancer: - video_name_enhancer = x['video_name'] + '_enhanced.mp4' - enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) - av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) - return_path = av_path_enhancer - - try: - enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - except: - enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - - save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') - - - # os.remove(enhanced_path) - - # os.remove(path) - # os.remove(new_audio_path) - - return return_path - diff --git a/sadtalker_video2pose/src/facerender/modules/dense_motion.py b/sadtalker_video2pose/src/facerender/modules/dense_motion.py deleted file mode 100644 index 4c30417870e79bc005ea47a8f383c3aa406df563..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/modules/dense_motion.py +++ /dev/null @@ -1,121 +0,0 @@ -from torch import nn -import torch.nn.functional as F -import torch -from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian - -from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d - - -class DenseMotionNetwork(nn.Module): - """ - Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving - """ - - def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, - estimate_occlusion_map=False): - super(DenseMotionNetwork, self).__init__() - # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) - self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) - - self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) - - self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) - self.norm = BatchNorm3d(compress, affine=True) - - if estimate_occlusion_map: - # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) - self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) - else: - self.occlusion = None - - self.num_kp = num_kp - - - def create_sparse_motions(self, feature, kp_driving, kp_source): - bs, _, d, h, w = feature.shape - identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) - identity_grid = identity_grid.view(1, 1, d, h, w, 3) - coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) - - # if 'jacobian' in kp_driving: - if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: - jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) - jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) - jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) - coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) - coordinate_grid = coordinate_grid.squeeze(-1) - - - driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) - - #adding background feature - identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) - sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3 - - # sparse_motions = driving_to_source - - return sparse_motions - - def create_deformed_feature(self, feature, sparse_motions): - bs, _, d, h, w = feature.shape - feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) - feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) - sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!! - sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) - sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) - return sparse_deformed - - def create_heatmap_representations(self, feature, kp_driving, kp_source): - spatial_size = feature.shape[3:] - gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) - gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) - heatmap = gaussian_driving - gaussian_source - - # adding background feature - zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) - heatmap = torch.cat([zeros, heatmap], dim=1) - heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) - return heatmap - - def forward(self, feature, kp_driving, kp_source): - bs, _, d, h, w = feature.shape - - feature = self.compress(feature) - feature = self.norm(feature) - feature = F.relu(feature) - - out_dict = dict() - sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) - deformed_feature = self.create_deformed_feature(feature, sparse_motion) - - heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) - - input_ = torch.cat([heatmap, deformed_feature], dim=2) - input_ = input_.view(bs, -1, d, h, w) - - # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) - - prediction = self.hourglass(input_) - - - mask = self.mask(prediction) - mask = F.softmax(mask, dim=1) - out_dict['mask'] = mask - mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) - - zeros_mask = torch.zeros_like(mask) - mask = torch.where(mask < 1e-3, zeros_mask, mask) - - sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) - deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) - deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) - - out_dict['deformation'] = deformation - - if self.occlusion: - bs, c, d, h, w = prediction.shape - prediction = prediction.view(bs, -1, h, w) - occlusion_map = torch.sigmoid(self.occlusion(prediction)) - out_dict['occlusion_map'] = occlusion_map - - return out_dict diff --git a/sadtalker_video2pose/src/facerender/modules/discriminator.py b/sadtalker_video2pose/src/facerender/modules/discriminator.py deleted file mode 100644 index cc0a2b460d2175a958d7b230b7e5233d7d7c7f92..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/modules/discriminator.py +++ /dev/null @@ -1,90 +0,0 @@ -from torch import nn -import torch.nn.functional as F -from facerender.modules.util import kp2gaussian -import torch - - -class DownBlock2d(nn.Module): - """ - Simple block for processing video (encoder). - """ - - def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): - super(DownBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) - - if sn: - self.conv = nn.utils.spectral_norm(self.conv) - - if norm: - self.norm = nn.InstanceNorm2d(out_features, affine=True) - else: - self.norm = None - self.pool = pool - - def forward(self, x): - out = x - out = self.conv(out) - if self.norm: - out = self.norm(out) - out = F.leaky_relu(out, 0.2) - if self.pool: - out = F.avg_pool2d(out, (2, 2)) - return out - - -class Discriminator(nn.Module): - """ - Discriminator similar to Pix2Pix - """ - - def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, - sn=False, **kwargs): - super(Discriminator, self).__init__() - - down_blocks = [] - for i in range(num_blocks): - down_blocks.append( - DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), - min(max_features, block_expansion * (2 ** (i + 1))), - norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) - - self.down_blocks = nn.ModuleList(down_blocks) - self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) - if sn: - self.conv = nn.utils.spectral_norm(self.conv) - - def forward(self, x): - feature_maps = [] - out = x - - for down_block in self.down_blocks: - feature_maps.append(down_block(out)) - out = feature_maps[-1] - prediction_map = self.conv(out) - - return feature_maps, prediction_map - - -class MultiScaleDiscriminator(nn.Module): - """ - Multi-scale (scale) discriminator - """ - - def __init__(self, scales=(), **kwargs): - super(MultiScaleDiscriminator, self).__init__() - self.scales = scales - discs = {} - for scale in scales: - discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) - self.discs = nn.ModuleDict(discs) - - def forward(self, x): - out_dict = {} - for scale, disc in self.discs.items(): - scale = str(scale).replace('-', '.') - key = 'prediction_' + scale - feature_maps, prediction_map = disc(x[key]) - out_dict['feature_maps_' + scale] = feature_maps - out_dict['prediction_map_' + scale] = prediction_map - return out_dict diff --git a/sadtalker_video2pose/src/facerender/modules/generator.py b/sadtalker_video2pose/src/facerender/modules/generator.py deleted file mode 100644 index 2b94dde7a37c5ddf0f74dd0317a5db3507ab0729..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/modules/generator.py +++ /dev/null @@ -1,255 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock -from src.facerender.modules.dense_motion import DenseMotionNetwork - - -class OcclusionAwareGenerator(nn.Module): - """ - Generator follows NVIDIA architecture. - """ - - def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, - num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): - super(OcclusionAwareGenerator, self).__init__() - - if dense_motion_params is not None: - self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, - estimate_occlusion_map=estimate_occlusion_map, - **dense_motion_params) - else: - self.dense_motion_network = None - - self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) - - down_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, block_expansion * (2 ** i)) - out_features = min(max_features, block_expansion * (2 ** (i + 1))) - down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - self.down_blocks = nn.ModuleList(down_blocks) - - self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) - - self.reshape_channel = reshape_channel - self.reshape_depth = reshape_depth - - self.resblocks_3d = torch.nn.Sequential() - for i in range(num_resblocks): - self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) - - out_features = block_expansion * (2 ** (num_down_blocks)) - self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) - self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) - - self.resblocks_2d = torch.nn.Sequential() - for i in range(num_resblocks): - self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) - - up_blocks = [] - for i in range(num_down_blocks): - in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) - out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) - up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - self.up_blocks = nn.ModuleList(up_blocks) - - self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) - self.estimate_occlusion_map = estimate_occlusion_map - self.image_channel = image_channel - - def deform_input(self, inp, deformation): - _, d_old, h_old, w_old, _ = deformation.shape - _, _, d, h, w = inp.shape - if d_old != d or h_old != h or w_old != w: - deformation = deformation.permute(0, 4, 1, 2, 3) - deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') - deformation = deformation.permute(0, 2, 3, 4, 1) - return F.grid_sample(inp, deformation) - - def forward(self, source_image, kp_driving, kp_source): - # Encoding (downsampling) part - out = self.first(source_image) - for i in range(len(self.down_blocks)): - out = self.down_blocks[i](out) - out = self.second(out) - bs, c, h, w = out.shape - # print(out.shape) - feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) - feature_3d = self.resblocks_3d(feature_3d) - - # Transforming feature representation according to deformation and occlusion - output_dict = {} - if self.dense_motion_network is not None: - dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, - kp_source=kp_source) - output_dict['mask'] = dense_motion['mask'] - - if 'occlusion_map' in dense_motion: - occlusion_map = dense_motion['occlusion_map'] - output_dict['occlusion_map'] = occlusion_map - else: - occlusion_map = None - deformation = dense_motion['deformation'] - out = self.deform_input(feature_3d, deformation) - - bs, c, d, h, w = out.shape - out = out.view(bs, c*d, h, w) - out = self.third(out) - out = self.fourth(out) - - if occlusion_map is not None: - if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: - occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') - out = out * occlusion_map - - # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image - - # Decoding part - out = self.resblocks_2d(out) - for i in range(len(self.up_blocks)): - out = self.up_blocks[i](out) - out = self.final(out) - out = F.sigmoid(out) - - output_dict["prediction"] = out - - return output_dict - - -class SPADEDecoder(nn.Module): - def __init__(self): - super().__init__() - ic = 256 - oc = 64 - norm_G = 'spadespectralinstance' - label_nc = 256 - - self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) - self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) - self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) - self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) - self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) - self.up = nn.Upsample(scale_factor=2) - - def forward(self, feature): - seg = feature - x = self.fc(feature) - x = self.G_middle_0(x, seg) - x = self.G_middle_1(x, seg) - x = self.G_middle_2(x, seg) - x = self.G_middle_3(x, seg) - x = self.G_middle_4(x, seg) - x = self.G_middle_5(x, seg) - x = self.up(x) - x = self.up_0(x, seg) # 256, 128, 128 - x = self.up(x) - x = self.up_1(x, seg) # 64, 256, 256 - - x = self.conv_img(F.leaky_relu(x, 2e-1)) - # x = torch.tanh(x) - x = F.sigmoid(x) - - return x - - -class OcclusionAwareSPADEGenerator(nn.Module): - - def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, - num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): - super(OcclusionAwareSPADEGenerator, self).__init__() - - if dense_motion_params is not None: - self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, - estimate_occlusion_map=estimate_occlusion_map, - **dense_motion_params) - else: - self.dense_motion_network = None - - self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) - - down_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, block_expansion * (2 ** i)) - out_features = min(max_features, block_expansion * (2 ** (i + 1))) - down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - self.down_blocks = nn.ModuleList(down_blocks) - - self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) - - self.reshape_channel = reshape_channel - self.reshape_depth = reshape_depth - - self.resblocks_3d = torch.nn.Sequential() - for i in range(num_resblocks): - self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) - - out_features = block_expansion * (2 ** (num_down_blocks)) - self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) - self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) - - self.estimate_occlusion_map = estimate_occlusion_map - self.image_channel = image_channel - - self.decoder = SPADEDecoder() - - def deform_input(self, inp, deformation): - _, d_old, h_old, w_old, _ = deformation.shape - _, _, d, h, w = inp.shape - if d_old != d or h_old != h or w_old != w: - deformation = deformation.permute(0, 4, 1, 2, 3) - deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') - deformation = deformation.permute(0, 2, 3, 4, 1) - return F.grid_sample(inp, deformation) - - def forward(self, source_image, kp_driving, kp_source): - # Encoding (downsampling) part - out = self.first(source_image) - for i in range(len(self.down_blocks)): - out = self.down_blocks[i](out) - out = self.second(out) - bs, c, h, w = out.shape - # print(out.shape) - feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) - feature_3d = self.resblocks_3d(feature_3d) - - # Transforming feature representation according to deformation and occlusion - output_dict = {} - if self.dense_motion_network is not None: - dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, - kp_source=kp_source) - output_dict['mask'] = dense_motion['mask'] - - # import pdb; pdb.set_trace() - - if 'occlusion_map' in dense_motion: - occlusion_map = dense_motion['occlusion_map'] - output_dict['occlusion_map'] = occlusion_map - else: - occlusion_map = None - deformation = dense_motion['deformation'] - out = self.deform_input(feature_3d, deformation) - - bs, c, d, h, w = out.shape - out = out.view(bs, c*d, h, w) - out = self.third(out) - out = self.fourth(out) - - # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map) - - if occlusion_map is not None: - if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: - occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') - out = out * occlusion_map - - # Decoding part - out = self.decoder(out) - - output_dict["prediction"] = out - - return output_dict \ No newline at end of file diff --git a/sadtalker_video2pose/src/facerender/modules/keypoint_detector.py b/sadtalker_video2pose/src/facerender/modules/keypoint_detector.py deleted file mode 100644 index e56800c7b1e94bb3cbf97200cd3f059ce9d29cf3..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/modules/keypoint_detector.py +++ /dev/null @@ -1,179 +0,0 @@ -from torch import nn -import torch -import torch.nn.functional as F - -from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d -from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck - - -class KPDetector(nn.Module): - """ - Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. - """ - - def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, - num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): - super(KPDetector, self).__init__() - - self.predictor = KPHourglass(block_expansion, in_features=image_channel, - max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) - - # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) - self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) - - if estimate_jacobian: - self.num_jacobian_maps = 1 if single_jacobian_map else num_kp - # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) - self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) - ''' - initial as: - [[1 0 0] - [0 1 0] - [0 0 1]] - ''' - self.jacobian.weight.data.zero_() - self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) - else: - self.jacobian = None - - self.temperature = temperature - self.scale_factor = scale_factor - if self.scale_factor != 1: - self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) - - def gaussian2kp(self, heatmap): - """ - Extract the mean from a heatmap - """ - shape = heatmap.shape - heatmap = heatmap.unsqueeze(-1) - grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) - value = (heatmap * grid).sum(dim=(2, 3, 4)) - kp = {'value': value} - - return kp - - def forward(self, x): - if self.scale_factor != 1: - x = self.down(x) - - feature_map = self.predictor(x) - prediction = self.kp(feature_map) - - final_shape = prediction.shape - heatmap = prediction.view(final_shape[0], final_shape[1], -1) - heatmap = F.softmax(heatmap / self.temperature, dim=2) - heatmap = heatmap.view(*final_shape) - - out = self.gaussian2kp(heatmap) - - if self.jacobian is not None: - jacobian_map = self.jacobian(feature_map) - jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], - final_shape[3], final_shape[4]) - heatmap = heatmap.unsqueeze(2) - - jacobian = heatmap * jacobian_map - jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) - jacobian = jacobian.sum(dim=-1) - jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) - out['jacobian'] = jacobian - - return out - - -class HEEstimator(nn.Module): - """ - Estimating head pose and expression. - """ - - def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): - super(HEEstimator, self).__init__() - - self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) - self.norm1 = BatchNorm2d(block_expansion, affine=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - - self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) - self.norm2 = BatchNorm2d(256, affine=True) - - self.block1 = nn.Sequential() - for i in range(3): - self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) - - self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) - self.norm3 = BatchNorm2d(512, affine=True) - self.block2 = ResBottleneck(in_features=512, stride=2) - - self.block3 = nn.Sequential() - for i in range(3): - self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) - - self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) - self.norm4 = BatchNorm2d(1024, affine=True) - self.block4 = ResBottleneck(in_features=1024, stride=2) - - self.block5 = nn.Sequential() - for i in range(5): - self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) - - self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) - self.norm5 = BatchNorm2d(2048, affine=True) - self.block6 = ResBottleneck(in_features=2048, stride=2) - - self.block7 = nn.Sequential() - for i in range(2): - self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) - - self.fc_roll = nn.Linear(2048, num_bins) - self.fc_pitch = nn.Linear(2048, num_bins) - self.fc_yaw = nn.Linear(2048, num_bins) - - self.fc_t = nn.Linear(2048, 3) - - self.fc_exp = nn.Linear(2048, 3*num_kp) - - def forward(self, x): - out = self.conv1(x) - out = self.norm1(out) - out = F.relu(out) - out = self.maxpool(out) - - out = self.conv2(out) - out = self.norm2(out) - out = F.relu(out) - - out = self.block1(out) - - out = self.conv3(out) - out = self.norm3(out) - out = F.relu(out) - out = self.block2(out) - - out = self.block3(out) - - out = self.conv4(out) - out = self.norm4(out) - out = F.relu(out) - out = self.block4(out) - - out = self.block5(out) - - out = self.conv5(out) - out = self.norm5(out) - out = F.relu(out) - out = self.block6(out) - - out = self.block7(out) - - out = F.adaptive_avg_pool2d(out, 1) - out = out.view(out.shape[0], -1) - - yaw = self.fc_roll(out) - pitch = self.fc_pitch(out) - roll = self.fc_yaw(out) - t = self.fc_t(out) - exp = self.fc_exp(out) - - return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} - diff --git a/sadtalker_video2pose/src/facerender/modules/make_animation.py b/sadtalker_video2pose/src/facerender/modules/make_animation.py deleted file mode 100644 index 42c8c53dcc04da8354d05c98c2bc0d88bf067fb2..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/modules/make_animation.py +++ /dev/null @@ -1,170 +0,0 @@ -from scipy.spatial import ConvexHull -import torch -import torch.nn.functional as F -import numpy as np -from tqdm import tqdm - -def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, - use_relative_movement=False, use_relative_jacobian=False): - if adapt_movement_scale: - source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume - driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume - adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) - else: - adapt_movement_scale = 1 - - kp_new = {k: v for k, v in kp_driving.items()} - - if use_relative_movement: - kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) - kp_value_diff *= adapt_movement_scale - kp_new['value'] = kp_value_diff + kp_source['value'] - - if use_relative_jacobian: - jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) - kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) - - return kp_new - -def headpose_pred_to_degree(pred): - device = pred.device - idx_tensor = [idx for idx in range(66)] - idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device) - pred = F.softmax(pred) - degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 - return degree - -def get_rotation_matrix(yaw, pitch, roll): - yaw = yaw / 180 * 3.14 - pitch = pitch / 180 * 3.14 - roll = roll / 180 * 3.14 - - roll = roll.unsqueeze(1) - pitch = pitch.unsqueeze(1) - yaw = yaw.unsqueeze(1) - - pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), - torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), - torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) - pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) - - yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), - torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), - -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) - yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) - - roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), - torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), - torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) - roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) - - rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) - - return rot_mat - -def keypoint_transformation(kp_canonical, he, wo_exp=False): - kp = kp_canonical['value'] # (bs, k, 3) - yaw, pitch, roll= he['yaw'], he['pitch'], he['roll'] - yaw = headpose_pred_to_degree(yaw) - pitch = headpose_pred_to_degree(pitch) - roll = headpose_pred_to_degree(roll) - - if 'yaw_in' in he: - yaw = he['yaw_in'] - if 'pitch_in' in he: - pitch = he['pitch_in'] - if 'roll_in' in he: - roll = he['roll_in'] - - rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) - - t, exp = he['t'], he['exp'] - if wo_exp: - exp = exp*0 - - # keypoint rotation - kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) - - # keypoint translation - t[:, 0] = t[:, 0]*0 - t[:, 2] = t[:, 2]*0 - t = t.unsqueeze(1).repeat(1, kp.shape[1], 1) - kp_t = kp_rotated + t - - # add expression deviation - exp = exp.view(exp.shape[0], -1, 3) - kp_transformed = kp_t + exp - - return {'value': kp_transformed} - - - -def make_animation(source_image, source_semantics, target_semantics, - generator, kp_detector, he_estimator, mapping, - yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, - use_exp=True, use_half=False): - with torch.no_grad(): - predictions = [] - - kp_canonical = kp_detector(source_image) - he_source = mapping(source_semantics) - kp_source = keypoint_transformation(kp_canonical, he_source) - - for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'): - # still check the dimension - # print(target_semantics.shape, source_semantics.shape) - target_semantics_frame = target_semantics[:, frame_idx] - he_driving = mapping(target_semantics_frame) - if yaw_c_seq is not None: - he_driving['yaw_in'] = yaw_c_seq[:, frame_idx] - if pitch_c_seq is not None: - he_driving['pitch_in'] = pitch_c_seq[:, frame_idx] - if roll_c_seq is not None: - he_driving['roll_in'] = roll_c_seq[:, frame_idx] - - kp_driving = keypoint_transformation(kp_canonical, he_driving) - - kp_norm = kp_driving - out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm) - ''' - source_image_new = out['prediction'].squeeze(1) - kp_canonical_new = kp_detector(source_image_new) - he_source_new = he_estimator(source_image_new) - kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True) - kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True) - out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new) - ''' - predictions.append(out['prediction']) - predictions_ts = torch.stack(predictions, dim=1) - return predictions_ts - -class AnimateModel(torch.nn.Module): - """ - Merge all generator related updates into single model for better multi-gpu usage - """ - - def __init__(self, generator, kp_extractor, mapping): - super(AnimateModel, self).__init__() - self.kp_extractor = kp_extractor - self.generator = generator - self.mapping = mapping - - self.kp_extractor.eval() - self.generator.eval() - self.mapping.eval() - - def forward(self, x): - - source_image = x['source_image'] - source_semantics = x['source_semantics'] - target_semantics = x['target_semantics'] - yaw_c_seq = x['yaw_c_seq'] - pitch_c_seq = x['pitch_c_seq'] - roll_c_seq = x['roll_c_seq'] - - predictions_video = make_animation(source_image, source_semantics, target_semantics, - self.generator, self.kp_extractor, - self.mapping, use_exp = True, - yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq) - - return predictions_video \ No newline at end of file diff --git a/sadtalker_video2pose/src/facerender/modules/mapping.py b/sadtalker_video2pose/src/facerender/modules/mapping.py deleted file mode 100644 index 5ac98dd9e177b949f71f8f47029b66d67ece05b4..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/modules/mapping.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class MappingNet(nn.Module): - def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): - super( MappingNet, self).__init__() - - self.layer = layer - nonlinearity = nn.LeakyReLU(0.1) - - self.first = nn.Sequential( - torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) - - for i in range(layer): - net = nn.Sequential(nonlinearity, - torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) - setattr(self, 'encoder' + str(i), net) - - self.pooling = nn.AdaptiveAvgPool1d(1) - self.output_nc = descriptor_nc - - self.fc_roll = nn.Linear(descriptor_nc, num_bins) - self.fc_pitch = nn.Linear(descriptor_nc, num_bins) - self.fc_yaw = nn.Linear(descriptor_nc, num_bins) - self.fc_t = nn.Linear(descriptor_nc, 3) - self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) - - def forward(self, input_3dmm): - out = self.first(input_3dmm) - for i in range(self.layer): - model = getattr(self, 'encoder' + str(i)) - out = model(out) + out[:,:,3:-3] - out = self.pooling(out) - out = out.view(out.shape[0], -1) - #print('out:', out.shape) - - yaw = self.fc_yaw(out) - pitch = self.fc_pitch(out) - roll = self.fc_roll(out) - t = self.fc_t(out) - exp = self.fc_exp(out) - - return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} \ No newline at end of file diff --git a/sadtalker_video2pose/src/facerender/modules/util.py b/sadtalker_video2pose/src/facerender/modules/util.py deleted file mode 100644 index f3bfb1f26427b491f032ca9952db41cdeb793d70..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/modules/util.py +++ /dev/null @@ -1,564 +0,0 @@ -from torch import nn - -import torch.nn.functional as F -import torch - -from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d -from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d - -import torch.nn.utils.spectral_norm as spectral_norm - - -def kp2gaussian(kp, spatial_size, kp_variance): - """ - Transform a keypoint into gaussian like representation - """ - mean = kp['value'] - - coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) - number_of_leading_dimensions = len(mean.shape) - 1 - shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape - coordinate_grid = coordinate_grid.view(*shape) - repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) - coordinate_grid = coordinate_grid.repeat(*repeats) - - # Preprocess kp shape - shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) - mean = mean.view(*shape) - - mean_sub = (coordinate_grid - mean) - - out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) - - return out - -def make_coordinate_grid_2d(spatial_size, type): - """ - Create a meshgrid [-1,1] x [-1,1] of given spatial_size. - """ - h, w = spatial_size - x = torch.arange(w).type(type) - y = torch.arange(h).type(type) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - - yy = y.view(-1, 1).repeat(1, w) - xx = x.view(1, -1).repeat(h, 1) - - meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) - - return meshed - - -def make_coordinate_grid(spatial_size, type): - d, h, w = spatial_size - x = torch.arange(w).type(type) - y = torch.arange(h).type(type) - z = torch.arange(d).type(type) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - z = (2 * (z / (d - 1)) - 1) - - yy = y.view(1, -1, 1).repeat(d, 1, w) - xx = x.view(1, 1, -1).repeat(d, h, 1) - zz = z.view(-1, 1, 1).repeat(1, h, w) - - meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) - - return meshed - - -class ResBottleneck(nn.Module): - def __init__(self, in_features, stride): - super(ResBottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1) - self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride) - self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1) - self.norm1 = BatchNorm2d(in_features//4, affine=True) - self.norm2 = BatchNorm2d(in_features//4, affine=True) - self.norm3 = BatchNorm2d(in_features, affine=True) - - self.stride = stride - if self.stride != 1: - self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride) - self.norm4 = BatchNorm2d(in_features, affine=True) - - def forward(self, x): - out = self.conv1(x) - out = self.norm1(out) - out = F.relu(out) - out = self.conv2(out) - out = self.norm2(out) - out = F.relu(out) - out = self.conv3(out) - out = self.norm3(out) - if self.stride != 1: - x = self.skip(x) - x = self.norm4(x) - out += x - out = F.relu(out) - return out - - -class ResBlock2d(nn.Module): - """ - Res block, preserve spatial resolution. - """ - - def __init__(self, in_features, kernel_size, padding): - super(ResBlock2d, self).__init__() - self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.norm1 = BatchNorm2d(in_features, affine=True) - self.norm2 = BatchNorm2d(in_features, affine=True) - - def forward(self, x): - out = self.norm1(x) - out = F.relu(out) - out = self.conv1(out) - out = self.norm2(out) - out = F.relu(out) - out = self.conv2(out) - out += x - return out - - -class ResBlock3d(nn.Module): - """ - Res block, preserve spatial resolution. - """ - - def __init__(self, in_features, kernel_size, padding): - super(ResBlock3d, self).__init__() - self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, - padding=padding) - self.norm1 = BatchNorm3d(in_features, affine=True) - self.norm2 = BatchNorm3d(in_features, affine=True) - - def forward(self, x): - out = self.norm1(x) - out = F.relu(out) - out = self.conv1(out) - out = self.norm2(out) - out = F.relu(out) - out = self.conv2(out) - out += x - return out - - -class UpBlock2d(nn.Module): - """ - Upsampling block for use in decoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(UpBlock2d, self).__init__() - - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm2d(out_features, affine=True) - - def forward(self, x): - out = F.interpolate(x, scale_factor=2) - out = self.conv(out) - out = self.norm(out) - out = F.relu(out) - return out - -class UpBlock3d(nn.Module): - """ - Upsampling block for use in decoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(UpBlock3d, self).__init__() - - self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm3d(out_features, affine=True) - - def forward(self, x): - # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') - out = F.interpolate(x, scale_factor=(1, 2, 2)) - out = self.conv(out) - out = self.norm(out) - out = F.relu(out) - return out - - -class DownBlock2d(nn.Module): - """ - Downsampling block for use in encoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(DownBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm2d(out_features, affine=True) - self.pool = nn.AvgPool2d(kernel_size=(2, 2)) - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = F.relu(out) - out = self.pool(out) - return out - - -class DownBlock3d(nn.Module): - """ - Downsampling block for use in encoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - super(DownBlock3d, self).__init__() - ''' - self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups, stride=(1, 2, 2)) - ''' - self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, - padding=padding, groups=groups) - self.norm = BatchNorm3d(out_features, affine=True) - self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = F.relu(out) - out = self.pool(out) - return out - - -class SameBlock2d(nn.Module): - """ - Simple block, preserve spatial resolution. - """ - - def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): - super(SameBlock2d, self).__init__() - self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, - kernel_size=kernel_size, padding=padding, groups=groups) - self.norm = BatchNorm2d(out_features, affine=True) - if lrelu: - self.ac = nn.LeakyReLU() - else: - self.ac = nn.ReLU() - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = self.ac(out) - return out - - -class Encoder(nn.Module): - """ - Hourglass Encoder - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(Encoder, self).__init__() - - down_blocks = [] - for i in range(num_blocks): - down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), - min(max_features, block_expansion * (2 ** (i + 1))), - kernel_size=3, padding=1)) - self.down_blocks = nn.ModuleList(down_blocks) - - def forward(self, x): - outs = [x] - for down_block in self.down_blocks: - outs.append(down_block(outs[-1])) - return outs - - -class Decoder(nn.Module): - """ - Hourglass Decoder - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(Decoder, self).__init__() - - up_blocks = [] - - for i in range(num_blocks)[::-1]: - in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) - out_filters = min(max_features, block_expansion * (2 ** i)) - up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) - - self.up_blocks = nn.ModuleList(up_blocks) - # self.out_filters = block_expansion - self.out_filters = block_expansion + in_features - - self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) - self.norm = BatchNorm3d(self.out_filters, affine=True) - - def forward(self, x): - out = x.pop() - # for up_block in self.up_blocks[:-1]: - for up_block in self.up_blocks: - out = up_block(out) - skip = x.pop() - out = torch.cat([out, skip], dim=1) - # out = self.up_blocks[-1](out) - out = self.conv(out) - out = self.norm(out) - out = F.relu(out) - return out - - -class Hourglass(nn.Module): - """ - Hourglass architecture. - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - super(Hourglass, self).__init__() - self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) - self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) - self.out_filters = self.decoder.out_filters - - def forward(self, x): - return self.decoder(self.encoder(x)) - - -class KPHourglass(nn.Module): - """ - Hourglass architecture. - """ - - def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256): - super(KPHourglass, self).__init__() - - self.down_blocks = nn.Sequential() - for i in range(num_blocks): - self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), - min(max_features, block_expansion * (2 ** (i + 1))), - kernel_size=3, padding=1)) - - in_filters = min(max_features, block_expansion * (2 ** num_blocks)) - self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1) - - self.up_blocks = nn.Sequential() - for i in range(num_blocks): - in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i))) - out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1))) - self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) - - self.reshape_depth = reshape_depth - self.out_filters = out_filters - - def forward(self, x): - out = self.down_blocks(x) - out = self.conv(out) - bs, c, h, w = out.shape - out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w) - out = self.up_blocks(out) - - return out - - - -class AntiAliasInterpolation2d(nn.Module): - """ - Band-limited downsampling, for better preservation of the input signal. - """ - def __init__(self, channels, scale): - super(AntiAliasInterpolation2d, self).__init__() - sigma = (1 / scale - 1) / 2 - kernel_size = 2 * round(sigma * 4) + 1 - self.ka = kernel_size // 2 - self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka - - kernel_size = [kernel_size, kernel_size] - sigma = [sigma, sigma] - # The gaussian kernel is the product of the - # gaussian function of each dimension. - kernel = 1 - meshgrids = torch.meshgrid( - [ - torch.arange(size, dtype=torch.float32) - for size in kernel_size - ] - ) - for size, std, mgrid in zip(kernel_size, sigma, meshgrids): - mean = (size - 1) / 2 - kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) - - # Make sure sum of values in gaussian kernel equals 1. - kernel = kernel / torch.sum(kernel) - # Reshape to depthwise convolutional weight - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - - self.register_buffer('weight', kernel) - self.groups = channels - self.scale = scale - inv_scale = 1 / scale - self.int_inv_scale = int(inv_scale) - - def forward(self, input): - if self.scale == 1.0: - return input - - out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) - out = F.conv2d(out, weight=self.weight, groups=self.groups) - out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] - - return out - - -class SPADE(nn.Module): - def __init__(self, norm_nc, label_nc): - super().__init__() - - self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) - nhidden = 128 - - self.mlp_shared = nn.Sequential( - nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), - nn.ReLU()) - self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) - self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) - - def forward(self, x, segmap): - normalized = self.param_free_norm(x) - segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') - actv = self.mlp_shared(segmap) - gamma = self.mlp_gamma(actv) - beta = self.mlp_beta(actv) - out = normalized * (1 + gamma) + beta - return out - - -class SPADEResnetBlock(nn.Module): - def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): - super().__init__() - # Attributes - self.learned_shortcut = (fin != fout) - fmiddle = min(fin, fout) - self.use_se = use_se - # create conv layers - self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) - self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) - if self.learned_shortcut: - self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) - # apply spectral norm if specified - if 'spectral' in norm_G: - self.conv_0 = spectral_norm(self.conv_0) - self.conv_1 = spectral_norm(self.conv_1) - if self.learned_shortcut: - self.conv_s = spectral_norm(self.conv_s) - # define normalization layers - self.norm_0 = SPADE(fin, label_nc) - self.norm_1 = SPADE(fmiddle, label_nc) - if self.learned_shortcut: - self.norm_s = SPADE(fin, label_nc) - - def forward(self, x, seg1): - x_s = self.shortcut(x, seg1) - dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) - dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) - out = x_s + dx - return out - - def shortcut(self, x, seg1): - if self.learned_shortcut: - x_s = self.conv_s(self.norm_s(x, seg1)) - else: - x_s = x - return x_s - - def actvn(self, x): - return F.leaky_relu(x, 2e-1) - -class audio2image(nn.Module): - def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params): - super().__init__() - # Attributes - self.generator = generator - self.kp_extractor = kp_extractor - self.he_estimator_video = he_estimator_video - self.he_estimator_audio = he_estimator_audio - self.train_params = train_params - - def headpose_pred_to_degree(self, pred): - device = pred.device - idx_tensor = [idx for idx in range(66)] - idx_tensor = torch.FloatTensor(idx_tensor).to(device) - pred = F.softmax(pred) - degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 - - return degree - - def get_rotation_matrix(self, yaw, pitch, roll): - yaw = yaw / 180 * 3.14 - pitch = pitch / 180 * 3.14 - roll = roll / 180 * 3.14 - - roll = roll.unsqueeze(1) - pitch = pitch.unsqueeze(1) - yaw = yaw.unsqueeze(1) - - roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), - torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), - torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) - roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) - - pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), - torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), - -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) - pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) - - yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), - torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), - torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) - yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) - - rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) - - return rot_mat - - def keypoint_transformation(self, kp_canonical, he): - kp = kp_canonical['value'] # (bs, k, 3) - yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] - t, exp = he['t'], he['exp'] - - yaw = self.headpose_pred_to_degree(yaw) - pitch = self.headpose_pred_to_degree(pitch) - roll = self.headpose_pred_to_degree(roll) - - rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) - - # keypoint rotation - kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) - - - - # keypoint translation - t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) - kp_t = kp_rotated + t - - # add expression deviation - exp = exp.view(exp.shape[0], -1, 3) - kp_transformed = kp_t + exp - - return {'value': kp_transformed} - - def forward(self, source_image, target_audio): - pose_source = self.he_estimator_video(source_image) - pose_generated = self.he_estimator_audio(target_audio) - kp_canonical = self.kp_extractor(source_image) - kp_source = self.keypoint_transformation(kp_canonical, pose_source) - kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated) - generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated) - return generated \ No newline at end of file diff --git a/sadtalker_video2pose/src/facerender/pirender/base_function.py b/sadtalker_video2pose/src/facerender/pirender/base_function.py deleted file mode 100644 index 650fb7de1b95fc34e4b7c17b2526c1f450a577a0..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/pirender/base_function.py +++ /dev/null @@ -1,368 +0,0 @@ -import sys -import math - -import torch -from torch import nn -from torch.nn import functional as F -from torch.autograd import Function -from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm - - -class LayerNorm2d(nn.Module): - def __init__(self, n_out, affine=True): - super(LayerNorm2d, self).__init__() - self.n_out = n_out - self.affine = affine - - if self.affine: - self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) - self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) - - def forward(self, x): - normalized_shape = x.size()[1:] - if self.affine: - return F.layer_norm(x, normalized_shape, \ - self.weight.expand(normalized_shape), - self.bias.expand(normalized_shape)) - - else: - return F.layer_norm(x, normalized_shape) - -class ADAINHourglass(nn.Module): - def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect): - super(ADAINHourglass, self).__init__() - self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect) - self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect) - self.output_nc = self.decoder.output_nc - - def forward(self, x, z): - return self.decoder(self.encoder(x, z), z) - - - -class ADAINEncoder(nn.Module): - def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(ADAINEncoder, self).__init__() - self.layers = layers - self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3) - for i in range(layers): - in_channels = min(ngf * (2**i), img_f) - out_channels = min(ngf *(2**(i+1)), img_f) - model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect) - setattr(self, 'encoder' + str(i), model) - self.output_nc = out_channels - - def forward(self, x, z): - out = self.input_layer(x) - out_list = [out] - for i in range(self.layers): - model = getattr(self, 'encoder' + str(i)) - out = model(out, z) - out_list.append(out) - return out_list - -class ADAINDecoder(nn.Module): - """docstring for ADAINDecoder""" - def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True, - nonlinearity=nn.LeakyReLU(), use_spect=False): - - super(ADAINDecoder, self).__init__() - self.encoder_layers = encoder_layers - self.decoder_layers = decoder_layers - self.skip_connect = skip_connect - use_transpose = True - - for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]: - in_channels = min(ngf * (2**(i+1)), img_f) - in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels - out_channels = min(ngf * (2**i), img_f) - model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect) - setattr(self, 'decoder' + str(i), model) - - self.output_nc = out_channels*2 if self.skip_connect else out_channels - - def forward(self, x, z): - out = x.pop() if self.skip_connect else x - for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]: - model = getattr(self, 'decoder' + str(i)) - out = model(out, z) - out = torch.cat([out, x.pop()], 1) if self.skip_connect else out - return out - -class ADAINEncoderBlock(nn.Module): - def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(ADAINEncoderBlock, self).__init__() - kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1} - kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1} - - self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect) - self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect) - - - self.norm_0 = ADAIN(input_nc, feature_nc) - self.norm_1 = ADAIN(output_nc, feature_nc) - self.actvn = nonlinearity - - def forward(self, x, z): - x = self.conv_0(self.actvn(self.norm_0(x, z))) - x = self.conv_1(self.actvn(self.norm_1(x, z))) - return x - -class ADAINDecoderBlock(nn.Module): - def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(ADAINDecoderBlock, self).__init__() - # Attributes - self.actvn = nonlinearity - hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc - - kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1} - if use_transpose: - kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1} - else: - kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1} - - # create conv layers - self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect) - if use_transpose: - self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect) - self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect) - else: - self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect), - nn.Upsample(scale_factor=2)) - self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect), - nn.Upsample(scale_factor=2)) - # define normalization layers - self.norm_0 = ADAIN(input_nc, feature_nc) - self.norm_1 = ADAIN(hidden_nc, feature_nc) - self.norm_s = ADAIN(input_nc, feature_nc) - - def forward(self, x, z): - x_s = self.shortcut(x, z) - dx = self.conv_0(self.actvn(self.norm_0(x, z))) - dx = self.conv_1(self.actvn(self.norm_1(dx, z))) - out = x_s + dx - return out - - def shortcut(self, x, z): - x_s = self.conv_s(self.actvn(self.norm_s(x, z))) - return x_s - - -def spectral_norm(module, use_spect=True): - """use spectral normal layer to stable the training process""" - if use_spect: - return SpectralNorm(module) - else: - return module - - -class ADAIN(nn.Module): - def __init__(self, norm_nc, feature_nc): - super().__init__() - - self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) - - nhidden = 128 - use_bias=True - - self.mlp_shared = nn.Sequential( - nn.Linear(feature_nc, nhidden, bias=use_bias), - nn.ReLU() - ) - self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias) - self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias) - - def forward(self, x, feature): - - # Part 1. generate parameter-free normalized activations - normalized = self.param_free_norm(x) - - # Part 2. produce scaling and bias conditioned on feature - feature = feature.view(feature.size(0), -1) - actv = self.mlp_shared(feature) - gamma = self.mlp_gamma(actv) - beta = self.mlp_beta(actv) - - # apply scale and bias - gamma = gamma.view(*gamma.size()[:2], 1,1) - beta = beta.view(*beta.size()[:2], 1,1) - out = normalized * (1 + gamma) + beta - return out - - -class FineEncoder(nn.Module): - """docstring for Encoder""" - def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineEncoder, self).__init__() - self.layers = layers - self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect) - for i in range(layers): - in_channels = min(ngf*(2**i), img_f) - out_channels = min(ngf*(2**(i+1)), img_f) - model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) - setattr(self, 'down' + str(i), model) - self.output_nc = out_channels - - def forward(self, x): - x = self.first(x) - out=[x] - for i in range(self.layers): - model = getattr(self, 'down'+str(i)) - x = model(x) - out.append(x) - return out - -class FineDecoder(nn.Module): - """docstring for FineDecoder""" - def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineDecoder, self).__init__() - self.layers = layers - for i in range(layers)[::-1]: - in_channels = min(ngf*(2**(i+1)), img_f) - out_channels = min(ngf*(2**i), img_f) - up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) - res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect) - jump = Jump(out_channels, norm_layer, nonlinearity, use_spect) - - setattr(self, 'up' + str(i), up) - setattr(self, 'res' + str(i), res) - setattr(self, 'jump' + str(i), jump) - - self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh') - - self.output_nc = out_channels - - def forward(self, x, z): - out = x.pop() - for i in range(self.layers)[::-1]: - res_model = getattr(self, 'res' + str(i)) - up_model = getattr(self, 'up' + str(i)) - jump_model = getattr(self, 'jump' + str(i)) - out = res_model(out, z) - out = up_model(out) - out = jump_model(x.pop()) + out - out_image = self.final(out) - return out_image - -class FirstBlock2d(nn.Module): - """ - Downsampling block for use in encoder. - """ - def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FirstBlock2d, self).__init__() - kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity) - else: - self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) - - - def forward(self, x): - out = self.model(x) - return out - -class DownBlock2d(nn.Module): - def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(DownBlock2d, self).__init__() - - - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - pool = nn.AvgPool2d(kernel_size=(2, 2)) - - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity, pool) - else: - self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool) - - def forward(self, x): - out = self.model(x) - return out - -class UpBlock2d(nn.Module): - def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(UpBlock2d, self).__init__() - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity) - else: - self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) - - def forward(self, x): - out = self.model(F.interpolate(x, scale_factor=2)) - return out - -class FineADAINResBlocks(nn.Module): - def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineADAINResBlocks, self).__init__() - self.num_block = num_block - for i in range(num_block): - model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect) - setattr(self, 'res'+str(i), model) - - def forward(self, x, z): - for i in range(self.num_block): - model = getattr(self, 'res'+str(i)) - x = model(x, z) - return x - -class Jump(nn.Module): - def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(Jump, self).__init__() - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) - - if type(norm_layer) == type(None): - self.model = nn.Sequential(conv, nonlinearity) - else: - self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity) - - def forward(self, x): - out = self.model(x) - return out - -class FineADAINResBlock2d(nn.Module): - """ - Define an Residual block for different types - """ - def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): - super(FineADAINResBlock2d, self).__init__() - - kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} - - self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) - self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) - self.norm1 = ADAIN(input_nc, feature_nc) - self.norm2 = ADAIN(input_nc, feature_nc) - - self.actvn = nonlinearity - - - def forward(self, x, z): - dx = self.actvn(self.norm1(self.conv1(x), z)) - dx = self.norm2(self.conv2(x), z) - out = dx + x - return out - -class FinalBlock2d(nn.Module): - """ - Define the output layer - """ - def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'): - super(FinalBlock2d, self).__init__() - - kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3} - conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) - - if tanh_or_sigmoid == 'sigmoid': - out_nonlinearity = nn.Sigmoid() - else: - out_nonlinearity = nn.Tanh() - - self.model = nn.Sequential(conv, out_nonlinearity) - def forward(self, x): - out = self.model(x) - return out \ No newline at end of file diff --git a/sadtalker_video2pose/src/facerender/pirender/config.py b/sadtalker_video2pose/src/facerender/pirender/config.py deleted file mode 100644 index 29dc2d1b9008dbf2dc3c0a307212471621bae8da..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/pirender/config.py +++ /dev/null @@ -1,211 +0,0 @@ -import collections -import functools -import os -import re - -import yaml - -class AttrDict(dict): - """Dict as attribute trick.""" - - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - for key, value in self.__dict__.items(): - if isinstance(value, dict): - self.__dict__[key] = AttrDict(value) - elif isinstance(value, (list, tuple)): - if isinstance(value[0], dict): - self.__dict__[key] = [AttrDict(item) for item in value] - else: - self.__dict__[key] = value - - def yaml(self): - """Convert object to yaml dict and return.""" - yaml_dict = {} - for key, value in self.__dict__.items(): - if isinstance(value, AttrDict): - yaml_dict[key] = value.yaml() - elif isinstance(value, list): - if isinstance(value[0], AttrDict): - new_l = [] - for item in value: - new_l.append(item.yaml()) - yaml_dict[key] = new_l - else: - yaml_dict[key] = value - else: - yaml_dict[key] = value - return yaml_dict - - def __repr__(self): - """Print all variables.""" - ret_str = [] - for key, value in self.__dict__.items(): - if isinstance(value, AttrDict): - ret_str.append('{}:'.format(key)) - child_ret_str = value.__repr__().split('\n') - for item in child_ret_str: - ret_str.append(' ' + item) - elif isinstance(value, list): - if isinstance(value[0], AttrDict): - ret_str.append('{}:'.format(key)) - for item in value: - # Treat as AttrDict above. - child_ret_str = item.__repr__().split('\n') - for item in child_ret_str: - ret_str.append(' ' + item) - else: - ret_str.append('{}: {}'.format(key, value)) - else: - ret_str.append('{}: {}'.format(key, value)) - return '\n'.join(ret_str) - - -class Config(AttrDict): - r"""Configuration class. This should include every human specifiable - hyperparameter values for your training.""" - - def __init__(self, filename=None, args=None, verbose=False, is_train=True): - super(Config, self).__init__() - # Set default parameters. - # Logging. - - large_number = 1000000000 - self.snapshot_save_iter = large_number - self.snapshot_save_epoch = large_number - self.snapshot_save_start_iter = 0 - self.snapshot_save_start_epoch = 0 - self.image_save_iter = large_number - self.eval_epoch = large_number - self.start_eval_epoch = large_number - self.eval_epoch = large_number - self.max_epoch = large_number - self.max_iter = large_number - self.logging_iter = 100 - self.image_to_tensorboard=False - self.which_iter = 0 # args.which_iter - self.resume = False - - self.checkpoints_dir = '/Users/shadowcun/Downloads/' - self.name = 'face' - self.phase = 'train' if is_train else 'test' - - # Networks. - self.gen = AttrDict(type='generators.dummy') - self.dis = AttrDict(type='discriminators.dummy') - - # Optimizers. - self.gen_optimizer = AttrDict(type='adam', - lr=0.0001, - adam_beta1=0.0, - adam_beta2=0.999, - eps=1e-8, - lr_policy=AttrDict(iteration_mode=False, - type='step', - step_size=large_number, - gamma=1)) - self.dis_optimizer = AttrDict(type='adam', - lr=0.0001, - adam_beta1=0.0, - adam_beta2=0.999, - eps=1e-8, - lr_policy=AttrDict(iteration_mode=False, - type='step', - step_size=large_number, - gamma=1)) - # Data. - self.data = AttrDict(name='dummy', - type='datasets.images', - num_workers=0) - self.test_data = AttrDict(name='dummy', - type='datasets.images', - num_workers=0, - test=AttrDict(is_lmdb=False, - roots='', - batch_size=1)) - self.trainer = AttrDict( - model_average=False, - model_average_beta=0.9999, - model_average_start_iteration=1000, - model_average_batch_norm_estimation_iteration=30, - model_average_remove_sn=True, - image_to_tensorboard=False, - hparam_to_tensorboard=False, - distributed_data_parallel='pytorch', - delay_allreduce=True, - gan_relativistic=False, - gen_step=1, - dis_step=1) - - # # Cudnn. - self.cudnn = AttrDict(deterministic=False, - benchmark=True) - - # Others. - self.pretrained_weight = '' - self.inference_args = AttrDict() - - - # Update with given configurations. - assert os.path.exists(filename), 'File {} not exist.'.format(filename) - loader = yaml.SafeLoader - loader.add_implicit_resolver( - u'tag:yaml.org,2002:float', - re.compile(u'''^(?: - [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? - |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) - |\\.[0-9_]+(?:[eE][-+][0-9]+)? - |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* - |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.')) - try: - with open(filename, 'r') as f: - cfg_dict = yaml.load(f, Loader=loader) - except EnvironmentError: - print('Please check the file with name of "%s"', filename) - recursive_update(self, cfg_dict) - - # Put common opts in both gen and dis. - if 'common' in cfg_dict: - self.common = AttrDict(**cfg_dict['common']) - self.gen.common = self.common - self.dis.common = self.common - - - if verbose: - print(' config '.center(80, '-')) - print(self.__repr__()) - print(''.center(80, '-')) - - -def rsetattr(obj, attr, val): - """Recursively find object and set value""" - pre, _, post = attr.rpartition('.') - return setattr(rgetattr(obj, pre) if pre else obj, post, val) - - -def rgetattr(obj, attr, *args): - """Recursively find object and return value""" - - def _getattr(obj, attr): - r"""Get attribute.""" - return getattr(obj, attr, *args) - - return functools.reduce(_getattr, [obj] + attr.split('.')) - - -def recursive_update(d, u): - """Recursively update AttrDict d with AttrDict u""" - for key, value in u.items(): - if isinstance(value, collections.abc.Mapping): - d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) - elif isinstance(value, (list, tuple)): - if isinstance(value[0], dict): - d.__dict__[key] = [AttrDict(item) for item in value] - else: - d.__dict__[key] = value - else: - d.__dict__[key] = value - return d diff --git a/sadtalker_video2pose/src/facerender/pirender/face_model.py b/sadtalker_video2pose/src/facerender/pirender/face_model.py deleted file mode 100644 index 0f83e2fc5d8c66cf9bd2e2c5549773e11e0f8a44..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/pirender/face_model.py +++ /dev/null @@ -1,178 +0,0 @@ -import functools -import torch -import torch.nn as nn -from .base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder - -def convert_flow_to_deformation(flow): - r"""convert flow fields to deformations. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - deformation (tensor): The deformation used for warpping - """ - b,c,h,w = flow.shape - flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) - grid = make_coordinate_grid(flow) - deformation = grid + flow_norm.permute(0,2,3,1) - return deformation - -def make_coordinate_grid(flow): - r"""obtain coordinate grid with the same size as the flow filed. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - grid (tensor): The grid with the same size as the input flow - """ - b,c,h,w = flow.shape - - x = torch.arange(w).to(flow) - y = torch.arange(h).to(flow) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - - yy = y.view(-1, 1).repeat(1, w) - xx = x.view(1, -1).repeat(h, 1) - - meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) - meshed = meshed.expand(b, -1, -1, -1) - return meshed - - -def warp_image(source_image, deformation): - r"""warp the input image according to the deformation - - Args: - source_image (tensor): source images to be warpped - deformation (tensor): deformations used to warp the images; value in range (-1, 1) - Returns: - output (tensor): the warpped images - """ - _, h_old, w_old, _ = deformation.shape - _, _, h, w = source_image.shape - if h_old != h or w_old != w: - deformation = deformation.permute(0, 3, 1, 2) - deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') - deformation = deformation.permute(0, 2, 3, 1) - return torch.nn.functional.grid_sample(source_image, deformation) - - -class FaceGenerator(nn.Module): - def __init__( - self, - mapping_net, - warpping_net, - editing_net, - common - ): - super(FaceGenerator, self).__init__() - self.mapping_net = MappingNet(**mapping_net) - self.warpping_net = WarpingNet(**warpping_net, **common) - self.editing_net = EditingNet(**editing_net, **common) - - def forward( - self, - input_image, - driving_source, - stage=None - ): - if stage == 'warp': - descriptor = self.mapping_net(driving_source) - output = self.warpping_net(input_image, descriptor) - else: - descriptor = self.mapping_net(driving_source) - output = self.warpping_net(input_image, descriptor) - output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) - return output - -class MappingNet(nn.Module): - def __init__(self, coeff_nc, descriptor_nc, layer): - super( MappingNet, self).__init__() - - self.layer = layer - nonlinearity = nn.LeakyReLU(0.1) - - self.first = nn.Sequential( - torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) - - for i in range(layer): - net = nn.Sequential(nonlinearity, - torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) - setattr(self, 'encoder' + str(i), net) - - self.pooling = nn.AdaptiveAvgPool1d(1) - self.output_nc = descriptor_nc - - def forward(self, input_3dmm): - out = self.first(input_3dmm) - for i in range(self.layer): - model = getattr(self, 'encoder' + str(i)) - out = model(out) + out[:,:,3:-3] - out = self.pooling(out) - return out - -class WarpingNet(nn.Module): - def __init__( - self, - image_nc, - descriptor_nc, - base_nc, - max_nc, - encoder_layer, - decoder_layer, - use_spect - ): - super( WarpingNet, self).__init__() - - nonlinearity = nn.LeakyReLU(0.1) - norm_layer = functools.partial(LayerNorm2d, affine=True) - kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} - - self.descriptor_nc = descriptor_nc - self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, - max_nc, encoder_layer, decoder_layer, **kwargs) - - self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), - nonlinearity, - nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) - - self.pool = nn.AdaptiveAvgPool2d(1) - - def forward(self, input_image, descriptor): - final_output={} - output = self.hourglass(input_image, descriptor) - final_output['flow_field'] = self.flow_out(output) - - deformation = convert_flow_to_deformation(final_output['flow_field']) - final_output['warp_image'] = warp_image(input_image, deformation) - return final_output - - -class EditingNet(nn.Module): - def __init__( - self, - image_nc, - descriptor_nc, - layer, - base_nc, - max_nc, - num_res_blocks, - use_spect): - super(EditingNet, self).__init__() - - nonlinearity = nn.LeakyReLU(0.1) - norm_layer = functools.partial(LayerNorm2d, affine=True) - kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} - self.descriptor_nc = descriptor_nc - - # encoder part - self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) - self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) - - def forward(self, input_image, warp_image, descriptor): - x = torch.cat([input_image, warp_image], 1) - x = self.encoder(x) - gen_image = self.decoder(x, descriptor) - return gen_image diff --git a/sadtalker_video2pose/src/facerender/pirender_animate.py b/sadtalker_video2pose/src/facerender/pirender_animate.py deleted file mode 100644 index 07d4ccf0918f09dcfa422a85694bd17bf42d11ff..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/pirender_animate.py +++ /dev/null @@ -1,266 +0,0 @@ -import os -import uuid -import cv2 -from tqdm import tqdm -import yaml -import numpy as np -import warnings -from skimage import img_as_ubyte -import safetensors -import safetensors.torch -warnings.filterwarnings('ignore') - - -import imageio -import torch -import torchvision - -from src.facerender.pirender.config import Config -from src.facerender.pirender.face_model import FaceGenerator - -from pydub import AudioSegment -from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list -from src.utils.paste_pic import paste_pic -from src.utils.videoio import save_video_with_watermark -from src.utils.flow_util import vis_flow -from scipy.io import savemat,loadmat - -try: - import webui # in webui - in_webui = True -except: - in_webui = False - -expession = loadmat('expression.mat') -control_dict = {} -for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']: - control_dict[item] = torch.tensor(expession[item])[0] - -class AnimateFromCoeff_PIRender(): - - def __init__(self, sadtalker_path, device): - - opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False) - opt.device = device - self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device) - checkpoint_path = sadtalker_path['pirender_checkpoint'] - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False) - print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path)) - self.net_G = self.net_G_ema.eval() - self.device = device - - - def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - - num = 16 - - # import pdb; pdb.set_trace() - # target_semantics_ - current = target_semantics[0, 0, :64, 0] - for control_k in range(len(control_dict.keys())): - listx = list(control_dict.keys()) - control_v = control_dict[listx[control_k]] - for i in range(num): - expression = (control_v-current)*i/(num-1)+current - target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None] - - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - result = img_as_ubyte(video) - - ### the generated video is 256x256, so we keep the aspect ratio, - original_size = crop_info[0] - if original_size: - result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - - video_name = x['video_name'] + '.mp4' - path = os.path.join(video_save_dir, 'temp_'+video_name) - - imageio.mimsave(path, result, fps=float(25)) - - av_path = os.path.join(video_save_dir, video_name) - return_path = av_path - - audio_path = x['audio_path'] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') - start_time = 0 - # cog will not keep the .mp3 filename - sound = AudioSegment.from_file(audio_path) - frames = frame_num - end_time = start_time + frames*1/25*1000 - word1=sound.set_frame_rate(16000) - word = word1[start_time:end_time] - word.export(new_audio_path, format="wav") - - save_video_with_watermark(path, new_audio_path, av_path, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name}') - - if 'full' in preprocess.lower(): - # only add watermark to the full image. - video_name_full = x['video_name'] + '_full.mp4' - full_video_path = os.path.join(video_save_dir, video_name_full) - return_path = full_video_path - paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) - print(f'The generated video is named {video_save_dir}/{video_name_full}') - else: - full_video_path = av_path - - #### paste back then enhancers - if enhancer: - video_name_enhancer = x['video_name'] + '_enhanced.mp4' - enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) - av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) - return_path = av_path_enhancer - - try: - enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - except: - enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - - save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') - os.remove(enhanced_path) - - os.remove(path) - os.remove(new_audio_path) - - return return_path - - def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - - - num = 16 - - current = target_semantics[0, 0, :64, 0] - for control_k in range(len(control_dict.keys())): - listx = list(control_dict.keys()) - control_v = control_dict[listx[control_k]] - for i in range(num): - expression = (control_v-current)*i/(num-1)+current - target_semantics[:, (control_k*num + i):(control_k*num + i+1), :64, :] = expression[None, None, :, None] - - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - - results = np.stack(video, axis=0) - - ### the generated video is 256x256, so we keep the aspect ratio, - # original_size = crop_info[0] - # if original_size: - # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - # results = np.stack(result, axis=0) - - x_name = os.path.basename(pic_path) - save_name = os.path.join(video_save_dir, x_name + '.flo') - save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4') - - flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False) - - flow_viz = [] - for kk in range(flow_full.shape[0]): - tmp = vis_flow(flow_full[kk]) - flow_viz.append(tmp) - flow_viz = np.stack(flow_viz) - - torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'}) - - return save_name_flow_vis - - -def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False): - - if not os.path.isfile(pic_path): - raise ValueError('pic_path must be a valid path to video/image file') - elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_img = cv2.imread(pic_path) - else: - # loader for videos - video_stream = cv2.VideoCapture(pic_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - break - full_img = frame - frame_h = full_img.shape[0] - frame_w = full_img.shape[1] - - # full images, we only use it as reference for zero init image. - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) - # template = np.zeros((frame_h, frame_w, 2)) # full flows - out_tmp = [] - for crop_frame in tqdm(flows, 'seamlessClone:'): - p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4) - - gen_img = np.zeros((frame_h, frame_w, 2)) - # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE) - gen_img[oy1:oy2,ox1:ox2] = p - out_tmp.append(gen_img) - - np.save(save_name, np.stack(out_tmp)) - return np.stack(out_tmp) \ No newline at end of file diff --git a/sadtalker_video2pose/src/facerender/pirender_animate_control.py b/sadtalker_video2pose/src/facerender/pirender_animate_control.py deleted file mode 100644 index 1c357f35577816c8d6731627afd505c6dd8efdca..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/pirender_animate_control.py +++ /dev/null @@ -1,251 +0,0 @@ -import os -import uuid -import cv2 -from tqdm import tqdm -import yaml -import numpy as np -import warnings -from skimage import img_as_ubyte -import safetensors -import safetensors.torch -warnings.filterwarnings('ignore') - - -import imageio -import torch -import torchvision - -from src.facerender.pirender.config import Config -from src.facerender.pirender.face_model import FaceGenerator - -from pydub import AudioSegment -from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list -from src.utils.paste_pic import paste_pic -from src.utils.videoio import save_video_with_watermark -from src.utils.flow_util import vis_flow - -from scipy.io import savemat,loadmat - -try: - import webui # in webui - in_webui = True -except: - in_webui = False - -expession = loadmat('expression.mat') -control_dict = {} -for item in ['expression_center', 'expression_mouth', 'expression_eyebrow', 'expression_eyes']: - control_dict[item] = torch.tensor(expession[item])[0] - -class AnimateFromCoeff_PIRender(): - - def __init__(self, sadtalker_path, device): - - opt = Config(sadtalker_path['pirender_yaml_path'], None, is_train=False) - opt.device = device - self.net_G_ema = FaceGenerator(**opt.gen.param).to(opt.device) - checkpoint_path = sadtalker_path['pirender_checkpoint'] - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - self.net_G_ema.load_state_dict(checkpoint['net_G_ema'], strict=False) - print('load [net_G] and [net_G_ema] from {}'.format(checkpoint_path)) - self.net_G = self.net_G_ema.eval() - self.device = device - - - def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - num = 10 - - # target_semantics_ - current = target_semantics['target_semantics_list'][0, :64, 0] - for control in control_dict: - for i in range(num): - expression = (control_dict[control]-current)*i/(num-1)+current - target_semantics['target_semantics_list'][:, :64, :] = expression[None, :, None] - - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['fake_image']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - result = img_as_ubyte(video) - - ### the generated video is 256x256, so we keep the aspect ratio, - original_size = crop_info[0] - if original_size: - result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - - video_name = x['video_name'] + '.mp4' - path = os.path.join(video_save_dir, 'temp_'+video_name) - - imageio.mimsave(path, result, fps=float(25)) - - av_path = os.path.join(video_save_dir, video_name) - return_path = av_path - - audio_path = x['audio_path'] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') - start_time = 0 - # cog will not keep the .mp3 filename - sound = AudioSegment.from_file(audio_path) - frames = frame_num - end_time = start_time + frames*1/25*1000 - word1=sound.set_frame_rate(16000) - word = word1[start_time:end_time] - word.export(new_audio_path, format="wav") - - save_video_with_watermark(path, new_audio_path, av_path, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name}') - - if 'full' in preprocess.lower(): - # only add watermark to the full image. - video_name_full = x['video_name'] + '_full.mp4' - full_video_path = os.path.join(video_save_dir, video_name_full) - return_path = full_video_path - paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) - print(f'The generated video is named {video_save_dir}/{video_name_full}') - else: - full_video_path = av_path - - #### paste back then enhancers - if enhancer: - video_name_enhancer = x['video_name'] + '_enhanced.mp4' - enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) - av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) - return_path = av_path_enhancer - - try: - enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - except: - enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) - imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) - - save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) - print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') - os.remove(enhanced_path) - - os.remove(path) - os.remove(new_audio_path) - - return return_path - - def generate_flow(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): - - source_image=x['source_image'].type(torch.FloatTensor) - source_semantics=x['source_semantics'].type(torch.FloatTensor) - target_semantics=x['target_semantics_list'].type(torch.FloatTensor) - source_image=source_image.to(self.device) - source_semantics=source_semantics.to(self.device) - target_semantics=target_semantics.to(self.device) - frame_num = x['frame_num'] - - with torch.no_grad(): - predictions_video = [] - for i in tqdm(range(target_semantics.shape[1]), 'FaceRender:'): - predictions_video.append(self.net_G(source_image, target_semantics[:, i])['flow_field']) - - predictions_video = torch.stack(predictions_video, dim=1) - predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) - - video = [] - for idx in range(len(predictions_video)): - image = predictions_video[idx] - image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) - video.append(image) - - results = np.stack(video, axis=0) - - ### the generated video is 256x256, so we keep the aspect ratio, - # original_size = crop_info[0] - # if original_size: - # result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] - # results = np.stack(result, axis=0) - - x_name = os.path.basename(pic_path) - save_name = os.path.join(video_save_dir, x_name + '.flo') - save_name_flow_vis = os.path.join(video_save_dir, x_name + '.mp4') - - flow_full = paste_flow(results, pic_path, save_name, crop_info, extended_crop= True if 'ext' in preprocess.lower() else False) - - flow_viz = [] - for kk in range(flow_full.shape[0]): - tmp = vis_flow(flow_full[kk]) - flow_viz.append(tmp) - flow_viz = np.stack(flow_viz) - - torchvision.io.write_video(save_name_flow_vis, flow_viz, fps=20, video_codec='h264', options={'crf': '10'}) - - return save_name_flow_vis - - -def paste_flow(flows, pic_path, save_name, crop_info, extended_crop=False): - - if not os.path.isfile(pic_path): - raise ValueError('pic_path must be a valid path to video/image file') - elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_img = cv2.imread(pic_path) - else: - # loader for videos - video_stream = cv2.VideoCapture(pic_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - break - full_img = frame - frame_h = full_img.shape[0] - frame_w = full_img.shape[1] - - # full images, we only use it as reference for zero init image. - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - # out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) - # template = np.zeros((frame_h, frame_w, 2)) # full flows - out_tmp = [] - for crop_frame in tqdm(flows, 'seamlessClone:'): - p = cv2.resize(crop_frame, (ox2-ox1, oy2 - oy1), interpolation=cv2.INTER_LANCZOS4) - - gen_img = np.zeros((frame_h, frame_w, 2)) - # gen_img = cv2.seamlessClone(p, template, mask, location, cv2.NORMAL_CLONE) - gen_img[oy1:oy2,ox1:ox2] = p - out_tmp.append(gen_img) - - np.save(save_name, np.stack(out_tmp)) - return np.stack(out_tmp) \ No newline at end of file diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/__init__.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/__init__.py deleted file mode 100644 index 48871cdcdc882c903501ecc6d70fcb1b50bd7e9f..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/sync_batchnorm/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- -# File : __init__.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d -from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/batchnorm.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/batchnorm.py deleted file mode 100644 index b4cc2ccd2f0c904cbe433fb6136f443f0fa86fa6..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/sync_batchnorm/batchnorm.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -# File : batchnorm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import collections - -import torch -import torch.nn.functional as F - -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast - -from .comm import SyncMaster - -__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] - - -def _sum_ft(tensor): - """sum over the first and last dimention""" - return tensor.sum(dim=0).sum(dim=-1) - - -def _unsqueeze_ft(tensor): - """add new dementions at the front and the tail""" - return tensor.unsqueeze(0).unsqueeze(-1) - - -_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) -_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) - - -class _SynchronizedBatchNorm(_BatchNorm): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): - super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) - - self._sync_master = SyncMaster(self._data_parallel_master) - - self._is_parallel = False - self._parallel_id = None - self._slave_pipe = None - - def forward(self, input): - # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. - if not (self._is_parallel and self.training): - return F.batch_norm( - input, self.running_mean, self.running_var, self.weight, self.bias, - self.training, self.momentum, self.eps) - - # Resize the input to (B, C, -1). - input_shape = input.size() - input = input.view(input.size(0), self.num_features, -1) - - # Compute the sum and square-sum. - sum_size = input.size(0) * input.size(2) - input_sum = _sum_ft(input) - input_ssum = _sum_ft(input ** 2) - - # Reduce-and-broadcast the statistics. - if self._parallel_id == 0: - mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) - else: - mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) - - # Compute the output. - if self.affine: - # MJY:: Fuse the multiplication for speed. - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) - else: - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) - - # Reshape it. - return output.view(input_shape) - - def __data_parallel_replicate__(self, ctx, copy_id): - self._is_parallel = True - self._parallel_id = copy_id - - # parallel_id == 0 means master device. - if self._parallel_id == 0: - ctx.sync_master = self._sync_master - else: - self._slave_pipe = ctx.sync_master.register_slave(copy_id) - - def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" - - # Always using same "device order" makes the ReduceAdd operation faster. - # Thanks to:: Tete Xiao (http://tetexiao.com/) - intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) - - to_reduce = [i[1][:2] for i in intermediates] - to_reduce = [j for i in to_reduce for j in i] # flatten - target_gpus = [i[1].sum.get_device() for i in intermediates] - - sum_size = sum([i[1].sum_size for i in intermediates]) - sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) - mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) - - broadcasted = Broadcast.apply(target_gpus, mean, inv_std) - - outputs = [] - for i, rec in enumerate(intermediates): - outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) - - return outputs - - def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with sum and square-sum. This method - also maintains the moving average on the master device.""" - assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' - mean = sum_ / size - sumvar = ssum - sum_ * mean - unbias_var = sumvar / (size - 1) - bias_var = sumvar / size - - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data - - return mean, bias_var.clamp(self.eps) ** -0.5 - - -class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): - r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a - mini-batch. - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm1d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm - - Args: - num_features: num_features from an expected input of size - `batch_size x num_features [x width]` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C)` or :math:`(N, C, L)` - - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm1d, self)._check_input_dim(input) - - -class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch - of 3d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm2d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 4: - raise ValueError('expected 4D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm2d, self)._check_input_dim(input) - - -class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch - of 4d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm3d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm - or Spatio-temporal BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x depth x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C, D, H, W)` - - Output: :math:`(N, C, D, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 5: - raise ValueError('expected 5D input (got {}D input)' - .format(input.dim())) - super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/comm.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/comm.py deleted file mode 100644 index b66ec4aea213edf4330beda0a8c8b93d6db77a60..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/sync_batchnorm/comm.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- -# File : comm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import queue -import collections -import threading - -__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] - - -class FutureResult(object): - """A thread-safe future implementation. Used only as one-to-one pipe.""" - - def __init__(self): - self._result = None - self._lock = threading.Lock() - self._cond = threading.Condition(self._lock) - - def put(self, result): - with self._lock: - assert self._result is None, 'Previous result has\'t been fetched.' - self._result = result - self._cond.notify() - - def get(self): - with self._lock: - if self._result is None: - self._cond.wait() - - res = self._result - self._result = None - return res - - -_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) -_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) - - -class SlavePipe(_SlavePipeBase): - """Pipe for master-slave communication.""" - - def run_slave(self, msg): - self.queue.put((self.identifier, msg)) - ret = self.result.get() - self.queue.put(True) - return ret - - -class SyncMaster(object): - """An abstract `SyncMaster` object. - - - During the replication, as the data parallel will trigger an callback of each module, all slave devices should - call `register(id)` and obtain an `SlavePipe` to communicate with the master. - - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, - and passed to a registered callback. - - After receiving the messages, the master device should gather the information and determine to message passed - back to each slave devices. - """ - - def __init__(self, master_callback): - """ - - Args: - master_callback: a callback to be invoked after having collected messages from slave devices. - """ - self._master_callback = master_callback - self._queue = queue.Queue() - self._registry = collections.OrderedDict() - self._activated = False - - def __getstate__(self): - return {'master_callback': self._master_callback} - - def __setstate__(self, state): - self.__init__(state['master_callback']) - - def register_slave(self, identifier): - """ - Register an slave device. - - Args: - identifier: an identifier, usually is the device id. - - Returns: a `SlavePipe` object which can be used to communicate with the master device. - - """ - if self._activated: - assert self._queue.empty(), 'Queue is not clean before next initialization.' - self._activated = False - self._registry.clear() - future = FutureResult() - self._registry[identifier] = _MasterRegistry(future) - return SlavePipe(identifier, self._queue, future) - - def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. - The messages were first collected from each devices (including the master device), and then - an callback will be invoked to compute the message to be sent back to each devices - (including the master device). - - Args: - master_msg: the message that the master want to send to itself. This will be placed as the first - message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. - - Returns: the message to be sent back to the master device. - - """ - self._activated = True - - intermediates = [(0, master_msg)] - for i in range(self.nr_slaves): - intermediates.append(self._queue.get()) - - results = self._master_callback(intermediates) - assert results[0][0] == 0, 'The first result should belongs to the master.' - - for i, res in results: - if i == 0: - continue - self._registry[i].result.put(res) - - for i in range(self.nr_slaves): - assert self._queue.get() is True - - return results[0][1] - - @property - def nr_slaves(self): - return len(self._registry) diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/replicate.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/replicate.py deleted file mode 100644 index 9b97380d9c5fbe75c4b3583d3668ccd6a2848699..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/sync_batchnorm/replicate.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# File : replicate.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import functools - -from torch.nn.parallel.data_parallel import DataParallel - -__all__ = [ - 'CallbackContext', - 'execute_replication_callbacks', - 'DataParallelWithCallback', - 'patch_replication_callback' -] - - -class CallbackContext(object): - pass - - -def execute_replication_callbacks(modules): - """ - Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. - - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Note that, as all modules are isomorphism, we assign each sub-module with a context - (shared among multiple copies of this module on different devices). - Through this context, different copies can share some information. - - We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback - of any slave copies. - """ - master_copy = modules[0] - nr_modules = len(list(master_copy.modules())) - ctxs = [CallbackContext() for _ in range(nr_modules)] - - for i, module in enumerate(modules): - for j, m in enumerate(module.modules()): - if hasattr(m, '__data_parallel_replicate__'): - m.__data_parallel_replicate__(ctxs[j], i) - - -class DataParallelWithCallback(DataParallel): - """ - Data Parallel with a replication callback. - - An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by - original `replicate` function. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - # sync_bn.__data_parallel_replicate__ will be invoked. - """ - - def replicate(self, module, device_ids): - modules = super(DataParallelWithCallback, self).replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - -def patch_replication_callback(data_parallel): - """ - Monkey-patch an existing `DataParallel` object. Add the replication callback. - Useful when you have customized `DataParallel` implementation. - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) - > patch_replication_callback(sync_bn) - # this is equivalent to - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - """ - - assert isinstance(data_parallel, DataParallel) - - old_replicate = data_parallel.replicate - - @functools.wraps(old_replicate) - def new_replicate(module, device_ids): - modules = old_replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - data_parallel.replicate = new_replicate diff --git a/sadtalker_video2pose/src/facerender/sync_batchnorm/unittest.py b/sadtalker_video2pose/src/facerender/sync_batchnorm/unittest.py deleted file mode 100644 index 9716d035495097fb086ec050ab0bc9b76b9d28a0..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/facerender/sync_batchnorm/unittest.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# File : unittest.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import unittest - -import numpy as np -from torch.autograd import Variable - - -def as_numpy(v): - if isinstance(v, Variable): - v = v.data - return v.cpu().numpy() - - -class TorchTestCase(unittest.TestCase): - def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): - npa, npb = as_numpy(a), as_numpy(b) - self.assertTrue( - np.allclose(npa, npb, atol=atol), - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - ) diff --git a/sadtalker_video2pose/src/generate_batch.py b/sadtalker_video2pose/src/generate_batch.py deleted file mode 100644 index 2fcaff51276d489aa76c15e4979864a4d4f74aa4..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/generate_batch.py +++ /dev/null @@ -1,120 +0,0 @@ -import os - -from tqdm import tqdm -import torch -import numpy as np -import random -import scipy.io as scio -import src.utils.audio as audio - -def crop_pad_audio(wav, audio_length): - if len(wav) > audio_length: - wav = wav[:audio_length] - elif len(wav) < audio_length: - wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) - return wav - -def parse_audio_length(audio_length, sr, fps): - bit_per_frames = sr / fps - - num_frames = int(audio_length / bit_per_frames) - audio_length = int(num_frames * bit_per_frames) - - return audio_length, num_frames - -def generate_blink_seq(num_frames): - ratio = np.zeros((num_frames,1)) - frame_id = 0 - while frame_id in range(num_frames): - start = 80 - if frame_id+start+9<=num_frames - 1: - ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] - frame_id = frame_id+start+9 - else: - break - return ratio - -def generate_blink_seq_randomly(num_frames): - ratio = np.zeros((num_frames,1)) - if num_frames<=20: - return ratio - frame_id = 0 - while frame_id in range(num_frames): - start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) - if frame_id+start+5<=num_frames - 1: - ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] - frame_id = frame_id+start+5 - else: - break - return ratio - -def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): - - syncnet_mel_step_size = 16 - fps = 25 - - pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] - audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] - - - if idlemode: - num_frames = int(length_of_audio * 25) - indiv_mels = np.zeros((num_frames, 80, 16)) - else: - wav = audio.load_wav(audio_path, 16000) - wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) - wav = crop_pad_audio(wav, wav_length) - orig_mel = audio.melspectrogram(wav).T - spec = orig_mel.copy() # nframes 80 - indiv_mels = [] - - for i in tqdm(range(num_frames), 'mel:'): - start_frame_num = i-2 - start_idx = int(80. * (start_frame_num / float(fps))) - end_idx = start_idx + syncnet_mel_step_size - seq = list(range(start_idx, end_idx)) - seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] - m = spec[seq, :] - indiv_mels.append(m.T) - indiv_mels = np.asarray(indiv_mels) # T 80 16 - - ratio = generate_blink_seq_randomly(num_frames) # T - source_semantics_path = first_coeff_path - source_semantics_dict = scio.loadmat(source_semantics_path) - ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 - ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) - - if ref_eyeblink_coeff_path is not None: - ratio[:num_frames] = 0 - refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) - refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] - refeyeblink_num_frames = refeyeblink_coeff.shape[0] - if refeyeblink_num_frames frame_num: - new_degree_list = new_degree_list[:frame_num] - elif len(new_degree_list) < frame_num: - for _ in range(frame_num-len(new_degree_list)): - new_degree_list.append(new_degree_list[-1]) - print(len(new_degree_list)) - print(frame_num) - - remainder = frame_num%batch_size - if remainder!=0: - for _ in range(batch_size-remainder): - new_degree_list.append(new_degree_list[-1]) - new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) - return new_degree_np - diff --git a/sadtalker_video2pose/src/gradio_demo.py b/sadtalker_video2pose/src/gradio_demo.py deleted file mode 100644 index 9a2399fc44704b544ef39bb908d32a21da9fae17..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/gradio_demo.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch, uuid -import os, sys, shutil, platform -from src.facerender.pirender_animate import AnimateFromCoeff_PIRender -from src.utils.preprocess import CropAndExtract -from src.test_audio2coeff import Audio2Coeff -from src.facerender.animate import AnimateFromCoeff -from src.generate_batch import get_data -from src.generate_facerender_batch import get_facerender_data - -from src.utils.init_path import init_path - -from pydub import AudioSegment - - -def mp3_to_wav(mp3_filename,wav_filename,frame_rate): - mp3_file = AudioSegment.from_file(file=mp3_filename) - mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav") - - -class SadTalker(): - - def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False): - - if torch.cuda.is_available(): - device = "cuda" - elif platform.system() == 'Darwin': # macos - device = "mps" - else: - device = "cpu" - - self.device = device - - os.environ['TORCH_HOME']= checkpoint_path - - self.checkpoint_path = checkpoint_path - self.config_path = config_path - - - def test(self, source_image, driven_audio, preprocess='crop', - still_mode=False, use_enhancer=False, batch_size=1, size=256, - pose_style = 0, - facerender='facevid2vid', - exp_scale=1.0, - use_ref_video = False, - ref_video = None, - ref_info = None, - use_idle_mode = False, - length_of_audio = 0, use_blink=True, - result_dir='./results/'): - - self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess) - print(self.sadtalker_paths) - - self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device) - self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device) - - if facerender == 'facevid2vid' and self.device != 'mps': - self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device) - elif facerender == 'pirender' or self.device == 'mps': - self.animate_from_coeff = AnimateFromCoeff_PIRender(self.sadtalker_paths, self.device) - facerender = 'pirender' - else: - raise(RuntimeError('Unknown model: {}'.format(facerender))) - - - time_tag = str(uuid.uuid4()) - save_dir = os.path.join(result_dir, time_tag) - os.makedirs(save_dir, exist_ok=True) - - input_dir = os.path.join(save_dir, 'input') - os.makedirs(input_dir, exist_ok=True) - - print(source_image) - pic_path = os.path.join(input_dir, os.path.basename(source_image)) - shutil.move(source_image, input_dir) - - if driven_audio is not None and os.path.isfile(driven_audio): - audio_path = os.path.join(input_dir, os.path.basename(driven_audio)) - - #### mp3 to wav - if '.mp3' in audio_path: - mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000) - audio_path = audio_path.replace('.mp3', '.wav') - else: - shutil.move(driven_audio, input_dir) - - elif use_idle_mode: - audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path - from pydub import AudioSegment - one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds - one_sec_segment.export(audio_path, format="wav") - else: - print(use_ref_video, ref_info) - assert use_ref_video == True and ref_info == 'all' - - if use_ref_video and ref_info == 'all': # full ref mode - ref_video_videoname = os.path.basename(ref_video) - audio_path = os.path.join(save_dir, ref_video_videoname+'.wav') - print('new audiopath:',audio_path) - # if ref_video contains audio, set the audio from ref_video. - cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path) - os.system(cmd) - - os.makedirs(save_dir, exist_ok=True) - - #crop image and extract 3dmm from image - first_frame_dir = os.path.join(save_dir, 'first_frame_dir') - os.makedirs(first_frame_dir, exist_ok=True) - first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size) - - if first_coeff_path is None: - raise AttributeError("No face is detected") - - if use_ref_video: - print('using ref video for genreation') - ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0] - ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname) - os.makedirs(ref_video_frame_dir, exist_ok=True) - print('3DMM Extraction for the reference video providing pose') - ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False) - else: - ref_video_coeff_path = None - - if use_ref_video: - if ref_info == 'pose': - ref_pose_coeff_path = ref_video_coeff_path - ref_eyeblink_coeff_path = None - elif ref_info == 'blink': - ref_pose_coeff_path = None - ref_eyeblink_coeff_path = ref_video_coeff_path - elif ref_info == 'pose+blink': - ref_pose_coeff_path = ref_video_coeff_path - ref_eyeblink_coeff_path = ref_video_coeff_path - elif ref_info == 'all': - ref_pose_coeff_path = None - ref_eyeblink_coeff_path = None - else: - raise('error in refinfo') - else: - ref_pose_coeff_path = None - ref_eyeblink_coeff_path = None - - #audio2ceoff - if use_ref_video and ref_info == 'all': - coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) - else: - batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, \ - idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio? - coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) - - #coeff2video - data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, \ - preprocess=preprocess, size=size, expression_scale = exp_scale, facemodel=facerender) - return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size) - video_name = data['video_name'] - print(f'The generated video is named {video_name} in {save_dir}') - - del self.preprocess_model - del self.audio_to_coeff - del self.animate_from_coeff - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() - - import gc; gc.collect() - - return return_path - - \ No newline at end of file diff --git a/sadtalker_video2pose/src/test_audio2coeff.py b/sadtalker_video2pose/src/test_audio2coeff.py deleted file mode 100644 index d0f5ca9195bbc980c93fa3e37c6d06cc32953aee..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/test_audio2coeff.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -import torch -import numpy as np -from scipy.io import savemat, loadmat -from yacs.config import CfgNode as CN -from scipy.signal import savgol_filter - -import safetensors -import safetensors.torch - -from src.audio2pose_models.audio2pose import Audio2Pose -from src.audio2exp_models.networks import SimpleWrapperV2 -from src.audio2exp_models.audio2exp import Audio2Exp -from src.utils.safetensor_helper import load_x_from_safetensor - -def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if model is not None: - model.load_state_dict(checkpoint['model']) - if optimizer is not None: - optimizer.load_state_dict(checkpoint['optimizer']) - - return checkpoint['epoch'] - -class Audio2Coeff(): - - def __init__(self, sadtalker_path, device): - #load config - fcfg_pose = open(sadtalker_path['audio2pose_yaml_path']) - cfg_pose = CN.load_cfg(fcfg_pose) - cfg_pose.freeze() - fcfg_exp = open(sadtalker_path['audio2exp_yaml_path']) - cfg_exp = CN.load_cfg(fcfg_exp) - cfg_exp.freeze() - - # load audio2pose_model - self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device) - self.audio2pose_model = self.audio2pose_model.to(device) - self.audio2pose_model.eval() - for param in self.audio2pose_model.parameters(): - param.requires_grad = False - - try: - if sadtalker_path['use_safetensor']: - checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) - self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose')) - else: - load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device) - except: - raise Exception("Failed in loading audio2pose_checkpoint") - - # load audio2exp_model - netG = SimpleWrapperV2() - netG = netG.to(device) - for param in netG.parameters(): - netG.requires_grad = False - netG.eval() - try: - if sadtalker_path['use_safetensor']: - checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) - netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp')) - else: - load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device) - except: - raise Exception("Failed in loading audio2exp_checkpoint") - self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False) - self.audio2exp_model = self.audio2exp_model.to(device) - for param in self.audio2exp_model.parameters(): - param.requires_grad = False - self.audio2exp_model.eval() - - self.device = device - - def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None): - - with torch.no_grad(): - #test - results_dict_exp= self.audio2exp_model.test(batch) - exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64 - - #for class_id in range(1): - #class_id = 0#(i+10)%45 - #class_id = random.randint(0,46) #46 styles can be selected - batch['class'] = torch.LongTensor([pose_style]).to(self.device) - results_dict_pose = self.audio2pose_model.test(batch) - pose_pred = results_dict_pose['pose_pred'] #bs T 6 - - pose_len = pose_pred.shape[1] - if pose_len<13: - pose_len = int((pose_len-1)/2)*2+1 - pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device) - else: - pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) - - coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70 - - coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() - - if ref_pose_coeff_path is not None: - coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path) - - savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])), - {'coeff_3dmm': coeffs_pred_numpy}) - - return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])) - - def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path): - num_frames = coeffs_pred_numpy.shape[0] - refpose_coeff_dict = loadmat(ref_pose_coeff_path) - refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70] - refpose_num_frames = refpose_coeff.shape[0] - if refpose_num_frames= 0 - if hp.symmetric_mels: - return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value - else: - return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) - -def _denormalize(D): - if hp.allow_clipping_in_normalization: - if hp.symmetric_mels: - return (((np.clip(D, -hp.max_abs_value, - hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) - + hp.min_level_db) - else: - return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) - - if hp.symmetric_mels: - return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) - else: - return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/sadtalker_video2pose/src/utils/croper.py b/sadtalker_video2pose/src/utils/croper.py deleted file mode 100644 index 578372debdb8d2b99fe93d3d2ba2dfacf7cbb0ad..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/croper.py +++ /dev/null @@ -1,145 +0,0 @@ -import os -import cv2 -import time -import glob -import argparse -import scipy -import numpy as np -from PIL import Image -import torch -from tqdm import tqdm -from itertools import cycle - -from src.face3d.extract_kp_videos_safe import KeypointExtractor -from facexlib.alignment import landmark_98_to_68 - -import numpy as np -from PIL import Image - -class Preprocesser: - def __init__(self, device='cuda'): - self.predictor = KeypointExtractor(device) - - def get_landmark(self, img_np): - """get landmark with dlib - :return: np.array shape=(68, 2) - """ - with torch.no_grad(): - dets = self.predictor.det_net.detect_faces(img_np, 0.97) - - if len(dets) == 0: - return None - det = dets[0] - - img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :] - lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0] - - #### keypoints to the original location - lm[:,0] += int(det[0]) - lm[:,1] += int(det[1]) - - return lm - - def align_face(self, img, lm, output_size=1024): - """ - :param filepath: str - :return: PIL Image - """ - lm_chin = lm[0: 17] # left-right - lm_eyebrow_left = lm[17: 22] # left-right - lm_eyebrow_right = lm[22: 27] # left-right - lm_nose = lm[27: 31] # top-down - lm_nostrils = lm[31: 36] # top-down - lm_eye_left = lm[36: 42] # left-clockwise - lm_eye_right = lm[42: 48] # left-clockwise - lm_mouth_outer = lm[48: 60] # left-clockwise - lm_mouth_inner = lm[60: 68] # left-clockwise - - # Calculate auxiliary vectors. - eye_left = np.mean(lm_eye_left, axis=0) - eye_right = np.mean(lm_eye_right, axis=0) - eye_avg = (eye_left + eye_right) * 0.5 - eye_to_eye = eye_right - eye_left - mouth_left = lm_mouth_outer[0] - mouth_right = lm_mouth_outer[6] - mouth_avg = (mouth_left + mouth_right) * 0.5 - eye_to_mouth = mouth_avg - eye_avg - - # Choose oriented crop rectangle. - x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference - x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化 - x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度 - y = np.flipud(x) * [-1, 1] - c = eye_avg + eye_to_mouth * 0.1 - quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点 - qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍 - - # Shrink. - # 如果计算出的四边形太大了,就按比例缩小它 - shrink = int(np.floor(qsize / output_size * 0.5)) - if shrink > 1: - rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) - img = img.resize(rsize, Image.ANTIALIAS) - quad /= shrink - qsize /= shrink - else: - rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1])))) - - # Crop. - border = max(int(np.rint(qsize * 0.1)), 3) - crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), - int(np.ceil(max(quad[:, 1])))) - crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), - min(crop[3] + border, img.size[1])) - if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: - # img = img.crop(crop) - quad -= crop[0:2] - - # Pad. - pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), - int(np.ceil(max(quad[:, 1])))) - pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), - max(pad[3] - img.size[1] + border, 0)) - # if enable_padding and max(pad) > border - 4: - # pad = np.maximum(pad, int(np.rint(qsize * 0.3))) - # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') - # h, w, _ = img.shape - # y, x, _ = np.ogrid[:h, :w, :1] - # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), - # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) - # blur = qsize * 0.02 - # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) - # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) - # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') - # quad += pad[:2] - - # Transform. - quad = (quad + 0.5).flatten() - lx = max(min(quad[0], quad[2]), 0) - ly = max(min(quad[1], quad[7]), 0) - rx = min(max(quad[4], quad[6]), img.size[0]) - ry = min(max(quad[3], quad[5]), img.size[0]) - - # Save aligned image. - return rsize, crop, [lx, ly, rx, ry] - - def crop(self, img_np_list, still=False, xsize=512): # first frame for all video - # print(img_np_list) - img_np = img_np_list[0] - lm = self.get_landmark(img_np) - - if lm is None: - raise 'can not detect the landmark from source image' - rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - for _i in range(len(img_np_list)): - _inp = img_np_list[_i] - _inp = cv2.resize(_inp, (rsize[0], rsize[1])) - _inp = _inp[cly:cry, clx:crx] - if not still: - _inp = _inp[ly:ry, lx:rx] - img_np_list[_i] = _inp - return img_np_list, crop, quad - diff --git a/sadtalker_video2pose/src/utils/face_enhancer.py b/sadtalker_video2pose/src/utils/face_enhancer.py deleted file mode 100644 index 2664560a1d7199e81f1a50093f29d02de91d4bcc..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/face_enhancer.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -import torch - -from gfpgan import GFPGANer - -from tqdm import tqdm - -from src.utils.videoio import load_video_to_cv2 - -import cv2 - - -class GeneratorWithLen(object): - """ From https://stackoverflow.com/a/7460929 """ - - def __init__(self, gen, length): - self.gen = gen - self.length = length - - def __len__(self): - return self.length - - def __iter__(self): - return self.gen - -def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): - gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) - return list(gen) - -def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): - """ Provide a generator with a __len__ method so that it can passed to functions that - call len()""" - - if os.path.isfile(images): # handle video to images - # TODO: Create a generator version of load_video_to_cv2 - images = load_video_to_cv2(images) - - gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) - gen_with_len = GeneratorWithLen(gen, len(images)) - return gen_with_len - -def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): - """ Provide a generator function so that all of the enhanced images don't need - to be stored in memory at the same time. This can save tons of RAM compared to - the enhancer function. """ - - print('face enhancer....') - if not isinstance(images, list) and os.path.isfile(images): # handle video to images - images = load_video_to_cv2(images) - - # ------------------------ set up GFPGAN restorer ------------------------ - if method == 'gfpgan': - arch = 'clean' - channel_multiplier = 2 - model_name = 'GFPGANv1.4' - url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' - elif method == 'RestoreFormer': - arch = 'RestoreFormer' - channel_multiplier = 2 - model_name = 'RestoreFormer' - url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' - elif method == 'codeformer': # TODO: - arch = 'CodeFormer' - channel_multiplier = 2 - model_name = 'CodeFormer' - url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' - else: - raise ValueError(f'Wrong model version {method}.') - - - # ------------------------ set up background upsampler ------------------------ - if bg_upsampler == 'realesrgan': - if not torch.cuda.is_available(): # CPU - import warnings - warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' - 'If you really want to use it, please modify the corresponding codes.') - bg_upsampler = None - else: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - bg_upsampler = RealESRGANer( - scale=2, - model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', - model=model, - tile=400, - tile_pad=10, - pre_pad=0, - half=True) # need to set False in CPU mode - else: - bg_upsampler = None - - # determine model paths - model_path = os.path.join('gfpgan/weights', model_name + '.pth') - - if not os.path.isfile(model_path): - model_path = os.path.join('checkpoints', model_name + '.pth') - - if not os.path.isfile(model_path): - # download pre-trained models from url - model_path = url - - restorer = GFPGANer( - model_path=model_path, - upscale=2, - arch=arch, - channel_multiplier=channel_multiplier, - bg_upsampler=bg_upsampler) - - # ------------------------ restore ------------------------ - for idx in tqdm(range(len(images)), 'Face Enhancer:'): - - img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) - - # restore faces and background if necessary - cropped_faces, restored_faces, r_img = restorer.enhance( - img, - has_aligned=False, - only_center_face=False, - paste_back=True) - - r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) - yield r_img diff --git a/sadtalker_video2pose/src/utils/flow_util.py b/sadtalker_video2pose/src/utils/flow_util.py deleted file mode 100644 index f25046bab67cc8fbbb59efd02f48d7b6f22fc580..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/flow_util.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -import sys - - -def convert_flow_to_deformation(flow): - r"""convert flow fields to deformations. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - deformation (tensor): The deformation used for warpping - """ - b,c,h,w = flow.shape - flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1) - grid = make_coordinate_grid(flow) - # print(grid.shape, flow_norm.shape) - deformation = grid + flow_norm.permute(0,2,3,1) - return deformation - -def make_coordinate_grid(flow): - r"""obtain coordinate grid with the same size as the flow filed. - - Args: - flow (tensor): Flow field obtained by the model - Returns: - grid (tensor): The grid with the same size as the input flow - """ - b,c,h,w = flow.shape - - x = torch.arange(w).to(flow) - y = torch.arange(h).to(flow) - - x = (2 * (x / (w - 1)) - 1) - y = (2 * (y / (h - 1)) - 1) - - yy = y.view(-1, 1).repeat(1, w) - xx = x.view(1, -1).repeat(h, 1) - - meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) - meshed = meshed.expand(b, -1, -1, -1) - return meshed - - -def warp_image(source_image, deformation): - r"""warp the input image according to the deformation - - Args: - source_image (tensor): source images to be warpped - deformation (tensor): deformations used to warp the images; value in range (-1, 1) - Returns: - output (tensor): the warpped images - """ - _, h_old, w_old, _ = deformation.shape - _, _, h, w = source_image.shape - if h_old != h or w_old != w: - deformation = deformation.permute(0, 3, 1, 2) - deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear') - deformation = deformation.permute(0, 2, 3, 1) - return torch.nn.functional.grid_sample(source_image, deformation) - - - -# visualize flow -import numpy as np - -__all__ = ['load_flow', 'save_flow', 'vis_flow'] - - -def load_flow(path): - with open(path, 'rb') as f: - magic = float(np.fromfile(f, np.float32, count=1)[0]) - if magic == 202021.25: - w, h = np.fromfile(f, np.int32, count=1)[0], np.fromfile(f, np.int32, count=1)[0] - data = np.fromfile(f, np.float32, count=h * w * 2) - data.resize((h, w, 2)) - return data - return None - - -def save_flow(path, flow): - magic = np.array([202021.25], np.float32) - h, w = flow.shape[:2] - h, w = np.array([h], np.int32), np.array([w], np.int32) - - with open(path, 'wb') as f: - magic.tofile(f) - w.tofile(f) - h.tofile(f) - flow.tofile(f) - - - -def makeColorwheel(): - # color encoding scheme - - # adapted from the color circle idea described at - # http://members.shaw.ca/quadibloc/other/colint.htm - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - - colorwheel = np.zeros([ncols, 3]) # r g b - - col = 0 - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY, 1) / RY) - col += RY - - # YG - colorwheel[col:YG + col, 0] = 255 - np.floor(255 * np.arange(0, YG, 1) / YG) - colorwheel[col:YG + col, 1] = 255 - col += YG - - # GC - colorwheel[col:GC + col, 1] = 255 - colorwheel[col:GC + col, 2] = np.floor(255 * np.arange(0, GC, 1) / GC) - col += GC - - # CB - colorwheel[col:CB + col, 1] = 255 - np.floor(255 * np.arange(0, CB, 1) / CB) - colorwheel[col:CB + col, 2] = 255 - col += CB - - # BM - colorwheel[col:BM + col, 2] = 255 - colorwheel[col:BM + col, 0] = np.floor(255 * np.arange(0, BM, 1) / BM) - col += BM - - # MR - colorwheel[col:MR + col, 2] = 255 - np.floor(255 * np.arange(0, MR, 1) / MR) - colorwheel[col:MR + col, 0] = 255 - return colorwheel - - -def computeColor(u, v): - colorwheel = makeColorwheel() - nan_u = np.isnan(u) - nan_v = np.isnan(v) - nan_u = np.where(nan_u) - nan_v = np.where(nan_v) - - u[nan_u] = 0 - u[nan_v] = 0 - v[nan_u] = 0 - v[nan_v] = 0 - - ncols = colorwheel.shape[0] - radius = np.sqrt(u ** 2 + v ** 2) - a = np.arctan2(-v, -u) / np.pi - fk = (a + 1) / 2 * (ncols - 1) # -1~1 maped to 1~ncols - k0 = fk.astype(np.uint8) # 1, 2, ..., ncols - k1 = k0 + 1 - k1[k1 == ncols] = 0 - f = fk - k0 - - img = np.empty([k1.shape[0], k1.shape[1], 3]) - ncolors = colorwheel.shape[1] - for i in range(ncolors): - tmp = colorwheel[:, i] - col0 = tmp[k0] / 255 - col1 = tmp[k1] / 255 - col = (1 - f) * col0 + f * col1 - idx = radius <= 1 - col[idx] = 1 - radius[idx] * (1 - col[idx]) # increase saturation with radius - col[~idx] *= 0.75 # out of range - img[:, :, 2 - i] = np.floor(255 * col).astype(np.uint8) - - return img.astype(np.uint8) - - -def vis_flow(flow): - eps = sys.float_info.epsilon - UNKNOWN_FLOW_THRESH = 1e9 - UNKNOWN_FLOW = 1e10 - - u = flow[:, :, 0] - v = flow[:, :, 1] - - maxu = -999 - maxv = -999 - - minu = 999 - minv = 999 - - maxrad = -1 - # fix unknown flow - greater_u = np.where(u > UNKNOWN_FLOW_THRESH) - greater_v = np.where(v > UNKNOWN_FLOW_THRESH) - u[greater_u] = 0 - u[greater_v] = 0 - v[greater_u] = 0 - v[greater_v] = 0 - - maxu = max([maxu, np.amax(u)]) - minu = min([minu, np.amin(u)]) - - maxv = max([maxv, np.amax(v)]) - minv = min([minv, np.amin(v)]) - rad = np.sqrt(np.multiply(u, u) + np.multiply(v, v)) - maxrad = max([maxrad, np.amax(rad)]) - # print('max flow: %.4f flow range: u = %.3f .. %.3f; v = %.3f .. %.3f\n' % (maxrad, minu, maxu, minv, maxv)) - - u = u / (maxrad + eps) - v = v / (maxrad + eps) - img = computeColor(u, v) - return img[:, :, [2, 1, 0]] - - -def test_visualize_flow(): - flow = load_flow('out.flo') - img = vis_flow(flow) - - import cv2 - cv2.imwrite("img.png", img) diff --git a/sadtalker_video2pose/src/utils/hparams.py b/sadtalker_video2pose/src/utils/hparams.py deleted file mode 100644 index 83c312d767c35b9adc988157243efc02129fdb84..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/hparams.py +++ /dev/null @@ -1,160 +0,0 @@ -from glob import glob -import os - -class HParams: - def __init__(self, **kwargs): - self.data = {} - - for key, value in kwargs.items(): - self.data[key] = value - - def __getattr__(self, key): - if key not in self.data: - raise AttributeError("'HParams' object has no attribute %s" % key) - return self.data[key] - - def set_hparam(self, key, value): - self.data[key] = value - - -# Default hyperparameters -hparams = HParams( - num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality - # network - rescale=True, # Whether to rescale audio prior to preprocessing - rescaling_max=0.9, # Rescaling value - - # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction - # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder - # Does not work if n_ffit is not multiple of hop_size!! - use_lws=False, - - n_fft=800, # Extra window size is filled with 0 paddings to match this parameter - hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) - win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) - sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) - - frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) - - # Mel and Linear spectrograms normalization/scaling and clipping - signal_normalization=True, - # Whether to normalize mel spectrograms to some predefined range (following below parameters) - allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True - symmetric_mels=True, - # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, - # faster and cleaner convergence) - max_abs_value=4., - # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not - # be too big to avoid gradient explosion, - # not too small for fast convergence) - # Contribution by @begeekmyfriend - # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude - # levels. Also allows for better G&L phase reconstruction) - preemphasize=True, # whether to apply filter - preemphasis=0.97, # filter coefficient. - - # Limits - min_level_db=-100, - ref_level_db=20, - fmin=55, - # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To - # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) - fmax=7600, # To be increased/reduced depending on data. - - ###################### Our training parameters ################################# - img_size=96, - fps=25, - - batch_size=16, - initial_learning_rate=1e-4, - nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs - num_workers=20, - checkpoint_interval=3000, - eval_interval=3000, - writer_interval=300, - save_optimizer_state=True, - - syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. - syncnet_batch_size=64, - syncnet_lr=1e-4, - syncnet_eval_interval=1000, - syncnet_checkpoint_interval=10000, - - disc_wt=0.07, - disc_initial_learning_rate=1e-4, -) - - - -# Default hyperparameters -hparamsdebug = HParams( - num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality - # network - rescale=True, # Whether to rescale audio prior to preprocessing - rescaling_max=0.9, # Rescaling value - - # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction - # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder - # Does not work if n_ffit is not multiple of hop_size!! - use_lws=False, - - n_fft=800, # Extra window size is filled with 0 paddings to match this parameter - hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) - win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) - sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) - - frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) - - # Mel and Linear spectrograms normalization/scaling and clipping - signal_normalization=True, - # Whether to normalize mel spectrograms to some predefined range (following below parameters) - allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True - symmetric_mels=True, - # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, - # faster and cleaner convergence) - max_abs_value=4., - # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not - # be too big to avoid gradient explosion, - # not too small for fast convergence) - # Contribution by @begeekmyfriend - # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude - # levels. Also allows for better G&L phase reconstruction) - preemphasize=True, # whether to apply filter - preemphasis=0.97, # filter coefficient. - - # Limits - min_level_db=-100, - ref_level_db=20, - fmin=55, - # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To - # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) - fmax=7600, # To be increased/reduced depending on data. - - ###################### Our training parameters ################################# - img_size=96, - fps=25, - - batch_size=2, - initial_learning_rate=1e-3, - nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs - num_workers=0, - checkpoint_interval=10000, - eval_interval=10, - writer_interval=5, - save_optimizer_state=True, - - syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. - syncnet_batch_size=64, - syncnet_lr=1e-4, - syncnet_eval_interval=10000, - syncnet_checkpoint_interval=10000, - - disc_wt=0.07, - disc_initial_learning_rate=1e-4, -) - - -def hparams_debug_string(): - values = hparams.values() - hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] - return "Hyperparameters:\n" + "\n".join(hp) diff --git a/sadtalker_video2pose/src/utils/init_path.py b/sadtalker_video2pose/src/utils/init_path.py deleted file mode 100644 index 65239fe3281798b2472f7ca0557a96157d9de930..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/init_path.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import glob - -def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): - - if old_version: - #### load all the checkpoint of `pth` - sadtalker_paths = { - 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), - 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), - 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), - 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), - 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') - } - - use_safetensor = False - elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): - print('using safetensor as default') - sadtalker_paths = { - "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), - } - use_safetensor = True - else: - print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") - use_safetensor = False - - sadtalker_paths = { - 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), - 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), - 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), - 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), - 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') - } - - sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' - sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') - sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') - sadtalker_paths['pirender_yaml_path'] = os.path.join(config_dir, 'facerender_pirender.yaml') - sadtalker_paths['pirender_checkpoint'] = os.path.join(checkpoint_dir, 'epoch_00190_iteration_000400000_checkpoint.pt') - sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') - - if 'full' in preprocess: - sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') - sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') - else: - sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') - sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') - - return sadtalker_paths \ No newline at end of file diff --git a/sadtalker_video2pose/src/utils/model2safetensor.py b/sadtalker_video2pose/src/utils/model2safetensor.py deleted file mode 100644 index c5b76e3d67a06fdbf6646590d44b8c225bc73d79..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/model2safetensor.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import yaml -import os - -import safetensors -from safetensors.torch import save_file -from yacs.config import CfgNode as CN -import sys - -sys.path.append('/apdcephfs/private_shadowcun/SadTalker') - -from src.face3d.models import networks - -from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector -from src.facerender.modules.mapping import MappingNet -from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator - -from src.audio2pose_models.audio2pose import Audio2Pose -from src.audio2exp_models.networks import SimpleWrapperV2 -from src.test_audio2coeff import load_cpk - -size = 256 -############ face vid2vid -config_path = os.path.join('src', 'config', 'facerender.yaml') -current_root_path = '.' - -path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') -net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='') -checkpoint = torch.load(path_of_net_recon_model, map_location='cpu') -net_recon.load_state_dict(checkpoint['net_recon']) - -with open(config_path) as f: - config = yaml.safe_load(f) - -generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], - **config['model_params']['common_params']) -kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], - **config['model_params']['common_params']) -he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], - **config['model_params']['common_params']) -mapping = MappingNet(**config['model_params']['mapping_params']) - -def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, - kp_detector=None, he_estimator=None, optimizer_generator=None, - optimizer_discriminator=None, optimizer_kp_detector=None, - optimizer_he_estimator=None, device="cpu"): - - checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) - if generator is not None: - generator.load_state_dict(checkpoint['generator']) - if kp_detector is not None: - kp_detector.load_state_dict(checkpoint['kp_detector']) - if he_estimator is not None: - he_estimator.load_state_dict(checkpoint['he_estimator']) - if discriminator is not None: - try: - discriminator.load_state_dict(checkpoint['discriminator']) - except: - print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') - if optimizer_generator is not None: - optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) - if optimizer_discriminator is not None: - try: - optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) - except RuntimeError as e: - print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') - if optimizer_kp_detector is not None: - optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) - if optimizer_he_estimator is not None: - optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) - - return checkpoint['epoch'] - - -def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, - kp_detector=None, he_estimator=None, - device="cpu"): - - checkpoint = safetensors.torch.load_file(checkpoint_path) - - if generator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'generator' in k: - x_generator[k.replace('generator.', '')] = v - generator.load_state_dict(x_generator) - if kp_detector is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'kp_extractor' in k: - x_generator[k.replace('kp_extractor.', '')] = v - kp_detector.load_state_dict(x_generator) - if he_estimator is not None: - x_generator = {} - for k,v in checkpoint.items(): - if 'he_estimator' in k: - x_generator[k.replace('he_estimator.', '')] = v - he_estimator.load_state_dict(x_generator) - - return None - -free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar' -load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) - -wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') - -audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') -audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') - -audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') -audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') - -fcfg_pose = open(audio2pose_yaml_path) -cfg_pose = CN.load_cfg(fcfg_pose) -cfg_pose.freeze() -audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint) -audio2pose_model.eval() -load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu') - -# load audio2exp_model -netG = SimpleWrapperV2() -netG.eval() -load_cpk(audio2exp_checkpoint, model=netG, device='cpu') - -class SadTalker(torch.nn.Module): - def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): - super(SadTalker, self).__init__() - self.kp_extractor = kp_extractor - self.generator = generator - self.audio2exp = netG - self.audio2pose = audio2pose - self.face_3drecon = face_3drecon - - -model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon) - -# here, we want to convert it to safetensor -save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") - -### test -load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) \ No newline at end of file diff --git a/sadtalker_video2pose/src/utils/paste_pic.py b/sadtalker_video2pose/src/utils/paste_pic.py deleted file mode 100644 index 4da8952e6933698fec6c7cf35042cb5b1f0dcba5..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/paste_pic.py +++ /dev/null @@ -1,69 +0,0 @@ -import cv2, os -import numpy as np -from tqdm import tqdm -import uuid - -from src.utils.videoio import save_video_with_watermark - -def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False): - - if not os.path.isfile(pic_path): - raise ValueError('pic_path must be a valid path to video/image file') - elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_img = cv2.imread(pic_path) - else: - # loader for videos - video_stream = cv2.VideoCapture(pic_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - break - full_img = frame - frame_h = full_img.shape[0] - frame_w = full_img.shape[1] - - video_stream = cv2.VideoCapture(video_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - crop_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - crop_frames.append(frame) - - if len(crop_info) != 3: - print("you didn't crop the image") - return - else: - r_w, r_h = crop_info[0] - clx, cly, crx, cry = crop_info[1] - lx, ly, rx, ry = crop_info[2] - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - if extended_crop: - oy1, oy2, ox1, ox2 = cly, cry, clx, crx - else: - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - - tmp_path = str(uuid.uuid4())+'.mp4' - out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) - for crop_frame in tqdm(crop_frames, 'seamlessClone:'): - p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) - - mask = 255*np.ones(p.shape, p.dtype) - location = ((ox1+ox2) // 2, (oy1+oy2) // 2) - gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE) - out_tmp.write(gen_img) - - out_tmp.release() - - save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False) - os.remove(tmp_path) diff --git a/sadtalker_video2pose/src/utils/preprocess.py b/sadtalker_video2pose/src/utils/preprocess.py deleted file mode 100644 index 4956c00d273467f8a0c020312401158b06c4fecd..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/preprocess.py +++ /dev/null @@ -1,170 +0,0 @@ -import numpy as np -import cv2, os, sys, torch -from tqdm import tqdm -from PIL import Image - -# 3dmm extraction -import safetensors -import safetensors.torch -from src.face3d.util.preprocess import align_img -from src.face3d.util.load_mats import load_lm3d -from src.face3d.models import networks - -from scipy.io import loadmat, savemat -from src.utils.croper import Preprocesser - - -import warnings - -from src.utils.safetensor_helper import load_x_from_safetensor -warnings.filterwarnings("ignore") - -def split_coeff(coeffs): - """ - Return: - coeffs_dict -- a dict of torch.tensors - - Parameters: - coeffs -- torch.tensor, size (B, 256) - """ - id_coeffs = coeffs[:, :80] - exp_coeffs = coeffs[:, 80: 144] - tex_coeffs = coeffs[:, 144: 224] - angles = coeffs[:, 224: 227] - gammas = coeffs[:, 227: 254] - translations = coeffs[:, 254:] - return { - 'id': id_coeffs, - 'exp': exp_coeffs, - 'tex': tex_coeffs, - 'angle': angles, - 'gamma': gammas, - 'trans': translations - } - - -class CropAndExtract(): - def __init__(self, sadtalker_path, device): - - self.propress = Preprocesser(device) - self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) - - if sadtalker_path['use_safetensor']: - checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) - self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) - else: - checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) - self.net_recon.load_state_dict(checkpoint['net_recon']) - - self.net_recon.eval() - self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) - self.device = device - - def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256): - - pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] - - landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') - coeff_path = os.path.join(save_dir, pic_name+'.mat') - png_path = os.path.join(save_dir, pic_name+'.png') - - #load input - if not os.path.isfile(input_path): - raise ValueError('input_path must be a valid path to video/image file') - elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_frames = [cv2.imread(input_path)] - fps = 25 - else: - # loader for videos - video_stream = cv2.VideoCapture(input_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - full_frames.append(frame) - if source_image_flag: - break - - x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] - - #### crop images as the - if 'crop' in crop_or_resize.lower(): # default crop - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - elif 'full' in crop_or_resize.lower(): - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - else: # resize mode - oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] - crop_info = ((ox2 - ox1, oy2 - oy1), None, None) - - frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] - if len(frames_pil) == 0: - print('No face is detected in the input file') - return None, None - - # save crop info - for frame in frames_pil: - cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) - - # 2. get the landmark according to the detected face. - if not os.path.isfile(landmarks_path): - lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) - else: - print(' Using saved landmarks.') - lm = np.loadtxt(landmarks_path).astype(np.float32) - lm = lm.reshape([len(x_full_frames), -1, 2]) - - if not os.path.isfile(coeff_path): - # load 3dmm paramter generator from Deep3DFaceRecon_pytorch - video_coeffs, full_coeffs = [], [] - for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): - frame = frames_pil[idx] - W,H = frame.size - lm1 = lm[idx].reshape([-1, 2]) - - if np.mean(lm1) == -1: - lm1 = (self.lm3d_std[:, :2]+1)/2. - lm1 = np.concatenate( - [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 - ) - else: - lm1[:, -1] = H - 1 - lm1[:, -1] - - trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) - - trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32) - im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) - - with torch.no_grad(): - full_coeff = self.net_recon(im_t) - coeffs = split_coeff(full_coeff) - - pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} - - pred_coeff = np.concatenate([ - pred_coeff['exp'], - pred_coeff['angle'], - pred_coeff['trans'], - trans_params_m[2:][None], - ], 1) - video_coeffs.append(pred_coeff) - full_coeffs.append(full_coeff.cpu().numpy()) - - semantic_npy = np.array(video_coeffs)[:,0] - - savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params}) - - return coeff_path, png_path, crop_info diff --git a/sadtalker_video2pose/src/utils/preprocess_fromvideo.py b/sadtalker_video2pose/src/utils/preprocess_fromvideo.py deleted file mode 100644 index e1e6c34055e557b6b39c5c8c1a5fd08842d17f57..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/preprocess_fromvideo.py +++ /dev/null @@ -1,195 +0,0 @@ -import numpy as np -import cv2, os, sys, torch -from tqdm import tqdm -from PIL import Image - -# 3dmm extraction -import safetensors -import safetensors.torch -from src.face3d.util.preprocess import align_img -from src.face3d.util.load_mats import load_lm3d -from src.face3d.models import networks - -from scipy.io import loadmat, savemat -from src.utils.croper import Preprocesser - - -import warnings - -from src.utils.safetensor_helper import load_x_from_safetensor -warnings.filterwarnings("ignore") - - -def smooth_3dmm_params(params, window_size=5): - # 创建一个新的数组来存储平滑后的参数 - smoothed_params = np.zeros_like(params) - - # 对每个参数进行平滑处理 - for i in range(params.shape[1]): - - # 在参数周围创建一个滑动窗口 - window = np.ones(int(window_size))/float(window_size) - smoothed_param = np.convolve(params[:, i], window, 'same') - - # 将平滑后的参数存储在新数组中 - smoothed_params[:, i] = smoothed_param - - return smoothed_params - - - -def split_coeff(coeffs): - """ - Return: - coeffs_dict -- a dict of torch.tensors - - Parameters: - coeffs -- torch.tensor, size (B, 256) - """ - id_coeffs = coeffs[:, :80] - exp_coeffs = coeffs[:, 80: 144] - tex_coeffs = coeffs[:, 144: 224] - angles = coeffs[:, 224: 227] - gammas = coeffs[:, 227: 254] - translations = coeffs[:, 254:] - return { - 'id': id_coeffs, - 'exp': exp_coeffs, - 'tex': tex_coeffs, - 'angle': angles, - 'gamma': gammas, - 'trans': translations - } - - -class CropAndExtract(): - def __init__(self, sadtalker_path, device): - - self.propress = Preprocesser(device) - self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) - - if sadtalker_path['use_safetensor']: - checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) - self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) - else: - checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) - self.net_recon.load_state_dict(checkpoint['net_recon']) - - self.net_recon.eval() - self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) - self.device = device - - def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256, if_smooth=False): - - pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] - - landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') - coeff_path = os.path.join(save_dir, pic_name+'.mat') - png_path = os.path.join(save_dir, pic_name+'.png') - - #load input - if not os.path.isfile(input_path): - raise ValueError('input_path must be a valid path to video/image file') - elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: - # loader for first frame - full_frames = [cv2.imread(input_path)] - fps = 25 - else: - # loader for videos - video_stream = cv2.VideoCapture(input_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - full_frames.append(frame) - if source_image_flag: - break - - x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] - - # print(x_full_frames) - - #### crop images as the - if 'crop' in crop_or_resize.lower(): # default crop - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - elif 'full' in crop_or_resize.lower(): - x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) - clx, cly, crx, cry = crop - lx, ly, rx, ry = quad - lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) - oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx - crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) - else: # resize mode - oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] - crop_info = ((ox2 - ox1, oy2 - oy1), None, None) - - frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] - if len(frames_pil) == 0: - print('No face is detected in the input file') - return None, None - - # save crop info - for frame in frames_pil: - cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) - - # 2. get the landmark according to the detected face. - if not os.path.isfile(landmarks_path): - lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) - else: - print(' Using saved landmarks.') - lm = np.loadtxt(landmarks_path).astype(np.float32) - lm = lm.reshape([len(x_full_frames), -1, 2]) - - if not os.path.isfile(coeff_path): - # load 3dmm paramter generator from Deep3DFaceRecon_pytorch - video_coeffs, full_coeffs = [], [] - for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): - frame = frames_pil[idx] - W,H = frame.size - lm1 = lm[idx].reshape([-1, 2]) - - if np.mean(lm1) == -1: - lm1 = (self.lm3d_std[:, :2]+1)/2. - lm1 = np.concatenate( - [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 - ) - else: - lm1[:, -1] = H - 1 - lm1[:, -1] - - trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) - - trans_params_m = np.array([float(item) for item in np.hsplit(trans_params, len(trans_params))]).astype(np.float32) - im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) - - with torch.no_grad(): - full_coeff = self.net_recon(im_t) - coeffs = split_coeff(full_coeff) - - pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} - - pred_coeff = np.concatenate([ - pred_coeff['exp'], - pred_coeff['angle'], - pred_coeff['trans'], - # trans_params_m[2:][None], - ], 1) - video_coeffs.append(pred_coeff) - full_coeffs.append(full_coeff.cpu().numpy()) - - semantic_npy = np.array(video_coeffs)[:,0] - - if if_smooth: - # pass - semantic_npy[:, -6:] = smooth_3dmm_params(semantic_npy[:, -6:], window_size=3) - - savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0], 'trans_params': trans_params}) - - return coeff_path, png_path, crop_info diff --git a/sadtalker_video2pose/src/utils/safetensor_helper.py b/sadtalker_video2pose/src/utils/safetensor_helper.py deleted file mode 100644 index 164ed9621eba24e0b3050ca663fcb60123517158..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/safetensor_helper.py +++ /dev/null @@ -1,8 +0,0 @@ - - -def load_x_from_safetensor(checkpoint, key): - x_generator = {} - for k,v in checkpoint.items(): - if key in k: - x_generator[k.replace(key+'.', '')] = v - return x_generator \ No newline at end of file diff --git a/sadtalker_video2pose/src/utils/text2speech.py b/sadtalker_video2pose/src/utils/text2speech.py deleted file mode 100644 index a0fe21daf74fcd01767b17378b7076c9dd424248..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/text2speech.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import tempfile -from TTS.api import TTS - - -class TTSTalker(): - def __init__(self) -> None: - model_name = TTS.list_models()[0] - self.tts = TTS(model_name) - - def test(self, text, language='en'): - - tempf = tempfile.NamedTemporaryFile( - delete = False, - suffix = ('.'+'wav'), - ) - - self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name) - - return tempf.name \ No newline at end of file diff --git a/sadtalker_video2pose/src/utils/videoio.py b/sadtalker_video2pose/src/utils/videoio.py deleted file mode 100644 index d604ae5b098006f3e59cf3c0133779ffd1cc9d5a..0000000000000000000000000000000000000000 --- a/sadtalker_video2pose/src/utils/videoio.py +++ /dev/null @@ -1,41 +0,0 @@ -import shutil -import uuid - -import os - -import cv2 - -def load_video_to_cv2(input_path): - video_stream = cv2.VideoCapture(input_path) - fps = video_stream.get(cv2.CAP_PROP_FPS) - full_frames = [] - while 1: - still_reading, frame = video_stream.read() - if not still_reading: - video_stream.release() - break - full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - return full_frames - -def save_video_with_watermark(video, audio, save_path, watermark=False): - temp_file = str(uuid.uuid4())+'.mp4' - cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec mpeg4 "%s"' % (video, audio, temp_file) - os.system(cmd) - - if watermark is False: - shutil.move(temp_file, save_path) - else: - # watermark - try: - ##### check if stable-diffusion-webui - import webui - from modules import paths - watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png" - except: - # get the root path of sadtalker. - dir_path = os.path.dirname(os.path.realpath(__file__)) - watarmark_path = dir_path+"/../../docs/sadtalker_logo.png" - - cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) - os.system(cmd) - os.remove(temp_file) \ No newline at end of file diff --git a/utils/flow_viz.py b/utils/flow_viz.py deleted file mode 100644 index 73c0a357d91e785127b2b9513b2a6951f4ceaf1e..0000000000000000000000000000000000000000 --- a/utils/flow_viz.py +++ /dev/null @@ -1,291 +0,0 @@ -# MIT License -# -# Copyright (c) 2018 Tom Runia -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to conditions. -# -# Author: Tom Runia -# Date Created: 2018-08-03 - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from PIL import Image -import torch - - -def make_colorwheel(): - ''' - Generates a color wheel for optical flow visualization as presented in: - Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) - URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf - According to the C++ source code of Daniel Scharstein - According to the Matlab source code of Deqing Sun - ''' - - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - colorwheel = np.zeros((ncols, 3)) - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) - col = col + RY - # YG - colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) - colorwheel[col:col + YG, 1] = 255 - col = col + YG - # GC - colorwheel[col:col + GC, 1] = 255 - colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) - col = col + GC - # CB - colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) - colorwheel[col:col + CB, 2] = 255 - col = col + CB - # BM - colorwheel[col:col + BM, 2] = 255 - colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) - col = col + BM - # MR - colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) - colorwheel[col:col + MR, 0] = 255 - return colorwheel - - -def flow_compute_color(u, v, convert_to_bgr=False): - ''' - Applies the flow color wheel to (possibly clipped) flow components u and v. - According to the C++ source code of Daniel Scharstein - According to the Matlab source code of Deqing Sun - :param u: np.ndarray, input horizontal flow - :param v: np.ndarray, input vertical flow - :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB - :return: - ''' - - flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) - - colorwheel = make_colorwheel() # shape [55x3] - ncols = colorwheel.shape[0] - - rad = np.sqrt(np.square(u) + np.square(v)) - a = np.arctan2(-v, -u) / np.pi - - fk = (a + 1) / 2 * (ncols - 1) + 1 - k0 = np.floor(fk).astype(np.int32) - k1 = k0 + 1 - k1[k1 == ncols] = 1 - f = fk - k0 - - for i in range(colorwheel.shape[1]): - tmp = colorwheel[:, i] - col0 = tmp[k0] / 255.0 - col1 = tmp[k1] / 255.0 - col = (1 - f) * col0 + f * col1 - - idx = (rad <= 1) - col[idx] = 1 - rad[idx] * (1 - col[idx]) - col[~idx] = col[~idx] * 0.75 # out of range? - - # Note the 2-i => BGR instead of RGB - ch_idx = 2 - i if convert_to_bgr else i - flow_image[:, :, ch_idx] = np.floor(255 * col) - - return flow_image - - -def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): - ''' - Expects a two dimensional flow image of shape [H,W,2] - According to the C++ source code of Daniel Scharstein - According to the Matlab source code of Deqing Sun - :param flow_uv: np.ndarray of shape [H,W,2] - :param clip_flow: float, maximum clipping value for flow - :return: - ''' - - assert flow_uv.ndim == 3, 'input flow must have three dimensions' - assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' - - if clip_flow is not None: - flow_uv = np.clip(flow_uv, 0, clip_flow) - - u = flow_uv[:, :, 0] - v = flow_uv[:, :, 1] - - rad = np.sqrt(np.square(u) + np.square(v)) - rad_max = np.max(rad) - - epsilon = 1e-5 - u = u / (rad_max + epsilon) - v = v / (rad_max + epsilon) - - return flow_compute_color(u, v, convert_to_bgr) - - -UNKNOWN_FLOW_THRESH = 1e7 -SMALLFLOW = 0.0 -LARGEFLOW = 1e8 - - -def make_color_wheel(): - """ - Generate color wheel according Middlebury color code - :return: Color wheel - """ - RY = 15 - YG = 6 - GC = 4 - CB = 11 - BM = 13 - MR = 6 - - ncols = RY + YG + GC + CB + BM + MR - - colorwheel = np.zeros([ncols, 3]) - - col = 0 - - # RY - colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) - col += RY - - # YG - colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) - colorwheel[col:col + YG, 1] = 255 - col += YG - - # GC - colorwheel[col:col + GC, 1] = 255 - colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) - col += GC - - # CB - colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) - colorwheel[col:col + CB, 2] = 255 - col += CB - - # BM - colorwheel[col:col + BM, 2] = 255 - colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) - col += + BM - - # MR - colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) - colorwheel[col:col + MR, 0] = 255 - - return colorwheel - - -def compute_color(u, v): - """ - compute optical flow color map - :param u: optical flow horizontal map - :param v: optical flow vertical map - :return: optical flow in color code - """ - [h, w] = u.shape - img = np.zeros([h, w, 3]) - nanIdx = np.isnan(u) | np.isnan(v) - u[nanIdx] = 0 - v[nanIdx] = 0 - - colorwheel = make_color_wheel() - ncols = np.size(colorwheel, 0) - - rad = np.sqrt(u ** 2 + v ** 2) - - a = np.arctan2(-v, -u) / np.pi - - fk = (a + 1) / 2 * (ncols - 1) + 1 - - k0 = np.floor(fk).astype(int) - - k1 = k0 + 1 - k1[k1 == ncols + 1] = 1 - f = fk - k0 - - for i in range(0, np.size(colorwheel, 1)): - tmp = colorwheel[:, i] - col0 = tmp[k0 - 1] / 255 - col1 = tmp[k1 - 1] / 255 - col = (1 - f) * col0 + f * col1 - - idx = rad <= 1 - col[idx] = 1 - rad[idx] * (1 - col[idx]) - notidx = np.logical_not(idx) - - col[notidx] *= 0.75 - img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) - - return img - - -# from https://github.com/gengshan-y/VCN -def flow_to_image(flow): - """ - Convert flow into middlebury color code image - :param flow: optical flow map - :return: optical flow image in middlebury color - """ - u = flow[:, :, 0] - v = flow[:, :, 1] - - # maxu = -999. - # maxv = -999. - # minu = 999. - # minv = 999. - - idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) - u[idxUnknow] = 0 - v[idxUnknow] = 0 - - # maxu = max(maxu, np.max(u)) - # minu = min(minu, np.min(u)) - - # maxv = max(maxv, np.max(v)) - # minv = min(minv, np.min(v)) - - rad = torch.sqrt(u ** 2 + v ** 2) - maxrad = max(-1, torch.max(rad).cpu().numpy()) - - u = u / (maxrad + np.finfo(float).eps) - v = v / (maxrad + np.finfo(float).eps) - - img = compute_color(u.cpu().numpy(), v.cpu().numpy()) - - idx = np.repeat(idxUnknow[:, :, np.newaxis].cpu().numpy(), 3, axis=2) - img[idx] = 0 - - return np.uint8(img) - - -def save_vis_flow_tofile(flow, output_path): - vis_flow = flow_to_image(flow) - Image.fromarray(vis_flow).save(output_path) - - -def flow_tensor_to_image(flow): - """Used for tensorboard visualization""" - flow = flow.permute(1, 2, 0) # [H, W, 2] - flow = flow.detach().cpu().numpy() - flow = flow_to_image(flow) # [H, W, 3] - flow = np.transpose(flow, (2, 0, 1)) # [3, H, W] - - return flow diff --git a/utils/scheduling_euler_discrete_karras_fix.py b/utils/scheduling_euler_discrete_karras_fix.py deleted file mode 100644 index 2de68461afb061e2bc5efb3efeb8e54c81b09ca6..0000000000000000000000000000000000000000 --- a/utils/scheduling_euler_discrete_karras_fix.py +++ /dev/null @@ -1,556 +0,0 @@ -# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput, logging -from diffusers.utils.torch_utils import randn_tensor -from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -import torch.nn.functional as F - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete -class EulerDiscreteSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's `step` function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample `(x_{0})` based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.FloatTensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.FloatTensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - -class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Euler scheduler. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): - The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): - The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, *optional*): - Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). - interpolation_type(`str`, defaults to `"linear"`, *optional*): - The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of - `"linear"` or `"log_linear"`. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable - Diffusion. - rescale_betas_zero_snr (`bool`, defaults to `False`): - Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and - dark samples instead of limiting it to samples with medium brightness. Loosely related to - [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - interpolation_type: str = "linear", - use_karras_sigmas: Optional[bool] = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, - timestep_spacing: str = "linspace", - timestep_type: str = "discrete", # can be "discrete" or "continuous" - steps_offset: int = 0, - rescale_betas_zero_snr: bool = False, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.betas) - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - if rescale_betas_zero_snr: - # Close to 0 without being 0 so first sigma is not inf - # FP16 smallest positive subnormal works well here - self.alphas_cumprod[-1] = 2**-24 - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - - sigmas = sigmas[::-1].copy() - - if self.use_karras_sigmas: - log_sigmas = np.log(sigmas) - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_train_timesteps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - # setable values - self.num_inference_steps = None - - # TODO: Support the full EDM scalings for all prediction types and timestep types - if timestep_type == "continuous" and prediction_type == "v_prediction": - self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) - else: - self.timesteps = torch.from_numpy(timesteps.astype(np.float32)) - - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - - self.is_scale_input_called = False - self.use_karras_sigmas = use_karras_sigmas - - self._step_index = None - - @property - def init_noise_sigma(self): - # standard deviation of the initial noise distribution - max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() - if self.config.timestep_spacing in ["linspace", "trailing"]: - return max_sigma - - return (max_sigma**2 + 1) ** 0.5 - - @property - def step_index(self): - """ - The index counter for current timestep. It will increae 1 after each scheduler step. - """ - return self._step_index - - def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] - ) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. - - Args: - sample (`torch.FloatTensor`): - The input sample. - timestep (`int`, *optional*): - The current timestep in the diffusion chain. - - Returns: - `torch.FloatTensor`: - A scaled input sample. - """ - if self.step_index is None: - self._init_step_index(timestep) - - sigma = self.sigmas[self.step_index] - sample = sample / ((sigma**2 + 1) ** 0.5) - - self.is_scale_input_called = True - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ - ::-1 - ].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - - if self.config.interpolation_type == "linear": - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - elif self.config.interpolation_type == "log_linear": - sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() - else: - raise ValueError( - f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" - " 'linear' or 'log_linear'" - ) - - if self.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - - # TODO: Support the full EDM scalings for all prediction types and timestep types - if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": - self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device) - else: - self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) - - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self._step_index = None - - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(np.maximum(sigma, 1e-10)) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - - # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - index_candidates = (self.timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] - - self._step_index = step_index.item() - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - s_churn: float = 0.0, - s_tmin: float = 0.0, - s_tmax: float = float("inf"), - s_noise: float = 1.0, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - s_churn (`float`): - s_tmin (`float`): - s_tmax (`float`): - s_noise (`float`, defaults to 1.0): - Scaling factor for noise added to the sample. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or - tuple. - - Returns: - [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. - """ - - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if not self.is_scale_input_called: - logger.warning( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Upcast to avoid precision issues when computing prev_sample - sample = sample.to(torch.float32) - - sigma = self.sigmas[self.step_index] - - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 - - noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator - ) - - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) - - if gamma > 0: - sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility - if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma_hat * model_output - elif self.config.prediction_type == "v_prediction": - # denoised = model_output * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma_hat - - dt = self.sigmas[self.step_index + 1] - sigma_hat - - prev_sample = sample + derivative * dt - - # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/utils/utils.py b/utils/utils.py deleted file mode 100644 index b296648f28598cd5f8d0fd9b0613b9173e1b9aad..0000000000000000000000000000000000000000 --- a/utils/utils.py +++ /dev/null @@ -1,269 +0,0 @@ -# -*- coding:utf-8 -*- -import os -import sys -import shutil -import logging -import colorlog -from tqdm import tqdm -import time -import yaml -import random -import importlib -from PIL import Image -from warnings import simplefilter -import imageio -import math -import collections -import json -import numpy as np -import torch -import torch.nn as nn -from torch.optim import Adam -import torch.nn.functional as F -from torch.utils.data import DataLoader -from torch.utils.data import DataLoader, Dataset -from einops import rearrange, repeat -import torch.distributed as dist -from torchvision import datasets, transforms, utils - -logging.getLogger().setLevel(logging.WARNING) -simplefilter(action='ignore', category=FutureWarning) - -def get_logger(filename=None): - """ - examples: - logger = get_logger('try_logging.txt') - - logger.debug("Do something.") - logger.info("Start print log.") - logger.warning("Something maybe fail.") - try: - raise ValueError() - except ValueError: - logger.error("Error", exc_info=True) - - tips: - DO NOT logger.inf(some big tensors since color may not helpful.) - """ - logger = logging.getLogger('utils') - level = logging.DEBUG - logger.setLevel(level=level) - # Use propagate to avoid multiple loggings. - logger.propagate = False - # Remove %(levelname)s since we have colorlog to represent levelname. - format_str = '[%(asctime)s <%(filename)s:%(lineno)d> %(funcName)s] %(message)s' - - streamHandler = logging.StreamHandler() - streamHandler.setLevel(level) - coloredFormatter = colorlog.ColoredFormatter( - '%(log_color)s' + format_str, - datefmt='%Y-%m-%d %H:%M:%S', - reset=True, - log_colors={ - 'DEBUG': 'cyan', - # 'INFO': 'white', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'reg,bg_white', - } - ) - - streamHandler.setFormatter(coloredFormatter) - logger.addHandler(streamHandler) - - if filename: - fileHandler = logging.FileHandler(filename) - fileHandler.setLevel(level) - formatter = logging.Formatter(format_str) - fileHandler.setFormatter(formatter) - logger.addHandler(fileHandler) - - # Fix multiple logging for torch.distributed - try: - class UniqueLogger: - def __init__(self, logger): - self.logger = logger - self.local_rank = torch.distributed.get_rank() - - def info(self, msg, *args, **kwargs): - if self.local_rank == 0: - return self.logger.info(msg, *args, **kwargs) - - def warning(self, msg, *args, **kwargs): - if self.local_rank == 0: - return self.logger.warning(msg, *args, **kwargs) - - logger = UniqueLogger(logger) - # AssertionError for gpu with no distributed - # AttributeError for no gpu. - except Exception: - pass - return logger - - -logger = get_logger() - -def split_filename(filename): - absname = os.path.abspath(filename) - dirname, basename = os.path.split(absname) - split_tmp = basename.rsplit('.', maxsplit=1) - if len(split_tmp) == 2: - rootname, extname = split_tmp - elif len(split_tmp) == 1: - rootname = split_tmp[0] - extname = None - else: - raise ValueError("programming error!") - return dirname, rootname, extname - -def data2file(data, filename, type=None, override=False, printable=False, **kwargs): - dirname, rootname, extname = split_filename(filename) - print_did_not_save_flag = True - if type: - extname = type - if not os.path.exists(dirname): - os.makedirs(dirname, exist_ok=True) - - if not os.path.exists(filename) or override: - if extname in ['jpg', 'png', 'jpeg']: - utils.save_image(data, filename, **kwargs) - elif extname == 'gif': - imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0) - elif extname == 'txt': - if kwargs is None: - kwargs = {} - max_step = kwargs.get('max_step') - if max_step is None: - max_step = np.Infinity - - with open(filename, 'w', encoding='utf-8') as f: - for i, e in enumerate(data): - if i < max_step: - f.write(str(e) + '\n') - else: - break - else: - raise ValueError('Do not support this type') - if printable: logger.info('Saved data to %s' % os.path.abspath(filename)) - else: - if print_did_not_save_flag: logger.info( - 'Did not save data to %s because file exists and override is False' % os.path.abspath( - filename)) - - -def file2data(filename, type=None, printable=True, **kwargs): - dirname, rootname, extname = split_filename(filename) - print_load_flag = True - if type: - extname = type - - if extname in ['pth', 'ckpt']: - data = torch.load(filename, map_location=kwargs.get('map_location')) - elif extname == 'txt': - top = kwargs.get('top', None) - with open(filename, encoding='utf-8') as f: - if top: - data = [f.readline() for _ in range(top)] - else: - data = [e for e in f.read().split('\n') if e] - elif extname == 'yaml': - with open(filename, 'r') as f: - data = yaml.load(f) - else: - raise ValueError('type can only support h5, npy, json, txt') - if printable: - if print_load_flag: - logger.info('Loaded data from %s' % os.path.abspath(filename)) - return data - - -def ensure_dirname(dirname, override=False): - if os.path.exists(dirname) and override: - logger.info('Removing dirname: %s' % os.path.abspath(dirname)) - try: - shutil.rmtree(dirname) - except OSError as e: - raise ValueError('Failed to delete %s because %s' % (dirname, e)) - - if not os.path.exists(dirname): - logger.info('Making dirname: %s' % os.path.abspath(dirname)) - os.makedirs(dirname, exist_ok=True) - - -def import_filename(filename): - spec = importlib.util.spec_from_file_location("mymodule", filename) - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - return module - - -def adaptively_load_state_dict(target, state_dict): - target_dict = target.state_dict() - - try: - common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()} - except Exception as e: - logger.warning('load error %s', e) - common_dict = {k: v for k, v in state_dict.items() if k in target_dict} - - if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \ - target.state_dict()['param_groups'][0]['params']: - logger.warning('Detected mismatch params, auto adapte state_dict to current') - common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params'] - target_dict.update(common_dict) - target.load_state_dict(target_dict) - - missing_keys = [k for k in target_dict.keys() if k not in common_dict] - unexpected_keys = [k for k in state_dict.keys() if k not in common_dict] - - if len(unexpected_keys) != 0: - logger.warning( - f"Some weights of state_dict were not used in target: {unexpected_keys}" - ) - if len(missing_keys) != 0: - logger.warning( - f"Some weights of state_dict are missing used in target {missing_keys}" - ) - if len(unexpected_keys) == 0 and len(missing_keys) == 0: - logger.warning("Strictly Loaded state_dict.") - -def set_seed(seed=42): - random.seed(seed) - os.environ['PYHTONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - -def image2pil(filename): - return Image.open(filename) - - -def image2arr(filename): - pil = image2pil(filename) - return pil2arr(pil) - - -# 格式转换 -def pil2arr(pil): - if isinstance(pil, list): - arr = np.array( - [np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil]) - else: - arr = np.array(pil) - return arr - - -def arr2pil(arr): - if arr.ndim == 3: - return Image.fromarray(arr.astype('uint8'), 'RGB') - elif arr.ndim == 4: - return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)] - else: - raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim) - -def notebook_show(*images): - from IPython.display import Image - from IPython.display import display - display(*[Image(e) for e in images]) \ No newline at end of file