BerfScene / models /sgbev3d_generator.py
3v324v23's picture
init
2f85de4
raw
history blame
12.9 kB
# python3.8
"""Contains the implementation of generator described in SGBEV3D."""
import torch
import torch.nn as nn
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
from models.utils.official_stylegan3_model_helper import Generator as StyleGAN3Backbone
from models.utils.unet import Generator as StyleGAN4Backbone
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
from models.utils.eg3d_superres import SuperresolutionHybrid2X
from models.utils.eg3d_superres import SuperresolutionHybrid4X
from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
from models.rendering.renderer import Renderer
from models.rendering.feature_extractor import FeatureExtractor
from models.utils.spade import SPADEGenerator
class SGBEV3DGenerator(nn.Module):
def __init__(
self,
z_dim,
c_dim,
w_dim,
semantic_nc,
ngf,
bev_grid_size,
aspect_ratio,
num_upsampling_layers,
not_use_vae,
norm_G,
interpolate_sr,
segmask=False,
dim_seq='16,8,4,2,1',
xyz_pe=False,
reverse_xy=True,
hidden_dim=64,
additional_layer_num=0,
block_num=5,
layer_num=2,
ff_input=False,
ref_mode='bev_plane_clevr_256',
sel_type=None,
backbone_ver=2,
img_resolution=256,
bev_resolution=256,
sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
mapping_kwargs={},
rendering_kwargs={}, # Arguments for rendering.
sr_kwargs={}, # Arguments for SuperResolution Network.
**synthesis_kwargs
):
super().__init__()
self.z_dim = z_dim
self.interpolate_sr = interpolate_sr
self.segmask = segmask
# Set up the overall renderer.
self.renderer = Renderer()
# Set up the feature extractor.
self.feature_extractor = FeatureExtractor(ref_mode=ref_mode, xyz_pe=xyz_pe, reverse_xy=reverse_xy)
# Set up the reference representation generator.
self.backbone = globals()[f'StyleGAN{backbone_ver}Backbone'](z_dim, c_dim, w_dim, img_resolution=bev_resolution, img_channels=32, label_nc=semantic_nc, use_sel=True, sel_type=sel_type, mapping_kwargs=mapping_kwargs, ff_input=ff_input, block_num=block_num, layer_num=layer_num, **synthesis_kwargs)
# Set up the post module in the feature extractor.
self.post_module = None
# Set up the post neural renderer.
self.post_neural_renderer = None
sr_kwargs_total = dict(
channels=32,
img_resolution=img_resolution,
sr_num_fp16_res=sr_num_fp16_res,
sr_antialias=rendering_kwargs['sr_antialias'],)
sr_kwargs_total.update(**sr_kwargs)
if img_resolution == 128:
self.post_neural_renderer = SuperresolutionHybrid2X(
**sr_kwargs_total)
elif img_resolution == 256:
self.post_neural_renderer = SuperresolutionHybrid4X_conststyle(
**sr_kwargs_total)
elif img_resolution == 512:
self.post_neural_renderer = SuperresolutionHybrid8XDC(
**sr_kwargs_total)
else:
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
# Set up the fully-connected layer head.
self.fc_head = OSGDecoder(
128 if xyz_pe else 64 , {
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
'decoder_output_dim': 32
},
hidden_dim=hidden_dim,
additional_layer_num=additional_layer_num
)
# Set up some rendering related arguments.
self.neural_rendering_resolution = rendering_kwargs.get(
'resolution', 64)
self.rendering_kwargs = rendering_kwargs
def mapping(self,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
return self.backbone.mapping(z,
c *
self.rendering_kwargs.get('c_scale', 0),
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
def synthesis(self,
wp,
c,
seg,
neural_rendering_resolution=None,
update_emas=False,
**synthesis_kwargs):
cam2world_matrix = c[:, :16].view(-1, 4, 4)
if self.rendering_kwargs.get('random_pose', False):
cam2world_matrix = None
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs)
if self.segmask:
xy_planes = xy_planes * seg[:, 0, ...][:, None, ...]
rendering_result = self.renderer(
wp=wp,
feature_extractor=self.feature_extractor,
rendering_options=self.rendering_kwargs,
cam2world_matrix=cam2world_matrix,
position_encoder=None,
ref_representation=xy_planes,
post_module=self.post_module,
fc_head=self.fc_head)
feature_samples = rendering_result['composite_rgb']
depth_samples = rendering_result['composite_depth']
# Reshape to keep consistent with 'raw' neural-rendered image.
N = wp.shape[0]
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(
N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run the post neural renderer to get final image.
# Here, the post neural renderer is a super-resolution network.
rgb_image = feature_image[:, :3]
if self.interpolate_sr:
sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False)
else:
sr_image = self.post_neural_renderer(
rgb_image,
feature_image,
# wp, # todo: study SR with wp
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
**{
k: synthesis_kwargs[k]
for k in synthesis_kwargs.keys() if k != 'noise_mode'
})
return {
'image': sr_image,
'image_raw': rgb_image,
'image_depth': depth_image,
'plane': xy_planes,
'points': rendering_result['points'],
'sigmas': rendering_result['sigmas']
}
def sample(self,
coordinates,
directions,
z,
c,
seg,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates.
# Mostly used for extracting shapes.
cam2world_matrix = c[:, :16].view(-1, 4, 4)
wp = self.mapping(z, c, truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs)
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=xy_planes,
post_module=self.post_module,
ray_dirs=directions,
cam_matrix=cam2world_matrix)
return result
def sample_mixed(self,
coordinates,
directions,
wp, c, seg,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Same as function `self.sample()`, but expects latent vectors 'wp'
# instead of Gaussian noise 'z'.
cam2world_matrix = c[:, :16].view(-1, 4, 4)
xy_planes = self.backbone.synthesis(wp, heatmap=seg, update_emas=update_emas, **synthesis_kwargs)
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=xy_planes,
post_module=self.post_module,
ray_dirs=directions,
cam_matrix=cam2world_matrix)
return result
def forward(self,
z,
c,
seg,
c_swapped=None, # `c_swapped` is swapped pose conditioning.
style_mixing_prob=0,
truncation_psi=1,
truncation_cutoff=None,
neural_rendering_resolution=None,
update_emas=False,
sample_mixed=False,
coordinates=None,
**synthesis_kwargs):
# Render a batch of generated images.
c_wp = c.clone()
if c_swapped is not None:
c_wp = c_swapped.clone()
wp = self.mapping(z, c_wp, truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
#TODO: implement style mixing
if not sample_mixed:
gen_output = self.synthesis(
wp,
c,
seg,
update_emas=update_emas,
neural_rendering_resolution=neural_rendering_resolution,
**synthesis_kwargs)
return {
'wp': z,
'gen_output': gen_output,
}
else:
# Only for density regularization in training process.
assert coordinates is not None
sample_sigma = self.sample_mixed(coordinates,
torch.randn_like(coordinates),
wp, c, seg,
update_emas=False)['sigma']
return {
'wp': z,
'sample_sigma': sample_sigma
}
class OSGDecoder(nn.Module):
"""Defines fully-connected layer head in EG3D."""
def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0):
super().__init__()
self.hidden_dim = hidden_dim
lst = []
lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
lst.append(nn.Softplus())
for i in range(additional_layer_num):
lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
lst.append(nn.Softplus())
lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']))
self.net = nn.Sequential(*lst)
# self.net = nn.Sequential(
# FullyConnectedLayer(n_features,
# self.hidden_dim,
# lr_multiplier=options['decoder_lr_mul']),
# nn.Softplus(),
# FullyConnectedLayer(self.hidden_dim,
# 1 + options['decoder_output_dim'],
# lr_multiplier=options['decoder_lr_mul']))
def forward(self, point_features, wp=None, dirs=None):
# Aggregate features
# point_features.shape: [N, R, K, C].
# Average across 'X, Y, Z' planes.
N, R, K, C = point_features.shape
x = point_features.reshape(-1, point_features.shape[-1])
x = self.net(x)
x = x.view(N, -1, x.shape[-1])
# Uses sigmoid clamping from MipNeRF
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}