StyleNeRF / dnnlib /geometry.py
Jiatao Gu
add code from the original repo
94ada0b
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn.functional as F
import math
import random
import numpy as np
def positional_encoding(p, size, pe='normal', use_pos=False):
if pe == 'gauss':
p_transformed = np.pi * p @ size
p_transformed = torch.cat(
[torch.sin(p_transformed), torch.cos(p_transformed)], dim=-1)
else:
p_transformed = torch.cat([torch.cat(
[torch.sin((2 ** i) * np.pi * p),
torch.cos((2 ** i) * np.pi * p)],
dim=-1) for i in range(size)], dim=-1)
if use_pos:
p_transformed = torch.cat([p_transformed, p], -1)
return p_transformed
def upsample(img_nerf, size, filter=None):
up = size // img_nerf.size(-1)
if up <= 1:
return img_nerf
if filter is not None:
from torch_utils.ops import upfirdn2d
for _ in range(int(math.log2(up))):
img_nerf = upfirdn2d.downsample2d(img_nerf, filter, up=2)
else:
img_nerf = F.interpolate(img_nerf, (size, size), mode='bilinear', align_corners=False)
return img_nerf
def downsample(img0, size, filter=None):
down = img0.size(-1) // size
if down <= 1:
return img0
if filter is not None:
from torch_utils.ops import upfirdn2d
for _ in range(int(math.log2(down))):
img0 = upfirdn2d.downsample2d(img0, filter, down=2)
else:
img0 = F.interpolate(img0, (size, size), mode='bilinear', align_corners=False)
return img0
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
"""
Normalize vector lengths.
"""
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
def repeat_vecs(vecs, n, dim=0):
return torch.stack(n*[vecs], dim=dim)
def get_grids(H, W, device, align=True):
ch = 1 if align else 1 - (1 / H)
cw = 1 if align else 1 - (1 / W)
x, y = torch.meshgrid(torch.linspace(-cw, cw, W, device=device),
torch.linspace(ch, -ch, H, device=device))
return torch.stack([x, y], -1)
def local_ensemble(pi, po, resolution):
ii = range(resolution)
ia = torch.tensor([max((i - 1)//2, 0) for i in ii]).long()
ib = torch.tensor([min((i + 1)//2, resolution//2-1) for i in ii]).long()
ul = torch.meshgrid(ia, ia)
ur = torch.meshgrid(ia, ib)
ll = torch.meshgrid(ib, ia)
lr = torch.meshgrid(ib, ib)
d_ul, p_ul = po - pi[ul], torch.stack(ul, -1)
d_ur, p_ur = po - pi[ur], torch.stack(ur, -1)
d_ll, p_ll = po - pi[ll], torch.stack(ll, -1)
d_lr, p_lr = po - pi[lr], torch.stack(lr, -1)
c_ul = d_ul.prod(dim=-1).abs()
c_ur = d_ur.prod(dim=-1).abs()
c_ll = d_ll.prod(dim=-1).abs()
c_lr = d_lr.prod(dim=-1).abs()
D = torch.stack([d_ul, d_ur, d_ll, d_lr], 0)
P = torch.stack([p_ul, p_ur, p_ll, p_lr], 0)
C = torch.stack([c_ul, c_ur, c_ll, c_lr], 0)
C = C / C.sum(dim=0, keepdim=True)
return D, P, C
def get_initial_rays_trig(num_steps, fov, resolution, ray_start, ray_end, device='cpu'):
"""Returns sample points, z_vals, ray directions in camera space."""
W, H = resolution
# Create full screen NDC (-1 to +1) coords [x, y, 0, 1].
# Y is flipped to follow image memory layouts.
x, y = torch.meshgrid(torch.linspace(-1, 1, W, device=device),
torch.linspace(1, -1, H, device=device))
x = x.T.flatten()
y = y.T.flatten()
z = -torch.ones_like(x, device=device) / math.tan((2 * math.pi * fov / 360)/2)
rays_d_cam = normalize_vecs(torch.stack([x, y, z], -1))
z_vals = torch.linspace(ray_start, ray_end, num_steps, device=device).reshape(1, num_steps, 1).repeat(W*H, 1, 1)
points = rays_d_cam.unsqueeze(1).repeat(1, num_steps, 1) * z_vals
return points, z_vals, rays_d_cam
def sample_camera_positions(
device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1,
horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'):
"""
Samples n random locations along a sphere of radius r.
Uses a gaussian distribution for pitch and yaw
"""
if mode == 'uniform':
theta = (torch.rand((n, 1),device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean
phi = (torch.rand((n, 1),device=device) - 0.5) * 2 * vertical_stddev + vertical_mean
elif mode == 'normal' or mode == 'gaussian':
theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
elif mode == 'hybrid':
if random.random() < 0.5:
theta = (torch.rand((n, 1),device=device) - 0.5) * 2 * horizontal_stddev * 2 + horizontal_mean
phi = (torch.rand((n, 1),device=device) - 0.5) * 2 * vertical_stddev * 2 + vertical_mean
else:
theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
else:
phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean
theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean
phi = torch.clamp(phi, 1e-5, math.pi - 1e-5)
output_points = torch.zeros((n, 3), device=device)# torch.cuda.FloatTensor(n, 3).fill_(0)#torch.zeros((n, 3))
output_points[:, 0:1] = r*torch.sin(phi) * torch.cos(theta)
output_points[:, 2:3] = r*torch.sin(phi) * torch.sin(theta)
output_points[:, 1:2] = r*torch.cos(phi)
return output_points, phi, theta
def perturb_points(points, z_vals, ray_directions, device):
distance_between_points = z_vals[:,:,1:2,:] - z_vals[:,:,0:1,:]
offset = (torch.rand(z_vals.shape, device=device)-0.5) * distance_between_points
z_vals = z_vals + offset
points = points + offset * ray_directions.unsqueeze(2)
return points, z_vals
def create_cam2world_matrix(forward_vector, origin, device=None):
"""Takes in the direction the camera is pointing and the camera origin and returns a world2cam matrix."""
forward_vector = normalize_vecs(forward_vector)
up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector)
left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1))
rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1)
translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
translation_matrix[:, :3, 3] = origin
cam2world = translation_matrix @ rotation_matrix
return cam2world
def transform_sampled_points(
points, z_vals, ray_directions, device,
h_stddev=1, v_stddev=1, h_mean=math.pi * 0.5,
v_mean=math.pi * 0.5, mode='normal'):
"""
points: batch_size x total_pixels x num_steps x 3
z_vals: batch_size x total_pixels x num_steps
"""
n, num_rays, num_steps, channels = points.shape
points, z_vals = perturb_points(points, z_vals, ray_directions, device)
camera_origin, pitch, yaw = sample_camera_positions(
n=points.shape[0], r=1,
horizontal_stddev=h_stddev, vertical_stddev=v_stddev,
horizontal_mean=h_mean, vertical_mean=v_mean,
device=device, mode=mode)
forward_vector = normalize_vecs(-camera_origin)
cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device)
points_homogeneous = torch.ones((points.shape[0], points.shape[1], points.shape[2], points.shape[3] + 1), device=device)
points_homogeneous[:, :, :, :3] = points
# should be n x 4 x 4 , n x r^2 x num_steps x 4
transformed_points = torch.bmm(cam2world_matrix, points_homogeneous.reshape(n, -1, 4).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, num_steps, 4)
transformed_ray_directions = torch.bmm(cam2world_matrix[..., :3, :3], ray_directions.reshape(n, -1, 3).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, 3)
homogeneous_origins = torch.zeros((n, 4, num_rays), device=device)
homogeneous_origins[:, 3, :] = 1
transformed_ray_origins = torch.bmm(cam2world_matrix, homogeneous_origins).permute(0, 2, 1).reshape(n, num_rays, 4)[..., :3]
return transformed_points[..., :3], z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw
def integration(
rgb_sigma, z_vals, device, noise_std=0.5,
last_back=False, white_back=False, clamp_mode=None, fill_mode=None):
rgbs = rgb_sigma[..., :3]
sigmas = rgb_sigma[..., 3:]
deltas = z_vals[..., 1:, :] - z_vals[..., :-1, :]
delta_inf = 1e10 * torch.ones_like(deltas[..., :1, :])
deltas = torch.cat([deltas, delta_inf], -2)
if noise_std > 0:
noise = torch.randn(sigmas.shape, device=device) * noise_std
else:
noise = 0
if clamp_mode == 'softplus':
alphas = 1 - torch.exp(-deltas * (F.softplus(sigmas + noise)))
elif clamp_mode == 'relu':
alphas = 1 - torch.exp(-deltas * (F.relu(sigmas + noise)))
else:
raise "Need to choose clamp mode"
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1, :]), 1-alphas + 1e-10], -2)
weights = alphas * torch.cumprod(alphas_shifted, -2)[..., :-1, :]
weights_sum = weights.sum(-2)
if last_back:
weights[..., -1, :] += (1 - weights_sum)
rgb_final = torch.sum(weights * rgbs, -2)
depth_final = torch.sum(weights * z_vals, -2)
if white_back:
rgb_final = rgb_final + 1-weights_sum
if fill_mode == 'debug':
rgb_final[weights_sum.squeeze(-1) < 0.9] = torch.tensor([1., 0, 0], device=rgb_final.device)
elif fill_mode == 'weight':
rgb_final = weights_sum.expand_as(rgb_final)
return rgb_final, depth_final, weights
def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64):
# return numpy array of forwarded sigma value
bound = (nerf.depth_range[1] - nerf.depth_range[0]) * 0.5
X = torch.linspace(-bound, bound, resolution).split(block_resolution)
sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32)
for xi, xs in enumerate(X):
for yi, ys in enumerate(X):
for zi, zs in enumerate(X):
xx, yy, zz = torch.meshgrid(xs, ys, zs)
pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C
block_shape = [1, len(xs), len(ys), len(zs)]
feat_out, sigma_out = nerf.fg_nerf.forward_style2(pts, None, block_shape, ws=styles)
sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \
yi * block_resolution: yi * block_resolution + len(ys), \
zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy()
return sigma_np, bound
def extract_geometry(nerf, styles, resolution, threshold):
import mcubes
print('threshold: {}'.format(threshold))
u, bound = get_sigma_field_np(nerf, styles, resolution)
vertices, triangles = mcubes.marching_cubes(u, threshold)
b_min_np = np.array([-bound, -bound, -bound])
b_max_np = np.array([ bound, bound, bound])
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
return vertices.astype('float32'), triangles
def render_mesh(meshes, camera_matrices, render_noise=True):
from pytorch3d.renderer import (
FoVPerspectiveCameras, look_at_view_transform,
RasterizationSettings, BlendParams,
MeshRenderer, MeshRasterizer, HardPhongShader, TexturesVertex
)
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures.meshes import Meshes
intrinsics, poses, _, _ = camera_matrices
device = poses.device
c2w = torch.matmul(poses, torch.diag(torch.tensor([-1.0, 1.0, -1.0, 1.0], device=device))[None, :, :]) # Different camera model...
w2c = torch.inverse(c2w)
R = c2w[:, :3, :3]
T = w2c[:, :3, 3] # So weird..... Why one is c2w and another is w2c?
focal = intrinsics[0, 0, 0]
fov = torch.arctan(focal) * 2.0 / np.pi * 180
colors = []
offset = 1
for res, (mesh, face_vert_noise) in meshes.items():
raster_settings = RasterizationSettings(
image_size=res,
blur_radius=0.0,
faces_per_pixel=1,
)
mesh = Meshes(
verts=[torch.from_numpy(mesh.vertices).float().to(device)],
faces=[torch.from_numpy(mesh.faces).long().to(device)])
_colors = []
for i in range(len(poses)):
cameras = FoVPerspectiveCameras(device=device, R=R[i: i+1], T=T[i: i+1], fov=fov)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
pix_to_face, zbuf, bary_coord, dists = rasterizer(mesh)
color = interpolate_face_attributes(pix_to_face, bary_coord, face_vert_noise).squeeze()
# hack
color[offset:, offset:] = color[:-offset, :-offset]
_colors += [color]
color = torch.stack(_colors, 0).permute(0,3,1,2)
colors += [color]
offset *= 2
return colors
def rotate_vects(v, theta):
theta = theta / math.pi * 2
theta = theta + (theta < 0).type_as(theta) * 4
v = v.reshape(v.size(0), v.size(1) // 4, 4, v.size(2), v.size(3))
vs = []
order = [0,2,3,1] # Not working
iorder = [0,3,1,2] # Not working
for b in range(len(v)):
if (theta[b] - 0) < 1e-6:
u, l = 0, 0
elif (theta[b] - 1) < 1e-6:
u, l = 0, 1
elif (theta[b] - 2) < 1e-6:
u, l = 0, 2
elif (theta[b] - 3) < 1e-6:
u, l = 0, 3
else:
u, l = math.modf(theta[b])
l, r = int(l), int(l + 1) % 4
vv = v[b, :, order] # 0 -> 1 -> 3 -> 2
vl = torch.cat([vv[:, l:], vv[:, :l]], 1)
if u > 0:
vr = torch.cat([vv[:, r:], vv[:, :r]], 1)
vv = vl * (1-u) + vr * u
else:
vv = vl
vs.append(vv[:, iorder])
v = torch.stack(vs, 0)
v = v.reshape(v.size(0), -1, v.size(-2), v.size(-1))
return v
def generate_option_outputs(render_option):
# output debugging outputs (not used in normal rendering process)
if ('depth' in render_option.split(',')):
img = camera_world[:, :1] + fg_depth_map * ray_vector
img = reformat(img, tgt_res)
if 'gradient' in render_option.split(','):
points = (camera_world[:,:,None]+di[:,:,:,None]*ray_vector[:,:,None]).reshape(
batch_size, tgt_res, tgt_res, di.size(-1), 3)
with torch.enable_grad():
gradients = self.fg_nerf.forward_style2(
points, None, [batch_size, tgt_res, di.size(-1), tgt_res], get_normal=True,
ws=styles, z_shape=z_shape_obj, z_app=z_app_obj).reshape(
batch_size, di.size(-1), 3, tgt_res * tgt_res).permute(0,3,1,2)
avg_grads = (gradients * fg_weights.unsqueeze(-1)).sum(-2)
normal = reformat(normalize(avg_grads, axis=2)[0], tgt_res)
img = normal
if 'value' in render_option.split(','):
fg_feat = fg_feat[:,:,3:].norm(dim=-1,keepdim=True)
img = reformat(fg_feat.repeat(1,1,3), tgt_res) / fg_feat.max() * 2 - 1
if 'opacity' in render_option.split(','):
opacity = bg_lambda.unsqueeze(-1).repeat(1,1,3) * 2 - 1
img = reformat(opacity, tgt_res)
if 'normal' in render_option.split(','):
shift_l, shift_r = img[:,:,2:,:], img[:,:,:-2,:]
shift_u, shift_d = img[:,:,:,2:], img[:,:,:,:-2]
diff_hor = normalize(shift_r - shift_l, axis=1)[0][:, :, :, 1:-1]
diff_ver = normalize(shift_u - shift_d, axis=1)[0][:, :, 1:-1, :]
normal = torch.cross(diff_hor, diff_ver, dim=1)
img = normalize(normal, axis=1)[0]
return {'full_out': (None, img), 'reg_loss': {}}