3dilize_anything / dust3r /render_to_3d.py
yansong1616's picture
Upload 90 files
56cd6b7 verified
raw
history blame
3.52 kB
import os
import torch
import numpy as np
import trimesh
from scipy.spatial.transform import Rotation
from dust3r.utils.device import to_numpy
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
import matplotlib.pyplot as plt
plt.ion()
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
batch_size = 1
# 将渲染的3D保存到outfile路径
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
cam_color=None, as_pointcloud=False, transparent_cams=False):
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
pts3d = to_numpy(pts3d)
imgs = to_numpy(imgs)
focals = to_numpy(focals)
cams2world = to_numpy(cams2world)
scene = trimesh.Scene()
# full pointcloud
if as_pointcloud:
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
scene.add_geometry(pct)
else:
meshes = []
for i in range(len(imgs)):
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
mesh = trimesh.Trimesh(**cat_meshes(meshes))
scene.add_geometry(mesh)
# add each camera
for i, pose_c2w in enumerate(cams2world):
if isinstance(cam_color, list):
camera_edge_color = cam_color[i]
else:
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
add_scene_cam(scene, pose_c2w, camera_edge_color,
None if transparent_cams else imgs[i], focals[i],
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
outfile = os.path.join(outdir, 'scene.glb')
print('(exporting 3D scene to', outfile, ')')
os.makedirs(outdir, exist_ok=True)
scene.export(file_obj=outfile)
return outfile
def get_3D_model_from_scene(outdir, scene, sam2_masks, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
clean_depth=False, transparent_cams=False, cam_size=0.05):
"""
extract 3D_model (glb file) from a reconstructed scene
"""
if scene is None:
return None
# post processes
if clean_depth:
scene = scene.clean_pointcloud()
if mask_sky:
scene = scene.mask_sky()
# get optimized values from scene
rgbimg = scene.imgs
focals = scene.get_focals().cpu()
cams2world = scene.get_im_poses().cpu()
# 3D pointcloud from depthmap, poses and intrinsics
pts3d = to_numpy(scene.get_pts3d())
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
msk = to_numpy(scene.get_masks())
assert len(msk) == len(sam2_masks)
# 将sam2输出的mask 和 dust3r输出的置信度阈值筛选后的msk取交集
for i in range(len(sam2_masks)):
msk[i] = np.logical_and(msk[i], sam2_masks[i])
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
transparent_cams=transparent_cams, cam_size=cam_size), msk # 置信度和SAM2 mask的交集