Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import torch | |
import torch.nn.functional as F | |
import math | |
import random | |
import numpy as np | |
def positional_encoding(p, size, pe='normal', use_pos=False): | |
if pe == 'gauss': | |
p_transformed = np.pi * p @ size | |
p_transformed = torch.cat( | |
[torch.sin(p_transformed), torch.cos(p_transformed)], dim=-1) | |
else: | |
p_transformed = torch.cat([torch.cat( | |
[torch.sin((2 ** i) * np.pi * p), | |
torch.cos((2 ** i) * np.pi * p)], | |
dim=-1) for i in range(size)], dim=-1) | |
if use_pos: | |
p_transformed = torch.cat([p_transformed, p], -1) | |
return p_transformed | |
def upsample(img_nerf, size, filter=None): | |
up = size // img_nerf.size(-1) | |
if up <= 1: | |
return img_nerf | |
if filter is not None: | |
from torch_utils.ops import upfirdn2d | |
for _ in range(int(math.log2(up))): | |
img_nerf = upfirdn2d.downsample2d(img_nerf, filter, up=2) | |
else: | |
img_nerf = F.interpolate(img_nerf, (size, size), mode='bilinear', align_corners=False) | |
return img_nerf | |
def downsample(img0, size, filter=None): | |
down = img0.size(-1) // size | |
if down <= 1: | |
return img0 | |
if filter is not None: | |
from torch_utils.ops import upfirdn2d | |
for _ in range(int(math.log2(down))): | |
img0 = upfirdn2d.downsample2d(img0, filter, down=2) | |
else: | |
img0 = F.interpolate(img0, (size, size), mode='bilinear', align_corners=False) | |
return img0 | |
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: | |
""" | |
Normalize vector lengths. | |
""" | |
return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) | |
def repeat_vecs(vecs, n, dim=0): | |
return torch.stack(n*[vecs], dim=dim) | |
def get_grids(H, W, device, align=True): | |
ch = 1 if align else 1 - (1 / H) | |
cw = 1 if align else 1 - (1 / W) | |
x, y = torch.meshgrid(torch.linspace(-cw, cw, W, device=device), | |
torch.linspace(ch, -ch, H, device=device)) | |
return torch.stack([x, y], -1) | |
def local_ensemble(pi, po, resolution): | |
ii = range(resolution) | |
ia = torch.tensor([max((i - 1)//2, 0) for i in ii]).long() | |
ib = torch.tensor([min((i + 1)//2, resolution//2-1) for i in ii]).long() | |
ul = torch.meshgrid(ia, ia) | |
ur = torch.meshgrid(ia, ib) | |
ll = torch.meshgrid(ib, ia) | |
lr = torch.meshgrid(ib, ib) | |
d_ul, p_ul = po - pi[ul], torch.stack(ul, -1) | |
d_ur, p_ur = po - pi[ur], torch.stack(ur, -1) | |
d_ll, p_ll = po - pi[ll], torch.stack(ll, -1) | |
d_lr, p_lr = po - pi[lr], torch.stack(lr, -1) | |
c_ul = d_ul.prod(dim=-1).abs() | |
c_ur = d_ur.prod(dim=-1).abs() | |
c_ll = d_ll.prod(dim=-1).abs() | |
c_lr = d_lr.prod(dim=-1).abs() | |
D = torch.stack([d_ul, d_ur, d_ll, d_lr], 0) | |
P = torch.stack([p_ul, p_ur, p_ll, p_lr], 0) | |
C = torch.stack([c_ul, c_ur, c_ll, c_lr], 0) | |
C = C / C.sum(dim=0, keepdim=True) | |
return D, P, C | |
def get_initial_rays_trig(num_steps, fov, resolution, ray_start, ray_end, device='cpu'): | |
"""Returns sample points, z_vals, ray directions in camera space.""" | |
W, H = resolution | |
# Create full screen NDC (-1 to +1) coords [x, y, 0, 1]. | |
# Y is flipped to follow image memory layouts. | |
x, y = torch.meshgrid(torch.linspace(-1, 1, W, device=device), | |
torch.linspace(1, -1, H, device=device)) | |
x = x.T.flatten() | |
y = y.T.flatten() | |
z = -torch.ones_like(x, device=device) / math.tan((2 * math.pi * fov / 360)/2) | |
rays_d_cam = normalize_vecs(torch.stack([x, y, z], -1)) | |
z_vals = torch.linspace(ray_start, ray_end, num_steps, device=device).reshape(1, num_steps, 1).repeat(W*H, 1, 1) | |
points = rays_d_cam.unsqueeze(1).repeat(1, num_steps, 1) * z_vals | |
return points, z_vals, rays_d_cam | |
def sample_camera_positions( | |
device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1, | |
horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'): | |
""" | |
Samples n random locations along a sphere of radius r. | |
Uses a gaussian distribution for pitch and yaw | |
""" | |
if mode == 'uniform': | |
theta = (torch.rand((n, 1),device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean | |
phi = (torch.rand((n, 1),device=device) - 0.5) * 2 * vertical_stddev + vertical_mean | |
elif mode == 'normal' or mode == 'gaussian': | |
theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean | |
phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean | |
elif mode == 'hybrid': | |
if random.random() < 0.5: | |
theta = (torch.rand((n, 1),device=device) - 0.5) * 2 * horizontal_stddev * 2 + horizontal_mean | |
phi = (torch.rand((n, 1),device=device) - 0.5) * 2 * vertical_stddev * 2 + vertical_mean | |
else: | |
theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean | |
phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean | |
else: | |
phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean | |
theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean | |
phi = torch.clamp(phi, 1e-5, math.pi - 1e-5) | |
output_points = torch.zeros((n, 3), device=device)# torch.cuda.FloatTensor(n, 3).fill_(0)#torch.zeros((n, 3)) | |
output_points[:, 0:1] = r*torch.sin(phi) * torch.cos(theta) | |
output_points[:, 2:3] = r*torch.sin(phi) * torch.sin(theta) | |
output_points[:, 1:2] = r*torch.cos(phi) | |
return output_points, phi, theta | |
def perturb_points(points, z_vals, ray_directions, device): | |
distance_between_points = z_vals[:,:,1:2,:] - z_vals[:,:,0:1,:] | |
offset = (torch.rand(z_vals.shape, device=device)-0.5) * distance_between_points | |
z_vals = z_vals + offset | |
points = points + offset * ray_directions.unsqueeze(2) | |
return points, z_vals | |
def create_cam2world_matrix(forward_vector, origin, device=None): | |
"""Takes in the direction the camera is pointing and the camera origin and returns a world2cam matrix.""" | |
forward_vector = normalize_vecs(forward_vector) | |
up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector) | |
left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) | |
up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1)) | |
rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) | |
rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1) | |
translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) | |
translation_matrix[:, :3, 3] = origin | |
cam2world = translation_matrix @ rotation_matrix | |
return cam2world | |
def transform_sampled_points( | |
points, z_vals, ray_directions, device, | |
h_stddev=1, v_stddev=1, h_mean=math.pi * 0.5, | |
v_mean=math.pi * 0.5, mode='normal'): | |
""" | |
points: batch_size x total_pixels x num_steps x 3 | |
z_vals: batch_size x total_pixels x num_steps | |
""" | |
n, num_rays, num_steps, channels = points.shape | |
points, z_vals = perturb_points(points, z_vals, ray_directions, device) | |
camera_origin, pitch, yaw = sample_camera_positions( | |
n=points.shape[0], r=1, | |
horizontal_stddev=h_stddev, vertical_stddev=v_stddev, | |
horizontal_mean=h_mean, vertical_mean=v_mean, | |
device=device, mode=mode) | |
forward_vector = normalize_vecs(-camera_origin) | |
cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device) | |
points_homogeneous = torch.ones((points.shape[0], points.shape[1], points.shape[2], points.shape[3] + 1), device=device) | |
points_homogeneous[:, :, :, :3] = points | |
# should be n x 4 x 4 , n x r^2 x num_steps x 4 | |
transformed_points = torch.bmm(cam2world_matrix, points_homogeneous.reshape(n, -1, 4).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, num_steps, 4) | |
transformed_ray_directions = torch.bmm(cam2world_matrix[..., :3, :3], ray_directions.reshape(n, -1, 3).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, 3) | |
homogeneous_origins = torch.zeros((n, 4, num_rays), device=device) | |
homogeneous_origins[:, 3, :] = 1 | |
transformed_ray_origins = torch.bmm(cam2world_matrix, homogeneous_origins).permute(0, 2, 1).reshape(n, num_rays, 4)[..., :3] | |
return transformed_points[..., :3], z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw | |
def integration( | |
rgb_sigma, z_vals, device, noise_std=0.5, | |
last_back=False, white_back=False, clamp_mode=None, fill_mode=None): | |
rgbs = rgb_sigma[..., :3] | |
sigmas = rgb_sigma[..., 3:] | |
deltas = z_vals[..., 1:, :] - z_vals[..., :-1, :] | |
delta_inf = 1e10 * torch.ones_like(deltas[..., :1, :]) | |
deltas = torch.cat([deltas, delta_inf], -2) | |
if noise_std > 0: | |
noise = torch.randn(sigmas.shape, device=device) * noise_std | |
else: | |
noise = 0 | |
if clamp_mode == 'softplus': | |
alphas = 1 - torch.exp(-deltas * (F.softplus(sigmas + noise))) | |
elif clamp_mode == 'relu': | |
alphas = 1 - torch.exp(-deltas * (F.relu(sigmas + noise))) | |
else: | |
raise "Need to choose clamp mode" | |
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1, :]), 1-alphas + 1e-10], -2) | |
weights = alphas * torch.cumprod(alphas_shifted, -2)[..., :-1, :] | |
weights_sum = weights.sum(-2) | |
if last_back: | |
weights[..., -1, :] += (1 - weights_sum) | |
rgb_final = torch.sum(weights * rgbs, -2) | |
depth_final = torch.sum(weights * z_vals, -2) | |
if white_back: | |
rgb_final = rgb_final + 1-weights_sum | |
if fill_mode == 'debug': | |
rgb_final[weights_sum.squeeze(-1) < 0.9] = torch.tensor([1., 0, 0], device=rgb_final.device) | |
elif fill_mode == 'weight': | |
rgb_final = weights_sum.expand_as(rgb_final) | |
return rgb_final, depth_final, weights | |
def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64): | |
# return numpy array of forwarded sigma value | |
bound = (nerf.depth_range[1] - nerf.depth_range[0]) * 0.5 | |
X = torch.linspace(-bound, bound, resolution).split(block_resolution) | |
sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32) | |
for xi, xs in enumerate(X): | |
for yi, ys in enumerate(X): | |
for zi, zs in enumerate(X): | |
xx, yy, zz = torch.meshgrid(xs, ys, zs) | |
pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C | |
block_shape = [1, len(xs), len(ys), len(zs)] | |
feat_out, sigma_out = nerf.fg_nerf.forward_style2(pts, None, block_shape, ws=styles) | |
sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \ | |
yi * block_resolution: yi * block_resolution + len(ys), \ | |
zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy() | |
return sigma_np, bound | |
def extract_geometry(nerf, styles, resolution, threshold): | |
import mcubes | |
print('threshold: {}'.format(threshold)) | |
u, bound = get_sigma_field_np(nerf, styles, resolution) | |
vertices, triangles = mcubes.marching_cubes(u, threshold) | |
b_min_np = np.array([-bound, -bound, -bound]) | |
b_max_np = np.array([ bound, bound, bound]) | |
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] | |
return vertices.astype('float32'), triangles | |
def render_mesh(meshes, camera_matrices, render_noise=True): | |
from pytorch3d.renderer import ( | |
FoVPerspectiveCameras, look_at_view_transform, | |
RasterizationSettings, BlendParams, | |
MeshRenderer, MeshRasterizer, HardPhongShader, TexturesVertex | |
) | |
from pytorch3d.ops import interpolate_face_attributes | |
from pytorch3d.structures.meshes import Meshes | |
intrinsics, poses, _, _ = camera_matrices | |
device = poses.device | |
c2w = torch.matmul(poses, torch.diag(torch.tensor([-1.0, 1.0, -1.0, 1.0], device=device))[None, :, :]) # Different camera model... | |
w2c = torch.inverse(c2w) | |
R = c2w[:, :3, :3] | |
T = w2c[:, :3, 3] # So weird..... Why one is c2w and another is w2c? | |
focal = intrinsics[0, 0, 0] | |
fov = torch.arctan(focal) * 2.0 / np.pi * 180 | |
colors = [] | |
offset = 1 | |
for res, (mesh, face_vert_noise) in meshes.items(): | |
raster_settings = RasterizationSettings( | |
image_size=res, | |
blur_radius=0.0, | |
faces_per_pixel=1, | |
) | |
mesh = Meshes( | |
verts=[torch.from_numpy(mesh.vertices).float().to(device)], | |
faces=[torch.from_numpy(mesh.faces).long().to(device)]) | |
_colors = [] | |
for i in range(len(poses)): | |
cameras = FoVPerspectiveCameras(device=device, R=R[i: i+1], T=T[i: i+1], fov=fov) | |
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) | |
pix_to_face, zbuf, bary_coord, dists = rasterizer(mesh) | |
color = interpolate_face_attributes(pix_to_face, bary_coord, face_vert_noise).squeeze() | |
# hack | |
color[offset:, offset:] = color[:-offset, :-offset] | |
_colors += [color] | |
color = torch.stack(_colors, 0).permute(0,3,1,2) | |
colors += [color] | |
offset *= 2 | |
return colors | |
def rotate_vects(v, theta): | |
theta = theta / math.pi * 2 | |
theta = theta + (theta < 0).type_as(theta) * 4 | |
v = v.reshape(v.size(0), v.size(1) // 4, 4, v.size(2), v.size(3)) | |
vs = [] | |
order = [0,2,3,1] # Not working | |
iorder = [0,3,1,2] # Not working | |
for b in range(len(v)): | |
if (theta[b] - 0) < 1e-6: | |
u, l = 0, 0 | |
elif (theta[b] - 1) < 1e-6: | |
u, l = 0, 1 | |
elif (theta[b] - 2) < 1e-6: | |
u, l = 0, 2 | |
elif (theta[b] - 3) < 1e-6: | |
u, l = 0, 3 | |
else: | |
u, l = math.modf(theta[b]) | |
l, r = int(l), int(l + 1) % 4 | |
vv = v[b, :, order] # 0 -> 1 -> 3 -> 2 | |
vl = torch.cat([vv[:, l:], vv[:, :l]], 1) | |
if u > 0: | |
vr = torch.cat([vv[:, r:], vv[:, :r]], 1) | |
vv = vl * (1-u) + vr * u | |
else: | |
vv = vl | |
vs.append(vv[:, iorder]) | |
v = torch.stack(vs, 0) | |
v = v.reshape(v.size(0), -1, v.size(-2), v.size(-1)) | |
return v | |
def generate_option_outputs(render_option): | |
# output debugging outputs (not used in normal rendering process) | |
if ('depth' in render_option.split(',')): | |
img = camera_world[:, :1] + fg_depth_map * ray_vector | |
img = reformat(img, tgt_res) | |
if 'gradient' in render_option.split(','): | |
points = (camera_world[:,:,None]+di[:,:,:,None]*ray_vector[:,:,None]).reshape( | |
batch_size, tgt_res, tgt_res, di.size(-1), 3) | |
with torch.enable_grad(): | |
gradients = self.fg_nerf.forward_style2( | |
points, None, [batch_size, tgt_res, di.size(-1), tgt_res], get_normal=True, | |
ws=styles, z_shape=z_shape_obj, z_app=z_app_obj).reshape( | |
batch_size, di.size(-1), 3, tgt_res * tgt_res).permute(0,3,1,2) | |
avg_grads = (gradients * fg_weights.unsqueeze(-1)).sum(-2) | |
normal = reformat(normalize(avg_grads, axis=2)[0], tgt_res) | |
img = normal | |
if 'value' in render_option.split(','): | |
fg_feat = fg_feat[:,:,3:].norm(dim=-1,keepdim=True) | |
img = reformat(fg_feat.repeat(1,1,3), tgt_res) / fg_feat.max() * 2 - 1 | |
if 'opacity' in render_option.split(','): | |
opacity = bg_lambda.unsqueeze(-1).repeat(1,1,3) * 2 - 1 | |
img = reformat(opacity, tgt_res) | |
if 'normal' in render_option.split(','): | |
shift_l, shift_r = img[:,:,2:,:], img[:,:,:-2,:] | |
shift_u, shift_d = img[:,:,:,2:], img[:,:,:,:-2] | |
diff_hor = normalize(shift_r - shift_l, axis=1)[0][:, :, :, 1:-1] | |
diff_ver = normalize(shift_u - shift_d, axis=1)[0][:, :, 1:-1, :] | |
normal = torch.cross(diff_hor, diff_ver, dim=1) | |
img = normalize(normal, axis=1)[0] | |
return {'full_out': (None, img), 'reg_loss': {}} | |