pengc02's picture
all
44925e5
from calendar import c
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
import yaml
import shutil
import collections
import torch
import torch.utils.data
import torch.nn.functional as F
import numpy as np
import cv2 as cv
import glob
import datetime
import trimesh
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import importlib
# import config
from omegaconf import OmegaConf
import json
# AnimatableGaussians part
from AnimatableGaussians.network.lpips import LPIPS
from AnimatableGaussians.dataset.dataset_pose import PoseDataset
import AnimatableGaussians.utils.net_util as net_util
import AnimatableGaussians.utils.visualize_util as visualize_util
from AnimatableGaussians.utils.renderer import Renderer
from AnimatableGaussians.utils.net_util import to_cuda
from AnimatableGaussians.utils.obj_io import save_mesh_as_ply
from AnimatableGaussians.gaussians.obj_io import save_gaussians_as_ply
import AnimatableGaussians.config as ag_config
# Gaussian-Head-Avatar part
from GHA.config.config import config_reenactment
from GHA.lib.dataset.Dataset import ReenactmentDataset
from GHA.lib.dataset.DataLoaderX import DataLoaderX
from GHA.lib.module.GaussianHeadModule import GaussianHeadModule
from GHA.lib.module.SuperResolutionModule import SuperResolutionModule
from GHA.lib.module.CameraModule import CameraModule
from GHA.lib.recorder.Recorder import ReenactmentRecorder
from GHA.lib.apps.Reenactment import Reenactment
# cat utils
from calc_offline_rendering_param import calc_offline_rendering_param
import ipdb
class Avatar:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# animateble gaussians part init
self.body = config.animatablegaussians
self.body.mode = 'test'
ag_config.set_opt(self.body)
avatar_module = self.body['model'].get('module', 'AnimatableGaussians.network.avatar')
print('Import AvatarNet from %s' % avatar_module)
AvatarNet = importlib.import_module(avatar_module).AvatarNet
self.avatar_net = AvatarNet(self.body.model).to(self.device)
self.random_bg_color = self.body['train'].get('random_bg_color', True)
self.bg_color = (1., 1., 1.)
self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(self.device)
self.loss_weight = self.body['train']['loss_weight']
self.finetune_color = self.body['train']['finetune_color']
print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()])))
# gaussian head avatar part init
self.head = config.gha
self.head_config = config_reenactment()
self.head_config.load(self.head.config_path)
self.head_config = self.head_config.get_cfg()
# cat utils part init
self.cat = config.cat
@torch.no_grad()
def test_body(self):
# run the animatable gaussian test
self.avatar_net.eval()
dataset_module = self.body.get('dataset', 'MvRgbDatasetAvatarReX')
MvRgbDataset = importlib.import_module('AnimatableGaussians.dataset.dataset_mv_rgb').__getattribute__(dataset_module)
training_dataset = MvRgbDataset(**self.body['train']['data'], training = False)
if self.body['test'].get('n_pca', -1) >= 1:
training_dataset.compute_pca(n_components = self.body['test']['n_pca'])
if 'pose_data' in self.body.test:
testing_dataset = PoseDataset(**self.body['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0])
dataset_name = testing_dataset.dataset_name
seq_name = testing_dataset.seq_name
else:
# throw an error
raise ValueError('No pose data in test config')
self.dataset = testing_dataset
# iter_idx = self.load_ckpt(self.body['test']['prev_ckpt'], False)[1]
output_dir = self.body['test'].get('output_dir', None)
if output_dir is None:
raise ValueError('No output_dir in test config')
use_pca = self.body['test'].get('n_pca', -1) >= 1
if use_pca:
output_dir += '/pca_%d_sigma_%.2f' % (self.body['test'].get('n_pca', -1), float(self.body['test'].get('sigma_pca', 1.)))
else:
output_dir += '/vanilla'
print('# Output dir: \033[1;31m%s\033[0m' % output_dir)
os.makedirs(output_dir + '/live_skeleton', exist_ok = True)
os.makedirs(output_dir + '/rgb_map', exist_ok = True)
os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True)
os.makedirs(output_dir + '/torso_map', exist_ok = True)
os.makedirs(output_dir + '/mask_map', exist_ok = True)
os.makedirs(output_dir + '/posed_gaussians', exist_ok = True)
os.makedirs(output_dir + '/posed_params', exist_ok = True)
os.makedirs(output_dir + '/full_body_mask', exist_ok = True)
os.makedirs(output_dir + '/hand_only_mask', exist_ok = True)
geo_renderer = None
item_0 = self.dataset.getitem(0, training = False)
object_center = item_0['live_bounds'].mean(0)
global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient']
# set x and z to 0
global_orient[0] = 0
global_orient[2] = 0
global_orient = cv.Rodrigues(global_orient)[0]
time_start = torch.cuda.Event(enable_timing = True)
time_start_all = torch.cuda.Event(enable_timing = True)
time_end = torch.cuda.Event(enable_timing = True)
data_num = len(self.dataset)
if self.body['test'].get('fix_hand', False):
self.avatar_net.generate_mean_hands()
log_time = False
extr_list = []
intr_list = []
img_h_list = []
img_w_list = []
for idx in tqdm(range(data_num), desc = 'Rendering avatars...'):
if log_time:
time_start.record()
time_start_all.record()
img_scale = self.body['test'].get('img_scale', 1.0)
view_setting = self.body['test'].get('view_setting', 'free')
if view_setting == 'camera':
# training view setting
cam_id = self.body['test']['render_view_idx']
intr = self.dataset.intr_mats[cam_id].copy()
intr[:2] *= img_scale
extr = self.dataset.extr_mats[cam_id].copy()
img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale)
elif view_setting.startswith('free'):
# free view setting
# frame_num_per_circle = 360
# print(self.opt['test'].get('global_orient', False))
frame_num_per_circle = 360
rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi
extr = visualize_util.calc_free_mv(object_center,
tar_pos = np.array([0, 0, 2.5]),
rot_Y = rot_Y,
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
intr[:2] *= img_scale
img_h = int(1024 * img_scale)
img_w = int(1024 * img_scale)
extr_list.append(extr)
intr_list.append(intr)
img_h_list.append(img_h)
img_w_list.append(img_w)
elif view_setting.startswith('degree120'):
print('we render 120 degree')
# +- 60 degree
frame_per_cycle = 480
max_degree = 60
frame_half_cycle = frame_per_cycle // 2
if idx%frame_per_cycle < frame_per_cycle/2:
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
else:
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
# to radian
rot_Y = rot_Y * np.pi / 180
if rot_Y<0:
rot_Y = rot_Y + 2 * np.pi
# print('rot_Y: ', rot_Y)
extr = visualize_util.calc_free_mv(object_center,
tar_pos = np.array([0, 0, 2.5]),
rot_Y = rot_Y,
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
intr[:2] *= img_scale
img_h = int(1024 * img_scale)
img_w = int(1024 * img_scale)
extr_list.append(extr)
intr_list.append(intr)
img_h_list.append(img_h)
img_w_list.append(img_w)
elif view_setting.startswith('degree90'):
print('we render 90 degree')
# +- 60 degree
frame_per_cycle = 360
max_degree = 45
frame_half_cycle = frame_per_cycle // 2
if idx%frame_per_cycle < frame_per_cycle/2:
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
else:
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
# to radian
rot_Y = rot_Y * np.pi / 180
if rot_Y<0:
rot_Y = rot_Y + 2 * np.pi
# print('rot_Y: ', rot_Y)
extr = visualize_util.calc_free_mv(object_center,
tar_pos = np.array([0, 0, 2.5]),
rot_Y = rot_Y,
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
intr[:2] *= img_scale
img_h = int(1024 * img_scale)
img_w = int(1024 * img_scale)
extr_list.append(extr)
intr_list.append(intr)
img_h_list.append(img_h)
img_w_list.append(img_w)
elif view_setting.startswith('front'):
# front view setting
extr = visualize_util.calc_free_mv(object_center,
tar_pos = np.array([0, 0, 2.5]),
rot_Y = 0.,
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
intr[:2] *= img_scale
img_h = int(1024 * img_scale)
img_w = int(1024 * img_scale)
extr_list.append(extr)
intr_list.append(intr)
img_h_list.append(img_h)
img_w_list.append(img_w)
# print('extr: ', extr)
# print('intr: ', intr)
# print('img_h: ', img_h)
# print('img_w: ', img_w)
# exit()
elif view_setting.startswith('back'):
# back view setting
extr = visualize_util.calc_free_mv(object_center,
tar_pos = np.array([0, 0, 2.5]),
rot_Y = np.pi,
rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0.,
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
intr[:2] *= img_scale
img_h = int(1024 * img_scale)
img_w = int(1024 * img_scale)
elif view_setting.startswith('moving'):
# moving camera setting
extr = visualize_util.calc_free_mv(object_center,
# tar_pos = np.array([0, 0, 3.0]),
# rot_Y = -0.3,
tar_pos = np.array([0, 0, 2.5]),
rot_Y = 0.,
rot_X = 0.3 if view_setting.endswith('bird') else 0.,
global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
intr[:2] *= img_scale
img_h = int(1024 * img_scale)
img_w = int(1024 * img_scale)
elif view_setting.startswith('cano'):
cano_center = self.dataset.cano_bounds.mean(0)
extr = np.identity(4, np.float32)
extr[:3, 3] = -cano_center
rot_x = np.identity(4, np.float32)
rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0]
extr = rot_x @ extr
f_len = 5000
extr[2, 3] += f_len / 512
intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32)
# item = self.dataset.getitem(idx,
# training = False,
# extr = extr,
# intr = intr,
# img_w = 1024,
# img_h = 1024)
img_w, img_h = 1024, 1024
# item['live_smpl_v'] = item['cano_smpl_v']
# item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1)
# item['live_bounds'] = item['cano_bounds']
else:
raise ValueError('Invalid view setting for animation!')
self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list)
# also save the extr and intr and img_h and img_w to json
camera_info = []
for i in range(len(extr_list)):
camera = {}
camera['extr'] = extr_list[i].tolist()
camera['intr'] = intr_list[i].tolist()
camera['img_h'] = img_h_list[i]
camera['img_w'] = img_w_list[i]
camera_info.append(camera)
with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp:
json.dump(camera_info, fp)
getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem
item = getitem_func(
idx,
training = False,
extr = extr,
intr = intr,
img_w = img_w,
img_h = img_h
)
items = to_cuda(item, add_batch = False)
if view_setting.startswith('moving') or view_setting == 'free_moving':
current_center = items['live_bounds'].cpu().numpy().mean(0)
delta = current_center - object_center
object_center[0] += delta[0]
# object_center[1] += delta[1]
# object_center[2] += delta[2]
if log_time:
time_end.record()
torch.cuda.synchronize()
print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
time_start.record()
if self.body['test'].get('render_skeleton', False):
from AnimatableGaussians.utils.visualize_skeletons import construct_skeletons
skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy())
skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False)
if geo_renderer is None:
geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1))
extr, intr = item['extr'], item['intr']
geo_renderer.set_camera(extr, intr)
geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)])
skel_img = geo_renderer.render()[:, :, :3]
skel_img = (skel_img * 255).astype(np.uint8)
cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img)
if log_time:
time_end.record()
torch.cuda.synchronize()
print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
time_start.record()
if 'smpl_pos_map' not in items:
self.avatar_net.get_pose_map(items)
# pca
if use_pca:
mask = training_dataset.pos_map_mask
live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy()
front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2)
pose_conds = front_live_pos_map[mask]
new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.body['test'].get('sigma_pca', 2.)))
front_live_pos_map[mask] = new_pose_conds
live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2)
items.update({
'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(self.device).permute(2, 0, 1)
})
if log_time:
time_end.record()
torch.cuda.synchronize()
print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
time_start.record()
output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca)
output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca)
mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca)
if log_time:
time_end.record()
torch.cuda.synchronize()
print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
time_start.record()
if 'rgb_map' in output_wo_hand:
rgb_map_wo_hand = output_wo_hand['rgb_map']
if 'full_body_rgb_map' in mask_output:
os.makedirs(output_dir + '/full_body_mask', exist_ok = True)
full_body_mask = mask_output['full_body_rgb_map']
full_body_mask.clip_(0., 1.)
full_body_mask = (full_body_mask * 255).to(torch.uint8)
cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy())
if 'hand_only_rgb_map' in mask_output:
os.makedirs(output_dir + '/hand_only_mask', exist_ok = True)
hand_only_mask = mask_output['hand_only_rgb_map']
hand_only_mask.clip_(0., 1.)
hand_only_mask = (hand_only_mask * 255).to(torch.uint8)
cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy())
if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output:
# mask only covers hand
body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device))
body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save
hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device))
hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01
if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95
if_mask_r_hand = if_mask_r_hand.cpu().numpy()
body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device))
body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save
hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device))
hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01
if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95
if_mask_l_hand = if_mask_l_hand.cpu().numpy()
# 保存左右手被遮挡部分的mask
red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask)
blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask)
all_mask = red_mask | blue_mask
# now save 3 mask to 3 folders
os.makedirs(output_dir + '/hand_mask', exist_ok = True)
os.makedirs(output_dir + '/r_hand_mask', exist_ok = True)
os.makedirs(output_dir + '/l_hand_mask', exist_ok = True)
os.makedirs(output_dir + '/hand_visual', exist_ok = True)
all_mask = (all_mask * 255).to(torch.uint8)
cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy())
r_hand_mask = (body_red_mask * 255).to(torch.uint8)
cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy())
l_hand_mask = (body_blue_mask * 255).to(torch.uint8)
cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy())
hand_visual = [if_mask_r_hand, if_mask_l_hand]
# save to npy
with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f:
np.save(f, hand_visual)
# now build sleeve_mask
if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output:
os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True)
os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True)
mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128)
mask = mask.cpu().numpy().astype(np.uint8)
# 定义一个结构元素,可以调整其大小以改变膨胀的程度
kernel = np.ones((5, 5), np.uint8)
# 应用膨胀操作
mask = cv.dilate(mask, kernel, iterations=3)
mask = torch.tensor(mask).to(self.device)
left_hand_mask = mask_output['left_hand_rgb_map']
left_hand_mask.clip_(0., 1.)
# non white part is mask
left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask)
left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01
# dele two hand mask
left_hand_mask = left_hand_mask & ~mask
right_hand_mask = mask_output['right_hand_rgb_map']
right_hand_mask.clip_(0., 1.)
right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask)
right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01
right_hand_mask = right_hand_mask & ~mask
# save
left_hand_mask = (left_hand_mask * 255).to(torch.uint8)
cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy())
right_hand_mask = (right_hand_mask * 255).to(torch.uint8)
cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy())
rgb_map = output['rgb_map']
rgb_map.clip_(0., 1.)
rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy()
cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map)
# 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map
if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output:
rgb_map_wo_hand = output_wo_hand['rgb_map']
rgb_map_wo_hand.clip_(0., 1.)
rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy()
r_mask = (r_hand_mask>128).cpu().numpy()
l_mask = (l_hand_mask>128).cpu().numpy()
mask = r_mask | l_mask
mask = mask.astype(np.uint8)
# 定义一个结构元素,可以调整其大小以改变膨胀的程度
kernel = np.ones((5, 5), np.uint8)
# 应用膨胀操作
mask = cv.dilate(mask, kernel, iterations=3)
mask = mask.astype(np.bool_)
mask = np.expand_dims(mask, axis=2)
# print('mask shape: ', mask.shape)
import ipdb
# ipdb.set_trace()
mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask
cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.png' % item['data_idx'], mix)
if 'torso_map' in output:
os.makedirs(output_dir + '/torso_map', exist_ok = True)
torso_map = output['torso_map'][:, :, 0]
torso_map.clip_(0., 1.)
torso_map = (torso_map * 255).to(torch.uint8)
cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy())
if 'mask_map' in output:
os.makedirs(output_dir + '/mask_map', exist_ok = True)
mask_map = output['mask_map'][:, :, 0]
mask_map.clip_(0., 1.)
mask_map = (mask_map * 255).to(torch.uint8)
cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy())
if self.body['test'].get('save_tex_map', False):
os.makedirs(output_dir + '/cano_tex_map', exist_ok = True)
cano_tex_map = output['cano_tex_map']
cano_tex_map.clip_(0., 1.)
cano_tex_map = (cano_tex_map * 255).to(torch.uint8)
cv.imwrite(output_dir + '/cano_tex_map/%08d.png' % item['data_idx'], cano_tex_map.cpu().numpy())
if self.body['test'].get('save_ply', False):
if item['data_idx'] == 0:
save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians'])
for k in output['posed_gaussians'].keys():
if isinstance(output['posed_gaussians'][k], torch.Tensor):
output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy()
np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians'])
np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']),
betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(),
global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(),
transl=item['transl'].reshape([-1]).detach().cpu().numpy(),
body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy())
if log_time:
time_end.record()
torch.cuda.synchronize()
print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.))
torch.cuda.empty_cache()
def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths):
with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp:
outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \
"white_background=False, data_device='cuda', eval=False)" % (
3, self.body['train']['data']['data_dir'], dump_dir)
fp.write(outstr)
with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp:
cam_jsons = []
for ci in range(len(extrs)):
extr, intr = extrs[ci], intrs[ci]
img_h, img_w = img_heights[ci], img_widths[ci]
w2c = extr
c2w = np.linalg.inv(w2c)
pos = c2w[:3, 3]
rot = c2w[:3, :3]
serializable_array_2d = [x.tolist() for x in rot]
camera_entry = {
'id': ci,
'img_name': '%08d' % ci,
'width': int(img_w),
'height': int(img_h),
'position': pos.tolist(),
'rotation': serializable_array_2d,
'fy': float(intr[1, 1]),
'fx': float(intr[0, 0]),
}
cam_jsons.append(camera_entry)
json.dump(cam_jsons, fp)
return
def test_head(self):
dataset = ReenactmentDataset(self.head_config.dataset)
dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True)
device = torch.device('cuda:%d' % self.head_config.gpu_id)
gaussianhead_state_dict = torch.load(self.head_config.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage)
gaussianhead = GaussianHeadModule(self.head_config.gaussianheadmodule,
xyz=gaussianhead_state_dict['xyz'],
feature=gaussianhead_state_dict['feature'],
landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device)
gaussianhead.load_state_dict(gaussianhead_state_dict)
supres = SuperResolutionModule(self.head_config.supresmodule).to(device)
supres.load_state_dict(torch.load(self.head_config.load_supres_checkpoint, map_location=lambda storage, loc: storage))
camera = CameraModule()
recorder = ReenactmentRecorder(self.head_config.recorder)
app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, self.head_config.gpu_id, dataset.freeview)
if self.head.offline_rendering_param_fpath is None:
app.run(stop_fid=800)
else:
app.run_for_offline_stitching(self.head.offline_rendering_param_fpath)
def cal_cat_param(self):
calc_offline_rendering_param(
self.cat.body_gaussian_root_dir,
self.cat.ref_head_gaussian_path,
self.cat.ref_head_param_path,
self.cat.render_cam_fpath,
self.cat.body_head_blending_param_path
)
if __name__ == '__main__':
conf = OmegaConf.load('configs/example.yaml')
avatar = Avatar(conf)
avatar.test_body()
# avatar.test_head()