import gradio as gr import torch import torchvision.transforms.functional as torchvision_F import numpy as np import os import shutil import importlib import trimesh import tempfile import subprocess import utils.options as options import shlex import time import rembg from utils.util import EasyDict as edict from PIL import Image from utils.eval_3D import get_dense_3D_grid, compute_level_grid, convert_to_explicit def get_1d_bounds(arr): nz = np.flatnonzero(arr) return nz[0], nz[-1] def get_bbox_from_mask(mask, thr): masks_for_box = (mask > thr).astype(np.float32) assert masks_for_box.sum() > 0, "Empty mask!" x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2)) y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1)) return x0, y0, x1, y1 def square_crop(image, bbox, crop_ratio=1.): x1, y1, x2, y2 = bbox h, w = y2-y1, x2-x1 yc, xc = (y1+y2)/2, (x1+x2)/2 S = max(h, w)*1.2 scale = S*crop_ratio image = torchvision_F.crop(image, top=int(yc-scale/2), left=int(xc-scale/2), height=int(scale), width=int(scale)) return image def preprocess_image(opt, image, bbox): image = square_crop(image, bbox=bbox) if image.size[0] != opt.W or image.size[1] != opt.H: image = image.resize((opt.W, opt.H)) image = torchvision_F.to_tensor(image) rgb, mask = image[:3], image[3:] if opt.data.bgcolor is not None: # replace background color using mask rgb = rgb * mask + opt.data.bgcolor * (1 - mask) mask = (mask > 0.5).float() return rgb, mask def get_image(opt, image_fname, mask_fname): image = Image.open(image_fname).convert("RGB") mask = Image.open(mask_fname).convert("L") mask_np = np.array(mask) #binarize mask_np[mask_np <= 127] = 0 mask_np[mask_np >= 127] = 1.0 image = Image.merge("RGBA", (*image.split(), mask)) bbox = get_bbox_from_mask(mask_np, 0.5) rgb_input_map, mask_input_map = preprocess_image(opt, image, bbox=bbox) return rgb_input_map, mask_input_map def get_intr(opt): # load camera f = 1.3875 K = torch.tensor([[f*opt.W, 0, opt.W/2], [0, f*opt.H, opt.H/2], [0, 0, 1]]).float() return K def get_pixel_grid(H, W, device='cuda'): y_range = torch.arange(H, dtype=torch.float32).to(device) x_range = torch.arange(W, dtype=torch.float32).to(device) Y, X = torch.meshgrid(y_range, x_range, indexing='ij') Z = torch.ones_like(Y).to(device) xyz_grid = torch.stack([X, Y, Z],dim=-1).view(-1,3) return xyz_grid def unproj_depth(depth, intr): ''' depth: [B, H, W] intr: [B, 3, 3] ''' batch_size, H, W = depth.shape intr = intr.to(depth.device) # [B, 3, 3] K_inv = torch.linalg.inv(intr).float() # [1, H*W,3] pixel_grid = get_pixel_grid(H, W, depth.device).unsqueeze(0) # [B, H*W,3] pixel_grid = pixel_grid.repeat(batch_size, 1, 1) # [B, 3, H*W] ray_dirs = K_inv @ pixel_grid.permute(0, 2, 1).contiguous() # [B, H*W, 3], in camera coordinates seen_points = ray_dirs.permute(0, 2, 1).contiguous() * depth.view(batch_size, H*W, 1) # [B, H, W, 3] seen_points = seen_points.view(batch_size, H, W, 3) return seen_points def prepare_data(opt, image_path, mask_path): var = edict() rgb_input_map, mask_input_map = get_image(opt, image_path, mask_path) intr = get_intr(opt) var.rgb_input_map = rgb_input_map.unsqueeze(0).to(opt.device) var.mask_input_map = mask_input_map.unsqueeze(0).to(opt.device) var.intr = intr.unsqueeze(0).to(opt.device) var.idx = torch.tensor([0]).to(opt.device).long() var.pose_gt = False return var @torch.no_grad() def marching_cubes(opt, var, impl_network, visualize_attn=False): points_3D = get_dense_3D_grid(opt, var) # [B, N, N, N, 3] level_vox, attn_vis = compute_level_grid(opt, impl_network, var.latent_depth, var.latent_semantic, points_3D, var.rgb_input_map, visualize_attn) if attn_vis: var.attn_vis = attn_vis # occ_grids: a list of length B, each is [N, N, N] *level_grids, = level_vox.cpu().numpy() meshes = convert_to_explicit(opt, level_grids, isoval=0.5, to_pointcloud=False) var.mesh_pred = meshes return var @torch.no_grad() def infer_sample(opt, var, graph): var = graph.forward(opt, var, training=False, get_loss=False) var = marching_cubes(opt, var, graph.impl_network, visualize_attn=True) return var.mesh_pred[0] def infer(input_image_path, input_mask_path): opt_cmd = options.parse_arguments(["--yaml=options/shape.yaml", "--datadir=examples", "--eval.vox_res=128", "--ckpt=/data/shape.ckpt"]) opt = options.set(opt_cmd=opt_cmd, safe_check=False) opt.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # build model print("Building model...") opt.pretrain.depth = None opt.arch.depth.pretrained = None module = importlib.import_module("model.compute_graph.graph_shape") graph = module.Graph(opt).to(opt.device) # download checkpoint if not os.path.isfile(opt.ckpt): print("Downloading checkpoint...") subprocess.run( shlex.split( "wget -q -O /data/shape.ckpt https://www.dropbox.com/scl/fi/hv3w9z59dqytievwviko4/shape.ckpt?rlkey=a2gut89kavrldmnt8b3df92oi&dl=0" ) ) # wait if the checkpoint is still downloading while not os.path.isfile(opt.ckpt): time.sleep(1) # load checkpoint print("Loading checkpoint...") checkpoint = torch.load(opt.ckpt, map_location=torch.device(opt.device)) graph.load_state_dict(checkpoint["graph"], strict=True) graph.eval() # load the data print("Loading data...") var = prepare_data(opt, input_image_path, input_mask_path) # create the save dir save_folder = os.path.join(opt.datadir, 'preds') if os.path.isdir(save_folder): shutil.rmtree(save_folder) os.makedirs(save_folder) opt.output_path = opt.datadir # inference the model and save the results print("Inferencing...") mesh_pred = infer_sample(opt, var, graph) # rotate the mesh upside down mesh_pred.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])) mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) mesh_pred.export(mesh_path.name, file_type="glb") return mesh_path.name def infer_wrapper_mask(input_image_path, input_mask_path): return infer(input_image_path, input_mask_path) def infer_wrapper_nomask(input_image_path): input = Image.open(input_image_path) segmented = rembg.remove(input) mask = segmented.split()[-1] mask_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False) mask.save(mask_path.name) return infer(input_image_path, mask_path.name), mask_path.name def assert_input_image(input_image): if input_image is None: raise gr.Error("No image selected or uploaded!") def assert_mask_image(input_mask): if input_mask is None: raise gr.Error("No mask selected or uploaded! Please check the box if you do not have the mask.") def demo_gradio(): with gr.Blocks(analytics_enabled=False) as demo_ui: # HEADERS with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ZeroShape: Regression-based Zero-shot Shape Reconstruction') gr.Markdown("[\[Arxiv\]](https://arxiv.org/pdf/2312.14198.pdf) | [\[Project\]](https://zixuanh.com/projects/zeroshape.html) | [\[GitHub\]](https://github.com/zxhuang1698/ZeroShape)") gr.Markdown("Please switch to the \"Estimated Mask\" tab if you do not have the foreground mask. The demo will try to estimate the mask for you.") # with mask with gr.Tab("Groundtruth Mask"): with gr.Row(): input_image_tab1 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) mask_tab1 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) output_mesh_tab1 = gr.Model3D(label="Output Mesh") with gr.Row(): submit_tab1 = gr.Button('Reconstruct', elem_id="recon_button_tab1", variant='primary') # examples with gr.Row(): examples_tab1 = [ ['examples/images/armchair.png', 'examples/masks/armchair.png'], ['examples/images/bolt.png', 'examples/masks/bolt.png'], ['examples/images/bucket.png', 'examples/masks/bucket.png'], ['examples/images/case.png', 'examples/masks/case.png'], ['examples/images/dispenser.png', 'examples/masks/dispenser.png'], ['examples/images/hat.png', 'examples/masks/hat.png'], ['examples/images/teddy_bear.png', 'examples/masks/teddy_bear.png'], ['examples/images/tiger.png', 'examples/masks/tiger.png'], ['examples/images/toy.png', 'examples/masks/toy.png'], ['examples/images/wedding_cake.png', 'examples/masks/wedding_cake.png'], ] gr.Examples( examples=examples_tab1, inputs=[input_image_tab1, mask_tab1], outputs=[output_mesh_tab1], fn=infer_wrapper_mask, cache_examples=False#os.getenv('SYSTEM') == 'spaces', ) # without mask with gr.Tab("Estimated Mask"): with gr.Row(): input_image_tab2 = gr.Image(label="Input Image", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) mask_tab2 = gr.Image(label="Foreground Mask", image_mode="RGB", sources="upload", type="filepath", elem_id="content_image", width=300) output_mesh_tab2 = gr.Model3D(label="Output Mesh") with gr.Row(): submit_tab2 = gr.Button('Reconstruct', elem_id="recon_button_tab2", variant='primary') # examples with gr.Row(): examples_tab2 = [ ['examples/images/armchair.png'], ['examples/images/bolt.png'], ['examples/images/bucket.png'], ['examples/images/case.png'], ['examples/images/dispenser.png'], ['examples/images/hat.png'], ['examples/images/teddy_bear.png'], ['examples/images/tiger.png'], ['examples/images/toy.png'], ['examples/images/wedding_cake.png'], ] gr.Examples( examples=examples_tab2, inputs=[input_image_tab2], outputs=[output_mesh_tab2, mask_tab2], fn=infer_wrapper_nomask, cache_examples=False#os.getenv('SYSTEM') == 'spaces', ) submit_tab1.click( fn=assert_input_image, inputs=[input_image_tab1], queue=False ).success( fn=assert_mask_image, inputs=[mask_tab1], queue=False ).success( fn=infer_wrapper_mask, inputs=[input_image_tab1, mask_tab1], outputs=[output_mesh_tab1], ) submit_tab2.click( fn=assert_input_image, inputs=[input_image_tab2], queue=False ).success( fn=infer_wrapper_nomask, inputs=[input_image_tab2], outputs=[output_mesh_tab2, mask_tab2], ) return demo_ui if __name__ == "__main__": demo_ui = demo_gradio() demo_ui.queue(max_size=10) demo_ui.launch()