Spaces:
Running
Running
from calendar import c | |
import os | |
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
# os.environ['TORCH_USE_CUDA_DSA'] = '1' | |
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
import yaml | |
import shutil | |
import collections | |
import torch | |
import torch.utils.data | |
import torch.nn.functional as F | |
import numpy as np | |
import cv2 as cv | |
import glob | |
import datetime | |
import trimesh | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
import importlib | |
# import config | |
from omegaconf import OmegaConf | |
import json | |
# AnimatableGaussians part | |
from AnimatableGaussians.network.lpips import LPIPS | |
from AnimatableGaussians.dataset.dataset_pose import PoseDataset | |
import AnimatableGaussians.utils.net_util as net_util | |
import AnimatableGaussians.utils.visualize_util as visualize_util | |
from AnimatableGaussians.utils.renderer import Renderer | |
from AnimatableGaussians.utils.net_util import to_cuda | |
from AnimatableGaussians.utils.obj_io import save_mesh_as_ply | |
from AnimatableGaussians.gaussians.obj_io import save_gaussians_as_ply | |
import AnimatableGaussians.config as ag_config | |
# Gaussian-Head-Avatar part | |
from GHA.config.config import config_reenactment | |
from GHA.lib.dataset.Dataset import ReenactmentDataset | |
from GHA.lib.dataset.DataLoaderX import DataLoaderX | |
from GHA.lib.module.GaussianHeadModule import GaussianHeadModule | |
from GHA.lib.module.SuperResolutionModule import SuperResolutionModule | |
from GHA.lib.module.CameraModule import CameraModule | |
from GHA.lib.recorder.Recorder import ReenactmentRecorder | |
from GHA.lib.apps.Reenactment import Reenactment | |
# cat utils | |
from calc_offline_rendering_param import calc_offline_rendering_param | |
import ipdb | |
class Avatar: | |
def __init__(self, config): | |
self.config = config | |
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
# animateble gaussians part init | |
self.body = config.animatablegaussians | |
self.body.mode = 'test' | |
ag_config.set_opt(self.body) | |
avatar_module = self.body['model'].get('module', 'AnimatableGaussians.network.avatar') | |
print('Import AvatarNet from %s' % avatar_module) | |
AvatarNet = importlib.import_module(avatar_module).AvatarNet | |
self.avatar_net = AvatarNet(self.body.model).to(self.device) | |
self.random_bg_color = self.body['train'].get('random_bg_color', True) | |
self.bg_color = (1., 1., 1.) | |
self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(self.device) | |
self.loss_weight = self.body['train']['loss_weight'] | |
self.finetune_color = self.body['train']['finetune_color'] | |
print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()]))) | |
# gaussian head avatar part init | |
self.head = config.gha | |
self.head_config = config_reenactment() | |
self.head_config.load(self.head.config_path) | |
self.head_config = self.head_config.get_cfg() | |
# cat utils part init | |
self.cat = config.cat | |
def test_body(self): | |
# run the animatable gaussian test | |
self.avatar_net.eval() | |
dataset_module = self.body.get('dataset', 'MvRgbDatasetAvatarReX') | |
MvRgbDataset = importlib.import_module('AnimatableGaussians.dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
training_dataset = MvRgbDataset(**self.body['train']['data'], training = False) | |
if self.body['test'].get('n_pca', -1) >= 1: | |
training_dataset.compute_pca(n_components = self.body['test']['n_pca']) | |
if 'pose_data' in self.body.test: | |
testing_dataset = PoseDataset(**self.body['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0]) | |
dataset_name = testing_dataset.dataset_name | |
seq_name = testing_dataset.seq_name | |
else: | |
# throw an error | |
raise ValueError('No pose data in test config') | |
self.dataset = testing_dataset | |
# iter_idx = self.load_ckpt(self.body['test']['prev_ckpt'], False)[1] | |
output_dir = self.body['test'].get('output_dir', None) | |
if output_dir is None: | |
raise ValueError('No output_dir in test config') | |
use_pca = self.body['test'].get('n_pca', -1) >= 1 | |
if use_pca: | |
output_dir += '/pca_%d_sigma_%.2f' % (self.body['test'].get('n_pca', -1), float(self.body['test'].get('sigma_pca', 1.))) | |
else: | |
output_dir += '/vanilla' | |
print('# Output dir: \033[1;31m%s\033[0m' % output_dir) | |
os.makedirs(output_dir + '/live_skeleton', exist_ok = True) | |
os.makedirs(output_dir + '/rgb_map', exist_ok = True) | |
os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True) | |
os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
os.makedirs(output_dir + '/posed_gaussians', exist_ok = True) | |
os.makedirs(output_dir + '/posed_params', exist_ok = True) | |
os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
geo_renderer = None | |
item_0 = self.dataset.getitem(0, training = False) | |
object_center = item_0['live_bounds'].mean(0) | |
global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient'] | |
# set x and z to 0 | |
global_orient[0] = 0 | |
global_orient[2] = 0 | |
global_orient = cv.Rodrigues(global_orient)[0] | |
time_start = torch.cuda.Event(enable_timing = True) | |
time_start_all = torch.cuda.Event(enable_timing = True) | |
time_end = torch.cuda.Event(enable_timing = True) | |
data_num = len(self.dataset) | |
if self.body['test'].get('fix_hand', False): | |
self.avatar_net.generate_mean_hands() | |
log_time = False | |
extr_list = [] | |
intr_list = [] | |
img_h_list = [] | |
img_w_list = [] | |
for idx in tqdm(range(data_num), desc = 'Rendering avatars...'): | |
if log_time: | |
time_start.record() | |
time_start_all.record() | |
img_scale = self.body['test'].get('img_scale', 1.0) | |
view_setting = self.body['test'].get('view_setting', 'free') | |
if view_setting == 'camera': | |
# training view setting | |
cam_id = self.body['test']['render_view_idx'] | |
intr = self.dataset.intr_mats[cam_id].copy() | |
intr[:2] *= img_scale | |
extr = self.dataset.extr_mats[cam_id].copy() | |
img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale) | |
elif view_setting.startswith('free'): | |
# free view setting | |
# frame_num_per_circle = 360 | |
# print(self.opt['test'].get('global_orient', False)) | |
frame_num_per_circle = 360 | |
rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = rot_Y, | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
elif view_setting.startswith('degree120'): | |
print('we render 120 degree') | |
# +- 60 degree | |
frame_per_cycle = 480 | |
max_degree = 60 | |
frame_half_cycle = frame_per_cycle // 2 | |
if idx%frame_per_cycle < frame_per_cycle/2: | |
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
else: | |
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# to radian | |
rot_Y = rot_Y * np.pi / 180 | |
if rot_Y<0: | |
rot_Y = rot_Y + 2 * np.pi | |
# print('rot_Y: ', rot_Y) | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = rot_Y, | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
elif view_setting.startswith('degree90'): | |
print('we render 90 degree') | |
# +- 60 degree | |
frame_per_cycle = 360 | |
max_degree = 45 | |
frame_half_cycle = frame_per_cycle // 2 | |
if idx%frame_per_cycle < frame_per_cycle/2: | |
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
else: | |
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# to radian | |
rot_Y = rot_Y * np.pi / 180 | |
if rot_Y<0: | |
rot_Y = rot_Y + 2 * np.pi | |
# print('rot_Y: ', rot_Y) | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = rot_Y, | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
elif view_setting.startswith('front'): | |
# front view setting | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = 0., | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
# print('extr: ', extr) | |
# print('intr: ', intr) | |
# print('img_h: ', img_h) | |
# print('img_w: ', img_w) | |
# exit() | |
elif view_setting.startswith('back'): | |
# back view setting | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = np.pi, | |
rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
elif view_setting.startswith('moving'): | |
# moving camera setting | |
extr = visualize_util.calc_free_mv(object_center, | |
# tar_pos = np.array([0, 0, 3.0]), | |
# rot_Y = -0.3, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = 0., | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.body['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
elif view_setting.startswith('cano'): | |
cano_center = self.dataset.cano_bounds.mean(0) | |
extr = np.identity(4, np.float32) | |
extr[:3, 3] = -cano_center | |
rot_x = np.identity(4, np.float32) | |
rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0] | |
extr = rot_x @ extr | |
f_len = 5000 | |
extr[2, 3] += f_len / 512 | |
intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32) | |
# item = self.dataset.getitem(idx, | |
# training = False, | |
# extr = extr, | |
# intr = intr, | |
# img_w = 1024, | |
# img_h = 1024) | |
img_w, img_h = 1024, 1024 | |
# item['live_smpl_v'] = item['cano_smpl_v'] | |
# item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1) | |
# item['live_bounds'] = item['cano_bounds'] | |
else: | |
raise ValueError('Invalid view setting for animation!') | |
self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list) | |
# also save the extr and intr and img_h and img_w to json | |
camera_info = [] | |
for i in range(len(extr_list)): | |
camera = {} | |
camera['extr'] = extr_list[i].tolist() | |
camera['intr'] = intr_list[i].tolist() | |
camera['img_h'] = img_h_list[i] | |
camera['img_w'] = img_w_list[i] | |
camera_info.append(camera) | |
with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp: | |
json.dump(camera_info, fp) | |
getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem | |
item = getitem_func( | |
idx, | |
training = False, | |
extr = extr, | |
intr = intr, | |
img_w = img_w, | |
img_h = img_h | |
) | |
items = to_cuda(item, add_batch = False) | |
if view_setting.startswith('moving') or view_setting == 'free_moving': | |
current_center = items['live_bounds'].cpu().numpy().mean(0) | |
delta = current_center - object_center | |
object_center[0] += delta[0] | |
# object_center[1] += delta[1] | |
# object_center[2] += delta[2] | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
if self.body['test'].get('render_skeleton', False): | |
from AnimatableGaussians.utils.visualize_skeletons import construct_skeletons | |
skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy()) | |
skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False) | |
if geo_renderer is None: | |
geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1)) | |
extr, intr = item['extr'], item['intr'] | |
geo_renderer.set_camera(extr, intr) | |
geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)]) | |
skel_img = geo_renderer.render()[:, :, :3] | |
skel_img = (skel_img * 255).astype(np.uint8) | |
cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
if 'smpl_pos_map' not in items: | |
self.avatar_net.get_pose_map(items) | |
# pca | |
if use_pca: | |
mask = training_dataset.pos_map_mask | |
live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy() | |
front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2) | |
pose_conds = front_live_pos_map[mask] | |
new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.body['test'].get('sigma_pca', 2.))) | |
front_live_pos_map[mask] = new_pose_conds | |
live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2) | |
items.update({ | |
'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(self.device).permute(2, 0, 1) | |
}) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca) | |
output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca) | |
mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
if 'rgb_map' in output_wo_hand: | |
rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
if 'full_body_rgb_map' in mask_output: | |
os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
full_body_mask = mask_output['full_body_rgb_map'] | |
full_body_mask.clip_(0., 1.) | |
full_body_mask = (full_body_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy()) | |
if 'hand_only_rgb_map' in mask_output: | |
os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
hand_only_mask = mask_output['hand_only_rgb_map'] | |
hand_only_mask.clip_(0., 1.) | |
hand_only_mask = (hand_only_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy()) | |
if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
# mask only covers hand | |
body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device)) | |
body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save | |
hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device)) | |
hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01 | |
if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95 | |
if_mask_r_hand = if_mask_r_hand.cpu().numpy() | |
body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device)) | |
body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save | |
hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device)) | |
hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01 | |
if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95 | |
if_mask_l_hand = if_mask_l_hand.cpu().numpy() | |
# 保存左右手被遮挡部分的mask | |
red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask) | |
blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask) | |
all_mask = red_mask | blue_mask | |
# now save 3 mask to 3 folders | |
os.makedirs(output_dir + '/hand_mask', exist_ok = True) | |
os.makedirs(output_dir + '/r_hand_mask', exist_ok = True) | |
os.makedirs(output_dir + '/l_hand_mask', exist_ok = True) | |
os.makedirs(output_dir + '/hand_visual', exist_ok = True) | |
all_mask = (all_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy()) | |
r_hand_mask = (body_red_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy()) | |
l_hand_mask = (body_blue_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy()) | |
hand_visual = [if_mask_r_hand, if_mask_l_hand] | |
# save to npy | |
with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f: | |
np.save(f, hand_visual) | |
# now build sleeve_mask | |
if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output: | |
os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True) | |
os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True) | |
mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128) | |
mask = mask.cpu().numpy().astype(np.uint8) | |
# 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
kernel = np.ones((5, 5), np.uint8) | |
# 应用膨胀操作 | |
mask = cv.dilate(mask, kernel, iterations=3) | |
mask = torch.tensor(mask).to(self.device) | |
left_hand_mask = mask_output['left_hand_rgb_map'] | |
left_hand_mask.clip_(0., 1.) | |
# non white part is mask | |
left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask) | |
left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01 | |
# dele two hand mask | |
left_hand_mask = left_hand_mask & ~mask | |
right_hand_mask = mask_output['right_hand_rgb_map'] | |
right_hand_mask.clip_(0., 1.) | |
right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask) | |
right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01 | |
right_hand_mask = right_hand_mask & ~mask | |
# save | |
left_hand_mask = (left_hand_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy()) | |
right_hand_mask = (right_hand_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy()) | |
rgb_map = output['rgb_map'] | |
rgb_map.clip_(0., 1.) | |
rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy() | |
cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map) | |
# 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map | |
if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
rgb_map_wo_hand.clip_(0., 1.) | |
rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy() | |
r_mask = (r_hand_mask>128).cpu().numpy() | |
l_mask = (l_hand_mask>128).cpu().numpy() | |
mask = r_mask | l_mask | |
mask = mask.astype(np.uint8) | |
# 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
kernel = np.ones((5, 5), np.uint8) | |
# 应用膨胀操作 | |
mask = cv.dilate(mask, kernel, iterations=3) | |
mask = mask.astype(np.bool_) | |
mask = np.expand_dims(mask, axis=2) | |
# print('mask shape: ', mask.shape) | |
import ipdb | |
# ipdb.set_trace() | |
mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask | |
cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.png' % item['data_idx'], mix) | |
if 'torso_map' in output: | |
os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
torso_map = output['torso_map'][:, :, 0] | |
torso_map.clip_(0., 1.) | |
torso_map = (torso_map * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy()) | |
if 'mask_map' in output: | |
os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
mask_map = output['mask_map'][:, :, 0] | |
mask_map.clip_(0., 1.) | |
mask_map = (mask_map * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy()) | |
if self.body['test'].get('save_tex_map', False): | |
os.makedirs(output_dir + '/cano_tex_map', exist_ok = True) | |
cano_tex_map = output['cano_tex_map'] | |
cano_tex_map.clip_(0., 1.) | |
cano_tex_map = (cano_tex_map * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/cano_tex_map/%08d.png' % item['data_idx'], cano_tex_map.cpu().numpy()) | |
if self.body['test'].get('save_ply', False): | |
if item['data_idx'] == 0: | |
save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians']) | |
for k in output['posed_gaussians'].keys(): | |
if isinstance(output['posed_gaussians'][k], torch.Tensor): | |
output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy() | |
np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians']) | |
np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']), | |
betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(), | |
global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(), | |
transl=item['transl'].reshape([-1]).detach().cpu().numpy(), | |
body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy()) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.)) | |
torch.cuda.empty_cache() | |
def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths): | |
with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp: | |
outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \ | |
"white_background=False, data_device='cuda', eval=False)" % ( | |
3, self.body['train']['data']['data_dir'], dump_dir) | |
fp.write(outstr) | |
with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp: | |
cam_jsons = [] | |
for ci in range(len(extrs)): | |
extr, intr = extrs[ci], intrs[ci] | |
img_h, img_w = img_heights[ci], img_widths[ci] | |
w2c = extr | |
c2w = np.linalg.inv(w2c) | |
pos = c2w[:3, 3] | |
rot = c2w[:3, :3] | |
serializable_array_2d = [x.tolist() for x in rot] | |
camera_entry = { | |
'id': ci, | |
'img_name': '%08d' % ci, | |
'width': int(img_w), | |
'height': int(img_h), | |
'position': pos.tolist(), | |
'rotation': serializable_array_2d, | |
'fy': float(intr[1, 1]), | |
'fx': float(intr[0, 0]), | |
} | |
cam_jsons.append(camera_entry) | |
json.dump(cam_jsons, fp) | |
return | |
def test_head(self): | |
dataset = ReenactmentDataset(self.head_config.dataset) | |
dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True) | |
device = torch.device('cuda:%d' % self.head_config.gpu_id) | |
gaussianhead_state_dict = torch.load(self.head_config.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage) | |
gaussianhead = GaussianHeadModule(self.head_config.gaussianheadmodule, | |
xyz=gaussianhead_state_dict['xyz'], | |
feature=gaussianhead_state_dict['feature'], | |
landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device) | |
gaussianhead.load_state_dict(gaussianhead_state_dict) | |
supres = SuperResolutionModule(self.head_config.supresmodule).to(device) | |
supres.load_state_dict(torch.load(self.head_config.load_supres_checkpoint, map_location=lambda storage, loc: storage)) | |
camera = CameraModule() | |
recorder = ReenactmentRecorder(self.head_config.recorder) | |
app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, self.head_config.gpu_id, dataset.freeview) | |
if self.head.offline_rendering_param_fpath is None: | |
app.run(stop_fid=800) | |
else: | |
app.run_for_offline_stitching(self.head.offline_rendering_param_fpath) | |
def cal_cat_param(self): | |
calc_offline_rendering_param( | |
self.cat.body_gaussian_root_dir, | |
self.cat.ref_head_gaussian_path, | |
self.cat.ref_head_param_path, | |
self.cat.render_cam_fpath, | |
self.cat.body_head_blending_param_path | |
) | |
if __name__ == '__main__': | |
conf = OmegaConf.load('configs/example.yaml') | |
avatar = Avatar(conf) | |
avatar.test_body() | |
# avatar.test_head() |