|
import os |
|
import math |
|
import cv2 |
|
import trimesh |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import nvdiffrast.torch as dr |
|
from mesh import Mesh, safe_normalize |
|
|
|
def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'): |
|
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" |
|
y = x.permute(0, 3, 1, 2) |
|
if x.shape[1] > size[0] and x.shape[2] > size[1]: |
|
y = torch.nn.functional.interpolate(y, size, mode=min) |
|
else: |
|
if mag == 'bilinear' or mag == 'bicubic': |
|
y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) |
|
else: |
|
y = torch.nn.functional.interpolate(y, size, mode=mag) |
|
return y.permute(0, 2, 3, 1).contiguous() |
|
|
|
def scale_img_hwc(x, size, mag='bilinear', min='bilinear'): |
|
return scale_img_nhwc(x[None, ...], size, mag, min)[0] |
|
|
|
def scale_img_nhw(x, size, mag='bilinear', min='bilinear'): |
|
return scale_img_nhwc(x[..., None], size, mag, min)[..., 0] |
|
|
|
def scale_img_hw(x, size, mag='bilinear', min='bilinear'): |
|
return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0] |
|
|
|
def trunc_rev_sigmoid(x, eps=1e-6): |
|
x = x.clamp(eps, 1 - eps) |
|
return torch.log(x / (1 - x)) |
|
|
|
def make_divisible(x, m=8): |
|
return int(math.ceil(x / m) * m) |
|
|
|
class Renderer(nn.Module): |
|
def __init__(self, opt): |
|
|
|
super().__init__() |
|
|
|
self.opt = opt |
|
|
|
self.mesh = Mesh.load(self.opt.mesh, resize=False) |
|
|
|
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'): |
|
self.glctx = dr.RasterizeGLContext() |
|
else: |
|
self.glctx = dr.RasterizeCudaContext() |
|
|
|
|
|
self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v)) |
|
self.raw_albedo = nn.Parameter(trunc_rev_sigmoid(self.mesh.albedo)) |
|
|
|
|
|
def get_params(self): |
|
|
|
params = [ |
|
{'params': self.raw_albedo, 'lr': self.opt.texture_lr}, |
|
] |
|
|
|
if self.opt.train_geo: |
|
params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr}) |
|
|
|
return params |
|
|
|
@torch.no_grad() |
|
def export_mesh(self, save_path): |
|
self.mesh.v = (self.mesh.v + self.v_offsets).detach() |
|
self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach()) |
|
self.mesh.write(save_path) |
|
|
|
|
|
def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'): |
|
|
|
|
|
if ssaa != 1: |
|
h = make_divisible(h0 * ssaa, 8) |
|
w = make_divisible(w0 * ssaa, 8) |
|
else: |
|
h, w = h0, w0 |
|
|
|
results = {} |
|
|
|
|
|
if self.opt.train_geo: |
|
v = self.mesh.v + self.v_offsets |
|
else: |
|
v = self.mesh.v |
|
|
|
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) |
|
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device) |
|
|
|
|
|
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0) |
|
v_clip = v_cam @ proj.T |
|
|
|
rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w)) |
|
|
|
alpha = (rast[0, ..., 3:] > 0).float() |
|
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) |
|
depth = depth.squeeze(0) |
|
|
|
texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all') |
|
albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) |
|
albedo = torch.sigmoid(albedo) |
|
|
|
if self.opt.train_geo: |
|
i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long() |
|
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] |
|
|
|
face_normals = torch.cross(v1 - v0, v2 - v0) |
|
face_normals = safe_normalize(face_normals) |
|
|
|
vn = torch.zeros_like(v) |
|
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) |
|
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) |
|
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) |
|
|
|
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) |
|
else: |
|
vn = self.mesh.vn |
|
|
|
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn) |
|
normal = safe_normalize(normal[0]) |
|
|
|
|
|
rot_normal = normal @ pose[:3, :3] |
|
viewcos = rot_normal[..., [2]] |
|
|
|
|
|
albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) |
|
albedo = alpha * albedo + (1 - alpha) * bg_color |
|
|
|
|
|
if ssaa != 1: |
|
albedo = scale_img_hwc(albedo, (h0, w0)) |
|
alpha = scale_img_hwc(alpha, (h0, w0)) |
|
depth = scale_img_hwc(depth, (h0, w0)) |
|
normal = scale_img_hwc(normal, (h0, w0)) |
|
viewcos = scale_img_hwc(viewcos, (h0, w0)) |
|
|
|
results['image'] = albedo.clamp(0, 1) |
|
results['alpha'] = alpha |
|
results['depth'] = depth |
|
results['normal'] = (normal + 1) / 2 |
|
results['viewcos'] = viewcos |
|
|
|
return results |