Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from typing import Tuple, Literal | |
from functools import partial | |
import itertools | |
# LRM | |
from .embedder import CameraEmbedder | |
from .transformer import TransformerDecoder | |
# from accelerate.logging import get_logger | |
# logger = get_logger(__name__) | |
class LRM_VSD_Mesh_Net(nn.Module): | |
""" | |
predict VSD using transformer | |
""" | |
def __init__(self, camera_embed_dim: int, | |
transformer_dim: int, transformer_layers: int, transformer_heads: int, | |
triplane_low_res: int, triplane_high_res: int, triplane_dim: int, | |
encoder_freeze: bool = True, encoder_type: str = 'dino', | |
encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768, app_dim = 27, density_dim = 8, app_n_comp=24, | |
density_n_comp=8): | |
super().__init__() | |
# attributes | |
self.encoder_feat_dim = encoder_feat_dim | |
self.camera_embed_dim = camera_embed_dim | |
self.triplane_low_res = triplane_low_res | |
self.triplane_high_res = triplane_high_res | |
self.triplane_dim = triplane_dim | |
self.transformer_dim=transformer_dim | |
# modules | |
self.encoder = self._encoder_fn(encoder_type)( | |
model_name=encoder_model_name, | |
modulation_dim=self.camera_embed_dim, #mod camera vector | |
freeze=encoder_freeze, | |
) | |
self.camera_embedder = CameraEmbedder( | |
raw_dim=12+4, embed_dim=camera_embed_dim, | |
) | |
self.n_comp=app_n_comp+density_n_comp | |
self.app_dim=app_dim | |
self.density_dim=density_dim | |
self.app_n_comp=app_n_comp | |
self.density_n_comp=density_n_comp | |
self.pos_embed = nn.Parameter(torch.randn(1, 3*(triplane_low_res**2)+3*triplane_low_res, transformer_dim) * (1. / transformer_dim) ** 0.5) | |
self.transformer = TransformerDecoder( | |
block_type='cond', | |
num_layers=transformer_layers, num_heads=transformer_heads, | |
inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=None, | |
) | |
# for plane | |
self.upsampler = nn.ConvTranspose2d(transformer_dim, self.n_comp, kernel_size=2, stride=2, padding=0) | |
self.dim_map = nn.Linear(transformer_dim,self.n_comp) | |
self.up_line = nn.Linear(triplane_low_res,triplane_low_res*2) | |
def _encoder_fn(encoder_type: str): | |
encoder_type = encoder_type.lower() | |
assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type" | |
if encoder_type == 'dino': | |
from .encoders.dino_wrapper import DinoWrapper | |
#logger.info("Using DINO as the encoder") | |
return DinoWrapper | |
elif encoder_type == 'dinov2': | |
from .encoders.dinov2_wrapper import Dinov2Wrapper | |
#logger.info("Using DINOv2 as the encoder") | |
return Dinov2Wrapper | |
def forward_transformer(self, image_feats, camera_embeddings=None): | |
N = image_feats.shape[0] | |
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] | |
x = self.transformer( | |
x, | |
cond=image_feats, | |
mod=camera_embeddings, | |
) | |
return x | |
def reshape_upsample(self, tokens): | |
#B,_,3*ncomp | |
N = tokens.shape[0] | |
H = W = self.triplane_low_res | |
P=self.n_comp | |
offset=3*H*W | |
# planes | |
plane_tokens= tokens[:,:3*H*W,:].view(N,H,W,3,self.transformer_dim) | |
plane_tokens = torch.einsum('nhwip->inphw', plane_tokens) # [3, N, P, H, W] | |
plane_tokens = plane_tokens.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] | |
plane_tokens = self.upsampler(plane_tokens) # [3*N, P, H', W'] | |
plane_tokens = plane_tokens.view(3, N, *plane_tokens.shape[-3:]) # [3, N, P, H', W'] | |
plane_tokens = torch.einsum('inphw->niphw', plane_tokens) # [N, 3, P, H', W'] | |
plane_tokens = plane_tokens.reshape(N, 3*P, *plane_tokens.shape[-2:]) # # [N, 3*P, H', W'] | |
plane_tokens = plane_tokens.contiguous() | |
#lines | |
line_tokens= tokens[:,3*H*W:3*H*W+3*H,:].view(N,H,3,self.transformer_dim) | |
line_tokens= self.dim_map(line_tokens) | |
line_tokens = torch.einsum('nhip->npih', line_tokens) # [ N, P, 3, H] | |
line_tokens=self.up_line(line_tokens) | |
line_tokens = torch.einsum('npih->niph', line_tokens) # [ N, 3, P, H] | |
line_tokens=line_tokens.reshape(N,3*P,line_tokens.shape[-1],1) | |
line_tokens = line_tokens.contiguous() | |
mat_tokens=None | |
d_mat_tokens=None | |
return plane_tokens[:,:self.app_n_comp*3,:,:],line_tokens[:,:self.app_n_comp*3,:,:],mat_tokens,d_mat_tokens,plane_tokens[:,self.app_n_comp*3:,:,:],line_tokens[:,self.app_n_comp*3:,:,:] | |
def forward_planes(self, image, camera): | |
# image: [N, V, C_img, H_img, W_img] | |
# camera: [N,V, D_cam_raw] | |
N,V,_,H,W = image.shape | |
image=image.reshape(N*V,3,H,W) | |
camera=camera.reshape(N*V,-1) | |
# embed camera | |
camera_embeddings = self.camera_embedder(camera) | |
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \ | |
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}" | |
# encode image | |
image_feats = self.encoder(image, camera_embeddings) | |
assert image_feats.shape[-1] == self.encoder_feat_dim, \ | |
f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}" | |
image_feats=image_feats.reshape(N,V*image_feats.shape[-2],image_feats.shape[-1]) | |
# transformer generating planes | |
tokens = self.forward_transformer(image_feats) | |
app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.reshape_upsample(tokens) | |
return app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines | |
def forward(self, image,source_camera): | |
# image: [N,V, C_img, H_img, W_img] | |
# source_camera: [N, V, D_cam_raw] | |
assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera" | |
planes = self.forward_planes(image, source_camera) | |
#B,3,dim,H,W | |
return planes |