import os import cv2 import torch import numpy as np import gradio as gr import trimesh import sys import os sys.path.append('vggsfm_code/') import shutil from vggsfm_code.hf_demo import demo_fn from omegaconf import DictConfig, OmegaConf from viz_utils.viz_fn import add_camera # from scipy.spatial.transform import Rotation import PIL import spaces @spaces.GPU def vggsfm_demo( input_image, input_video, query_frame_num, max_query_pts # grid_size: int = 10, ): cfg_file = "vggsfm_code/cfgs/demo.yaml" cfg = OmegaConf.load(cfg_file) max_input_image = 20 target_dir = f"input_images" if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) target_dir_images = target_dir + "/images" os.makedirs(target_dir_images) if input_image is not None: if len(input_image)<3: return None, "Please input at least three frames" input_image = sorted(input_image) input_image = input_image[:max_input_image] # Copy files to the new directory for file_name in input_image: shutil.copy(file_name, target_dir_images) elif input_video is not None: vs = cv2.VideoCapture(input_video) fps = vs.get(cv2.CAP_PROP_FPS) frame_rate = 1 frame_interval = int(fps * frame_rate) video_frame_num = 0 count = 0 while video_frame_num<=max_input_image: (gotit, frame) = vs.read() count +=1 if count % frame_interval == 0: cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) video_frame_num+=1 if not gotit: break if video_frame_num<3: return None, "Please input at least three frames" else: return None, "Input format incorrect" cfg.query_frame_num = query_frame_num cfg.max_query_pts = max_query_pts print(f"Files have been copied to {target_dir_images}") cfg.SCENE_DIR = target_dir predictions = demo_fn(cfg) glbfile = vggsfm_predictions_to_glb(predictions) print(input_image) print(input_video) return glbfile, "Success" def vggsfm_predictions_to_glb(predictions): # learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py points3D = predictions["points3D"].cpu().numpy() points3D_rgb = predictions["points3D_rgb"].cpu().numpy() points3D_rgb = (points3D_rgb*255).astype(np.uint8) extrinsics_opencv = predictions["extrinsics_opencv"].cpu().numpy() intrinsics_opencv = predictions["intrinsics_opencv"].cpu().numpy() raw_image_paths = predictions["raw_image_paths"] images = predictions["images"].permute(0,2,3,1).cpu().numpy() images = (images*255).astype(np.uint8) glbscene = trimesh.Scene() point_cloud = trimesh.PointCloud(points3D, colors=points3D_rgb) glbscene.add_geometry(point_cloud) camera_edge_colors = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204), (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)] frame_num = len(extrinsics_opencv) extrinsics_opencv_4x4 = np.zeros((frame_num, 4, 4)) extrinsics_opencv_4x4[:, :3, :4] = extrinsics_opencv extrinsics_opencv_4x4[:, 3, 3] = 1 for idx in range(frame_num): cam_from_world = extrinsics_opencv_4x4[idx] cam_to_world = np.linalg.inv(cam_from_world) cur_cam_color = camera_edge_colors[idx % len(camera_edge_colors)] cur_focal = intrinsics_opencv[idx, 0, 0] # cur_image_path = raw_image_paths[idx] # cur_image = np.array(PIL.Image.open(cur_image_path)) # add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=cur_image.shape[1::-1], # focal=None,screen_width=0.3) add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=(1024,1024), focal=None,screen_width=0.35) opengl_mat = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) rot = np.eye(4) rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() glbscene.apply_transform(np.linalg.inv(np.linalg.inv(extrinsics_opencv_4x4[0]) @ opengl_mat @ rot)) glbfile = "glbscene.glb" glbscene.export(file_obj=glbfile) return glbfile if True: demo = gr.Interface( title="🎨 VGGSfM: Visual Geometry Grounded Deep Structure From Motion", description="
\

Welcome to VGGSfM!", fn=vggsfm_demo, inputs=[ gr.File(file_count="multiple", label="Input Images", interactive=True), gr.Video(label="Input video", interactive=True), gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of query images"), gr.Slider(minimum=512, maximum=4096, step=1, value=1024, label="Number of query points"), ], outputs=[gr.Model3D(label="Reconstruction"), gr.Textbox(label="Log")], cache_examples=True, allow_flagging=False, ) demo.queue(max_size=20, concurrency_count=1).launch(debug=True) # demo.launch(debug=True, share=True) else: import glob files = glob.glob(f'vggsfm_code/examples/cake/images/*', recursive=True) vggsfm_demo(files, None, None) # demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)