vggsfm / app.py
JianyuanWang's picture
push
34d287c
raw
history blame
5.67 kB
import os
import cv2
import torch
import numpy as np
import gradio as gr
import trimesh
import sys
import os
sys.path.append('vggsfm_code/')
import shutil
from vggsfm_code.hf_demo import demo_fn
from omegaconf import DictConfig, OmegaConf
from viz_utils.viz_fn import add_camera
#
from scipy.spatial.transform import Rotation
import PIL
import spaces
@spaces.GPU
def vggsfm_demo(
input_image,
input_video,
query_frame_num,
max_query_pts
# grid_size: int = 10,
):
cfg_file = "vggsfm_code/cfgs/demo.yaml"
cfg = OmegaConf.load(cfg_file)
max_input_image = 20
target_dir = f"input_images"
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
target_dir_images = target_dir + "/images"
os.makedirs(target_dir_images)
if input_image is not None:
if len(input_image)<3:
return None, "Please input at least three frames"
input_image = sorted(input_image)
input_image = input_image[:max_input_image]
# Copy files to the new directory
for file_name in input_image:
shutil.copy(file_name, target_dir_images)
elif input_video is not None:
vs = cv2.VideoCapture(input_video)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_rate = 1
frame_interval = int(fps * frame_rate)
video_frame_num = 0
count = 0
while video_frame_num<=max_input_image:
(gotit, frame) = vs.read()
count +=1
if count % frame_interval == 0:
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
video_frame_num+=1
if not gotit:
break
if video_frame_num<3:
return None, "Please input at least three frames"
else:
return None, "Input format incorrect"
cfg.query_frame_num = query_frame_num
cfg.max_query_pts = max_query_pts
print(f"Files have been copied to {target_dir_images}")
cfg.SCENE_DIR = target_dir
predictions = demo_fn(cfg)
glbfile = vggsfm_predictions_to_glb(predictions)
print(input_image)
print(input_video)
return glbfile, "Success"
def vggsfm_predictions_to_glb(predictions):
# learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py
points3D = predictions["points3D"].cpu().numpy()
points3D_rgb = predictions["points3D_rgb"].cpu().numpy()
points3D_rgb = (points3D_rgb*255).astype(np.uint8)
extrinsics_opencv = predictions["extrinsics_opencv"].cpu().numpy()
intrinsics_opencv = predictions["intrinsics_opencv"].cpu().numpy()
raw_image_paths = predictions["raw_image_paths"]
images = predictions["images"].permute(0,2,3,1).cpu().numpy()
images = (images*255).astype(np.uint8)
glbscene = trimesh.Scene()
point_cloud = trimesh.PointCloud(points3D, colors=points3D_rgb)
glbscene.add_geometry(point_cloud)
camera_edge_colors = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
(128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
frame_num = len(extrinsics_opencv)
extrinsics_opencv_4x4 = np.zeros((frame_num, 4, 4))
extrinsics_opencv_4x4[:, :3, :4] = extrinsics_opencv
extrinsics_opencv_4x4[:, 3, 3] = 1
for idx in range(frame_num):
cam_from_world = extrinsics_opencv_4x4[idx]
cam_to_world = np.linalg.inv(cam_from_world)
cur_cam_color = camera_edge_colors[idx % len(camera_edge_colors)]
cur_focal = intrinsics_opencv[idx, 0, 0]
# cur_image_path = raw_image_paths[idx]
# cur_image = np.array(PIL.Image.open(cur_image_path))
# add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=cur_image.shape[1::-1],
# focal=None,screen_width=0.3)
add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=(1024,1024),
focal=None,screen_width=0.35)
opengl_mat = np.array([[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
glbscene.apply_transform(np.linalg.inv(np.linalg.inv(extrinsics_opencv_4x4[0]) @ opengl_mat @ rot))
glbfile = "glbscene.glb"
glbscene.export(file_obj=glbfile)
return glbfile
if True:
demo = gr.Interface(
title="🎨 VGGSfM: Visual Geometry Grounded Deep Structure From Motion",
description="<div style='text-align: left;'> \
<p>Welcome to <a href='https://github.com/facebookresearch/vggsfm' target='_blank'>VGGSfM</a>!",
fn=vggsfm_demo,
inputs=[
gr.File(file_count="multiple", label="Input Images", interactive=True),
gr.Video(label="Input video", interactive=True),
gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of query images"),
gr.Slider(minimum=512, maximum=4096, step=1, value=1024, label="Number of query points"),
],
outputs=[gr.Model3D(label="Reconstruction"), gr.Textbox(label="Log")],
cache_examples=True,
allow_flagging=False,
)
demo.queue(max_size=20, concurrency_count=1).launch(debug=True)
# demo.launch(debug=True, share=True)
else:
import glob
files = glob.glob(f'vggsfm_code/examples/cake/images/*', recursive=True)
vggsfm_demo(files, None, None)
# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)