Spaces:
Sleeping
Sleeping
yansong1616
commited on
Commit
•
56cd6b7
1
Parent(s):
633d2c0
Upload 90 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- dust3r/__init__.py +2 -0
- dust3r/__pycache__/__init__.cpython-310.pyc +0 -0
- dust3r/__pycache__/__init__.cpython-38.pyc +0 -0
- dust3r/__pycache__/__init__.cpython-39.pyc +0 -0
- dust3r/__pycache__/image_pairs.cpython-310.pyc +0 -0
- dust3r/__pycache__/image_pairs.cpython-38.pyc +0 -0
- dust3r/__pycache__/inference.cpython-310.pyc +0 -0
- dust3r/__pycache__/inference.cpython-38.pyc +0 -0
- dust3r/__pycache__/inference.cpython-39.pyc +0 -0
- dust3r/__pycache__/model.cpython-310.pyc +0 -0
- dust3r/__pycache__/model.cpython-38.pyc +0 -0
- dust3r/__pycache__/model.cpython-39.pyc +0 -0
- dust3r/__pycache__/optim_factory.cpython-310.pyc +0 -0
- dust3r/__pycache__/optim_factory.cpython-38.pyc +0 -0
- dust3r/__pycache__/patch_embed.cpython-310.pyc +0 -0
- dust3r/__pycache__/patch_embed.cpython-38.pyc +0 -0
- dust3r/__pycache__/post_process.cpython-310.pyc +0 -0
- dust3r/__pycache__/render_to_3d.cpython-310.pyc +0 -0
- dust3r/__pycache__/viz.cpython-310.pyc +0 -0
- dust3r/__pycache__/viz.cpython-38.pyc +0 -0
- dust3r/cloud_opt/__init__.py +29 -0
- dust3r/cloud_opt/__pycache__/__init__.cpython-310.pyc +0 -0
- dust3r/cloud_opt/__pycache__/__init__.cpython-38.pyc +0 -0
- dust3r/cloud_opt/__pycache__/base_opt.cpython-310.pyc +0 -0
- dust3r/cloud_opt/__pycache__/base_opt.cpython-38.pyc +0 -0
- dust3r/cloud_opt/__pycache__/commons.cpython-310.pyc +0 -0
- dust3r/cloud_opt/__pycache__/commons.cpython-38.pyc +0 -0
- dust3r/cloud_opt/__pycache__/init_im_poses.cpython-310.pyc +0 -0
- dust3r/cloud_opt/__pycache__/init_im_poses.cpython-38.pyc +0 -0
- dust3r/cloud_opt/__pycache__/optimizer.cpython-310.pyc +0 -0
- dust3r/cloud_opt/__pycache__/optimizer.cpython-38.pyc +0 -0
- dust3r/cloud_opt/__pycache__/pair_viewer.cpython-310.pyc +0 -0
- dust3r/cloud_opt/base_opt.py +380 -0
- dust3r/cloud_opt/commons.py +91 -0
- dust3r/cloud_opt/init_im_poses.py +316 -0
- dust3r/cloud_opt/optimizer.py +249 -0
- dust3r/cloud_opt/pair_viewer.py +125 -0
- dust3r/datasets/__init__.py +42 -0
- dust3r/datasets/base/__init__.py +2 -0
- dust3r/datasets/base/base_stereo_view_dataset.py +220 -0
- dust3r/datasets/base/batched_sampler.py +74 -0
- dust3r/datasets/base/easy_dataset.py +157 -0
- dust3r/datasets/co3d.py +146 -0
- dust3r/datasets/utils/__init__.py +2 -0
- dust3r/datasets/utils/cropping.py +119 -0
- dust3r/datasets/utils/transforms.py +11 -0
- dust3r/heads/__init__.py +19 -0
- dust3r/heads/__pycache__/__init__.cpython-310.pyc +0 -0
- dust3r/heads/__pycache__/__init__.cpython-38.pyc +0 -0
- dust3r/heads/__pycache__/__init__.cpython-39.pyc +0 -0
dust3r/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (145 Bytes). View file
|
|
dust3r/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (143 Bytes). View file
|
|
dust3r/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (143 Bytes). View file
|
|
dust3r/__pycache__/image_pairs.cpython-310.pyc
ADDED
Binary file (3.19 kB). View file
|
|
dust3r/__pycache__/image_pairs.cpython-38.pyc
ADDED
Binary file (3.25 kB). View file
|
|
dust3r/__pycache__/inference.cpython-310.pyc
ADDED
Binary file (5.2 kB). View file
|
|
dust3r/__pycache__/inference.cpython-38.pyc
ADDED
Binary file (5.21 kB). View file
|
|
dust3r/__pycache__/inference.cpython-39.pyc
ADDED
Binary file (5.2 kB). View file
|
|
dust3r/__pycache__/model.cpython-310.pyc
ADDED
Binary file (5.99 kB). View file
|
|
dust3r/__pycache__/model.cpython-38.pyc
ADDED
Binary file (5.96 kB). View file
|
|
dust3r/__pycache__/model.cpython-39.pyc
ADDED
Binary file (5.97 kB). View file
|
|
dust3r/__pycache__/optim_factory.cpython-310.pyc
ADDED
Binary file (371 Bytes). View file
|
|
dust3r/__pycache__/optim_factory.cpython-38.pyc
ADDED
Binary file (367 Bytes). View file
|
|
dust3r/__pycache__/patch_embed.cpython-310.pyc
ADDED
Binary file (2.74 kB). View file
|
|
dust3r/__pycache__/patch_embed.cpython-38.pyc
ADDED
Binary file (2.76 kB). View file
|
|
dust3r/__pycache__/post_process.cpython-310.pyc
ADDED
Binary file (1.65 kB). View file
|
|
dust3r/__pycache__/render_to_3d.cpython-310.pyc
ADDED
Binary file (2.91 kB). View file
|
|
dust3r/__pycache__/viz.cpython-310.pyc
ADDED
Binary file (10.6 kB). View file
|
|
dust3r/__pycache__/viz.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
dust3r/cloud_opt/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# global alignment optimization wrapper function
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
from .optimizer import PointCloudOptimizer
|
10 |
+
from .pair_viewer import PairViewer
|
11 |
+
|
12 |
+
|
13 |
+
class GlobalAlignerMode(Enum):
|
14 |
+
PointCloudOptimizer = "PointCloudOptimizer"
|
15 |
+
PairViewer = "PairViewer"
|
16 |
+
|
17 |
+
|
18 |
+
def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):
|
19 |
+
# extract all inputs
|
20 |
+
view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
|
21 |
+
# build the optimizer
|
22 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
23 |
+
net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
|
24 |
+
elif mode == GlobalAlignerMode.PairViewer:
|
25 |
+
net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
|
26 |
+
else:
|
27 |
+
raise NotImplementedError(f'Unknown mode {mode}')
|
28 |
+
|
29 |
+
return net
|
dust3r/cloud_opt/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.04 kB). View file
|
|
dust3r/cloud_opt/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.03 kB). View file
|
|
dust3r/cloud_opt/__pycache__/base_opt.cpython-310.pyc
ADDED
Binary file (15.6 kB). View file
|
|
dust3r/cloud_opt/__pycache__/base_opt.cpython-38.pyc
ADDED
Binary file (15.8 kB). View file
|
|
dust3r/cloud_opt/__pycache__/commons.cpython-310.pyc
ADDED
Binary file (3.36 kB). View file
|
|
dust3r/cloud_opt/__pycache__/commons.cpython-38.pyc
ADDED
Binary file (3.41 kB). View file
|
|
dust3r/cloud_opt/__pycache__/init_im_poses.cpython-310.pyc
ADDED
Binary file (8.42 kB). View file
|
|
dust3r/cloud_opt/__pycache__/init_im_poses.cpython-38.pyc
ADDED
Binary file (8.45 kB). View file
|
|
dust3r/cloud_opt/__pycache__/optimizer.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
dust3r/cloud_opt/__pycache__/optimizer.cpython-38.pyc
ADDED
Binary file (11.4 kB). View file
|
|
dust3r/cloud_opt/__pycache__/pair_viewer.cpython-310.pyc
ADDED
Binary file (4.89 kB). View file
|
|
dust3r/cloud_opt/base_opt.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Base class for the global alignement procedure
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from copy import deepcopy
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import roma
|
13 |
+
from copy import deepcopy
|
14 |
+
import tqdm
|
15 |
+
|
16 |
+
from dust3r.utils.geometry import inv, geotrf
|
17 |
+
from dust3r.utils.device import to_numpy
|
18 |
+
from dust3r.utils.image import rgb
|
19 |
+
from dust3r.viz import SceneViz, segment_sky, auto_cam_size
|
20 |
+
from dust3r.optim_factory import adjust_learning_rate_by_lr
|
21 |
+
|
22 |
+
from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
|
23 |
+
cosine_schedule, linear_schedule, get_conf_trf)
|
24 |
+
import dust3r.cloud_opt.init_im_poses as init_fun
|
25 |
+
|
26 |
+
|
27 |
+
class BasePCOptimizer (nn.Module):
|
28 |
+
""" Optimize a global scene, given a list of pairwise observations.
|
29 |
+
Graph node: images
|
30 |
+
Graph edges: observations = (pred1, pred2)
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, *args, **kwargs):
|
34 |
+
if len(args) == 1 and len(kwargs) == 0:
|
35 |
+
other = deepcopy(args[0])
|
36 |
+
attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
|
37 |
+
min_conf_thr conf_thr conf_i conf_j im_conf
|
38 |
+
base_scale norm_pw_scale POSE_DIM pw_poses
|
39 |
+
pw_adaptors pw_adaptors has_im_poses rand_pose imgs'''.split()
|
40 |
+
self.__dict__.update({k: other[k] for k in attrs})
|
41 |
+
else:
|
42 |
+
self._init_from_views(*args, **kwargs)
|
43 |
+
|
44 |
+
def _init_from_views(self, view1, view2, pred1, pred2,
|
45 |
+
dist='l1',
|
46 |
+
conf='log',
|
47 |
+
min_conf_thr=3,
|
48 |
+
base_scale=0.5,
|
49 |
+
allow_pw_adaptors=False,
|
50 |
+
pw_break=20,
|
51 |
+
rand_pose=torch.randn,
|
52 |
+
iterationsCount=None,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
if not isinstance(view1['idx'], list):
|
56 |
+
view1['idx'] = view1['idx'].tolist()
|
57 |
+
if not isinstance(view2['idx'], list):
|
58 |
+
view2['idx'] = view2['idx'].tolist()
|
59 |
+
self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
|
60 |
+
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
|
61 |
+
self.dist = ALL_DISTS[dist]
|
62 |
+
|
63 |
+
|
64 |
+
self.n_imgs = self._check_edges()
|
65 |
+
|
66 |
+
# input data
|
67 |
+
pred1_pts = pred1['pts3d']
|
68 |
+
pred2_pts = pred2['pts3d_in_other_view']
|
69 |
+
self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
|
70 |
+
self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
|
71 |
+
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
|
72 |
+
|
73 |
+
# work in log-scale with conf
|
74 |
+
pred1_conf = pred1['conf']
|
75 |
+
pred2_conf = pred2['conf']
|
76 |
+
self.min_conf_thr = min_conf_thr
|
77 |
+
self.conf_trf = get_conf_trf(conf)
|
78 |
+
|
79 |
+
self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
|
80 |
+
self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
|
81 |
+
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
|
82 |
+
|
83 |
+
# pairwise pose parameters
|
84 |
+
self.base_scale = base_scale
|
85 |
+
self.norm_pw_scale = True
|
86 |
+
self.pw_break = pw_break
|
87 |
+
self.POSE_DIM = 7
|
88 |
+
self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
|
89 |
+
self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
|
90 |
+
self.pw_adaptors.requires_grad_(allow_pw_adaptors)
|
91 |
+
self.has_im_poses = False
|
92 |
+
self.rand_pose = rand_pose
|
93 |
+
|
94 |
+
# possibly store images for show_pointcloud
|
95 |
+
self.imgs = None
|
96 |
+
if 'img' in view1 and 'img' in view2:
|
97 |
+
imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
|
98 |
+
for v in range(len(self.edges)):
|
99 |
+
idx = view1['idx'][v]
|
100 |
+
imgs[idx] = view1['img'][v]
|
101 |
+
idx = view2['idx'][v]
|
102 |
+
imgs[idx] = view2['img'][v]
|
103 |
+
self.imgs = rgb(imgs)
|
104 |
+
|
105 |
+
@property
|
106 |
+
def n_edges(self):
|
107 |
+
return len(self.edges)
|
108 |
+
|
109 |
+
@property
|
110 |
+
def str_edges(self):
|
111 |
+
return [edge_str(i, j) for i, j in self.edges]
|
112 |
+
|
113 |
+
@property
|
114 |
+
def imsizes(self):
|
115 |
+
return [(w, h) for h, w in self.imshapes]
|
116 |
+
|
117 |
+
@property
|
118 |
+
def device(self):
|
119 |
+
return next(iter(self.parameters())).device
|
120 |
+
|
121 |
+
def state_dict(self, trainable=True):
|
122 |
+
all_params = super().state_dict()
|
123 |
+
return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
|
124 |
+
|
125 |
+
def load_state_dict(self, data):
|
126 |
+
return super().load_state_dict(self.state_dict(trainable=False) | data)
|
127 |
+
|
128 |
+
def _check_edges(self):
|
129 |
+
indices = sorted({i for edge in self.edges for i in edge})
|
130 |
+
assert indices == list(range(len(indices))), 'bad pair indices: missing values '
|
131 |
+
return len(indices)
|
132 |
+
|
133 |
+
@torch.no_grad()
|
134 |
+
def _compute_img_conf(self, pred1_conf, pred2_conf):
|
135 |
+
im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
|
136 |
+
for e, (i, j) in enumerate(self.edges):
|
137 |
+
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
|
138 |
+
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
|
139 |
+
return im_conf
|
140 |
+
|
141 |
+
def get_adaptors(self): # 公式(5)中的σ_e
|
142 |
+
adapt = self.pw_adaptors
|
143 |
+
adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
|
144 |
+
if self.norm_pw_scale: # normalize so that the product == 1
|
145 |
+
adapt = adapt - adapt.mean(dim=1, keepdim=True) # 归一化
|
146 |
+
return (adapt / self.pw_break).exp() # TODO gys:公式(5)中的σ_e是什么?
|
147 |
+
|
148 |
+
def _get_poses(self, poses): # self.im_poses 或者 self.pw_poses
|
149 |
+
# normalize rotation
|
150 |
+
Q = poses[:, :4]
|
151 |
+
T = signed_expm1(poses[:, 4:7])
|
152 |
+
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
|
153 |
+
return RT
|
154 |
+
|
155 |
+
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
|
156 |
+
# all poses == cam-to-world
|
157 |
+
pose = poses[idx]
|
158 |
+
if not (pose.requires_grad or force):
|
159 |
+
return pose
|
160 |
+
|
161 |
+
if R.shape == (4, 4):
|
162 |
+
assert T is None
|
163 |
+
T = R[:3, 3]
|
164 |
+
R = R[:3, :3]
|
165 |
+
|
166 |
+
if R is not None:
|
167 |
+
pose.data[0:4] = roma.rotmat_to_unitquat(R)
|
168 |
+
if T is not None:
|
169 |
+
pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
|
170 |
+
|
171 |
+
if scale is not None:
|
172 |
+
assert poses.shape[-1] in (8, 13)
|
173 |
+
pose.data[-1] = np.log(float(scale))
|
174 |
+
return pose
|
175 |
+
|
176 |
+
def get_pw_norm_scale_factor(self):
|
177 |
+
if self.norm_pw_scale:
|
178 |
+
# normalize scales so that things cannot go south
|
179 |
+
# we want that exp(scale) ~= self.base_scale
|
180 |
+
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
|
181 |
+
else:
|
182 |
+
return 1 # don't norm scale for known poses
|
183 |
+
|
184 |
+
def get_pw_scale(self):
|
185 |
+
scale = self.pw_poses[:, -1].exp() # (n_edges,)
|
186 |
+
scale = scale * self.get_pw_norm_scale_factor()
|
187 |
+
return scale
|
188 |
+
|
189 |
+
def get_pw_poses(self): # cam to world
|
190 |
+
RT = self._get_poses(self.pw_poses)
|
191 |
+
scaled_RT = RT.clone()
|
192 |
+
scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
|
193 |
+
return scaled_RT
|
194 |
+
|
195 |
+
def get_masks(self):
|
196 |
+
return [(conf > self.min_conf_thr) for conf in self.im_conf]
|
197 |
+
|
198 |
+
def depth_to_pts3d(self):
|
199 |
+
raise NotImplementedError()
|
200 |
+
|
201 |
+
def get_pts3d(self, raw=False):
|
202 |
+
res = self.depth_to_pts3d()
|
203 |
+
if not raw:
|
204 |
+
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
205 |
+
return res
|
206 |
+
|
207 |
+
def _set_focal(self, idx, focal, force=False):
|
208 |
+
raise NotImplementedError()
|
209 |
+
|
210 |
+
def get_focals(self):
|
211 |
+
raise NotImplementedError()
|
212 |
+
|
213 |
+
def get_known_focal_mask(self):
|
214 |
+
raise NotImplementedError()
|
215 |
+
|
216 |
+
def get_principal_points(self):
|
217 |
+
raise NotImplementedError()
|
218 |
+
|
219 |
+
def get_conf(self, mode=None):
|
220 |
+
trf = self.conf_trf if mode is None else get_conf_trf(mode)
|
221 |
+
return [trf(c) for c in self.im_conf]
|
222 |
+
|
223 |
+
def get_im_poses(self):
|
224 |
+
raise NotImplementedError()
|
225 |
+
|
226 |
+
def _set_depthmap(self, idx, depth, force=False):
|
227 |
+
raise NotImplementedError()
|
228 |
+
|
229 |
+
def get_depthmaps(self, raw=False):
|
230 |
+
raise NotImplementedError()
|
231 |
+
|
232 |
+
@torch.no_grad()
|
233 |
+
def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
|
234 |
+
""" Method:
|
235 |
+
1) express all 3d points in each camera coordinate frame
|
236 |
+
2) if they're in front of a depthmap --> then lower their confidence
|
237 |
+
"""
|
238 |
+
assert 0 <= tol < 1
|
239 |
+
cams = inv(self.get_im_poses())
|
240 |
+
K = self.get_intrinsics()
|
241 |
+
depthmaps = self.get_depthmaps()
|
242 |
+
res = deepcopy(self)
|
243 |
+
|
244 |
+
for i, pts3d in enumerate(self.depth_to_pts3d()):
|
245 |
+
for j in range(self.n_imgs):
|
246 |
+
if i == j:
|
247 |
+
continue
|
248 |
+
|
249 |
+
# project 3dpts in other view
|
250 |
+
Hi, Wi = self.imshapes[i]
|
251 |
+
Hj, Wj = self.imshapes[j]
|
252 |
+
proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
|
253 |
+
proj_depth = proj[:, :, 2]
|
254 |
+
u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
|
255 |
+
|
256 |
+
# check which points are actually in the visible cone
|
257 |
+
msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
|
258 |
+
msk_j = v[msk_i], u[msk_i]
|
259 |
+
|
260 |
+
# find bad points = those in front but less confident
|
261 |
+
bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
|
262 |
+
) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
|
263 |
+
|
264 |
+
bad_msk_i = msk_i.clone()
|
265 |
+
bad_msk_i[msk_i] = bad_points
|
266 |
+
res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
|
267 |
+
|
268 |
+
return res
|
269 |
+
|
270 |
+
def forward(self, ret_details=False):
|
271 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
272 |
+
pw_adapt = self.get_adaptors()
|
273 |
+
proj_pts3d = self.get_pts3d()
|
274 |
+
# pre-compute pixel weights
|
275 |
+
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
|
276 |
+
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
|
277 |
+
|
278 |
+
loss = 0
|
279 |
+
if ret_details:
|
280 |
+
details = -torch.ones((self.n_imgs, self.n_imgs))
|
281 |
+
|
282 |
+
for e, (i, j) in enumerate(self.edges):
|
283 |
+
i_j = edge_str(i, j)
|
284 |
+
# distance in image i and j
|
285 |
+
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
|
286 |
+
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
|
287 |
+
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
|
288 |
+
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
|
289 |
+
loss = loss + li + lj
|
290 |
+
|
291 |
+
if ret_details:
|
292 |
+
details[i, j] = li + lj
|
293 |
+
loss /= self.n_edges # average over all pairs
|
294 |
+
|
295 |
+
if ret_details:
|
296 |
+
return loss, details
|
297 |
+
return loss
|
298 |
+
|
299 |
+
def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
|
300 |
+
if init is None:
|
301 |
+
pass
|
302 |
+
elif init == 'msp' or init == 'mst':
|
303 |
+
# ==============3.3.Downstream Applications:主要是为3.4. Global Alignment中的公式(5)初始化内外参矩阵和待估计的世界坐标系的坐标============
|
304 |
+
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
|
305 |
+
elif init == 'known_poses':
|
306 |
+
init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP)
|
307 |
+
else:
|
308 |
+
raise ValueError(f'bad value for {init=}')
|
309 |
+
|
310 |
+
global_alignment_loop(self, **kw) # 3.4. Global Alignment:梯度下降公式(5)
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def mask_sky(self):
|
314 |
+
res = deepcopy(self)
|
315 |
+
for i in range(self.n_imgs):
|
316 |
+
sky = segment_sky(self.imgs[i])
|
317 |
+
res.im_conf[i][sky] = 0
|
318 |
+
return res
|
319 |
+
|
320 |
+
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
|
321 |
+
viz = SceneViz()
|
322 |
+
if self.imgs is None:
|
323 |
+
colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
|
324 |
+
colors = list(map(tuple, colors.tolist()))
|
325 |
+
for n in range(self.n_imgs):
|
326 |
+
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
|
327 |
+
else:
|
328 |
+
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
|
329 |
+
colors = np.random.randint(256, size=(self.n_imgs, 3))
|
330 |
+
|
331 |
+
# camera poses
|
332 |
+
im_poses = to_numpy(self.get_im_poses())
|
333 |
+
if cam_size is None:
|
334 |
+
cam_size = auto_cam_size(im_poses)
|
335 |
+
viz.add_cameras(im_poses, self.get_focals(), colors=colors,
|
336 |
+
images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
|
337 |
+
if show_pw_cams:
|
338 |
+
pw_poses = self.get_pw_poses()
|
339 |
+
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
|
340 |
+
|
341 |
+
if show_pw_pts3d:
|
342 |
+
pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
|
343 |
+
viz.add_pointcloud(pts, (128, 0, 128))
|
344 |
+
|
345 |
+
viz.show(**kw)
|
346 |
+
return viz
|
347 |
+
|
348 |
+
|
349 |
+
def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6, verbose=False):
|
350 |
+
params = [p for p in net.parameters() if p.requires_grad]
|
351 |
+
if not params:
|
352 |
+
return net
|
353 |
+
|
354 |
+
if verbose:
|
355 |
+
print([name for name, value in net.named_parameters() if value.requires_grad])
|
356 |
+
|
357 |
+
lr_base = lr
|
358 |
+
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
|
359 |
+
|
360 |
+
with tqdm.tqdm(total=niter) as bar:
|
361 |
+
while bar.n < bar.total:
|
362 |
+
t = bar.n / bar.total
|
363 |
+
|
364 |
+
if schedule == 'cosine':
|
365 |
+
lr = cosine_schedule(t, lr_base, lr_min)
|
366 |
+
elif schedule == 'linear':
|
367 |
+
lr = linear_schedule(t, lr_base, lr_min)
|
368 |
+
else:
|
369 |
+
raise ValueError(f'bad lr {schedule=}')
|
370 |
+
adjust_learning_rate_by_lr(optimizer, lr)
|
371 |
+
|
372 |
+
optimizer.zero_grad()
|
373 |
+
loss = net() # 论文中:Global optimization
|
374 |
+
loss.backward()
|
375 |
+
optimizer.step()
|
376 |
+
loss = float(loss)
|
377 |
+
bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
|
378 |
+
if bar.n % 30 == 0:
|
379 |
+
print(' ')
|
380 |
+
bar.update()
|
dust3r/cloud_opt/commons.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utility functions for global alignment
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def edge_str(i, j):
|
13 |
+
return f'{i}_{j}'
|
14 |
+
|
15 |
+
|
16 |
+
def i_j_ij(ij):
|
17 |
+
return edge_str(*ij), ij
|
18 |
+
|
19 |
+
|
20 |
+
def edge_conf(conf_i, conf_j, edge):
|
21 |
+
return float(conf_i[edge].mean() * conf_j[edge].mean())
|
22 |
+
# edge对应的两张图片经dust3r输出的置信度,分别对两张图片所有像素点的置信度取平均值再相乘,作为当前edge的置信度
|
23 |
+
|
24 |
+
|
25 |
+
def compute_edge_scores(edges, conf_i, conf_j):# edge对应的两张图片经dust3r会输出两个置信度矩阵,分别对两张图片所有像素点的置信度取平均值再相乘,作为当前edge的置信度
|
26 |
+
return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
|
27 |
+
|
28 |
+
|
29 |
+
def NoGradParamDict(x):
|
30 |
+
assert isinstance(x, dict)
|
31 |
+
return nn.ParameterDict(x).requires_grad_(False)
|
32 |
+
|
33 |
+
|
34 |
+
def get_imshapes(edges, pred_i, pred_j):
|
35 |
+
n_imgs = max(max(e) for e in edges) + 1
|
36 |
+
imshapes = [None] * n_imgs
|
37 |
+
for e, (i, j) in enumerate(edges):
|
38 |
+
shape_i = tuple(pred_i[e].shape[0:2])
|
39 |
+
shape_j = tuple(pred_j[e].shape[0:2])
|
40 |
+
if imshapes[i]:
|
41 |
+
assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
|
42 |
+
if imshapes[j]:
|
43 |
+
assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
|
44 |
+
imshapes[i] = shape_i
|
45 |
+
imshapes[j] = shape_j
|
46 |
+
return imshapes
|
47 |
+
|
48 |
+
|
49 |
+
def get_conf_trf(mode):
|
50 |
+
if mode == 'log':
|
51 |
+
def conf_trf(x): return x.log()
|
52 |
+
elif mode == 'sqrt':
|
53 |
+
def conf_trf(x): return x.sqrt()
|
54 |
+
elif mode == 'm1':
|
55 |
+
def conf_trf(x): return x-1
|
56 |
+
elif mode in ('id', 'none'):
|
57 |
+
def conf_trf(x): return x
|
58 |
+
else:
|
59 |
+
raise ValueError(f'bad mode for {mode=}')
|
60 |
+
return conf_trf
|
61 |
+
|
62 |
+
|
63 |
+
def l2_dist(a, b, weight):
|
64 |
+
return ((a - b).square().sum(dim=-1) * weight)
|
65 |
+
|
66 |
+
|
67 |
+
def l1_dist(a, b, weight):
|
68 |
+
return ((a - b).norm(dim=-1) * weight) # torch.norm()是求范式的损失,默认是第二范式
|
69 |
+
|
70 |
+
|
71 |
+
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
|
72 |
+
|
73 |
+
|
74 |
+
def signed_log1p(x):
|
75 |
+
sign = torch.sign(x)
|
76 |
+
return sign * torch.log1p(torch.abs(x))
|
77 |
+
|
78 |
+
|
79 |
+
def signed_expm1(x):
|
80 |
+
sign = torch.sign(x)
|
81 |
+
return sign * torch.expm1(torch.abs(x))
|
82 |
+
|
83 |
+
|
84 |
+
def cosine_schedule(t, lr_start, lr_end):
|
85 |
+
assert 0 <= t <= 1
|
86 |
+
return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
|
87 |
+
|
88 |
+
|
89 |
+
def linear_schedule(t, lr_start, lr_end):
|
90 |
+
assert 0 <= t <= 1
|
91 |
+
return lr_start + (lr_end - lr_start) * t
|
dust3r/cloud_opt/init_im_poses.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Initialization functions for global alignment
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from functools import cache
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import scipy.sparse as sp
|
11 |
+
import torch
|
12 |
+
import cv2
|
13 |
+
import roma
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
|
17 |
+
from dust3r.post_process import estimate_focal_knowing_depth
|
18 |
+
from dust3r.viz import to_numpy
|
19 |
+
|
20 |
+
from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
|
21 |
+
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
|
25 |
+
device = self.device
|
26 |
+
|
27 |
+
# indices of known poses
|
28 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
29 |
+
assert nkp == self.n_imgs, 'not all poses are known'
|
30 |
+
|
31 |
+
# get all focals
|
32 |
+
nkf, _, im_focals = get_known_focals(self)
|
33 |
+
assert nkf == self.n_imgs
|
34 |
+
im_pp = self.get_principal_points()
|
35 |
+
|
36 |
+
best_depthmaps = {}
|
37 |
+
# init all pairwise poses
|
38 |
+
for e, (i, j) in enumerate(tqdm(self.edges)):
|
39 |
+
i_j = edge_str(i, j)
|
40 |
+
|
41 |
+
# find relative pose for this pair
|
42 |
+
P1 = torch.eye(4, device=device)
|
43 |
+
msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
|
44 |
+
_, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
|
45 |
+
pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
|
46 |
+
|
47 |
+
# align the two predicted camera with the two gt cameras
|
48 |
+
s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
|
49 |
+
# normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
|
50 |
+
# and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
|
51 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
52 |
+
|
53 |
+
# remember if this is a good depthmap
|
54 |
+
score = float(self.conf_i[i_j].mean())
|
55 |
+
if score > best_depthmaps.get(i, (0,))[0]:
|
56 |
+
best_depthmaps[i] = score, i_j, s
|
57 |
+
|
58 |
+
# init all image poses
|
59 |
+
for n in range(self.n_imgs):
|
60 |
+
assert known_poses_msk[n]
|
61 |
+
_, i_j, scale = best_depthmaps[n]
|
62 |
+
depth = self.pred_i[i_j][:, :, 2]
|
63 |
+
self._set_depthmap(n, depth * scale)
|
64 |
+
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def init_minimum_spanning_tree(self, **kw):
|
68 |
+
""" Init all camera poses (image-wise and pairwise poses) given
|
69 |
+
an initial set of pairwise estimations.
|
70 |
+
"""
|
71 |
+
device = self.device
|
72 |
+
pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
|
73 |
+
self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
|
74 |
+
device, has_im_poses=self.has_im_poses, **kw)
|
75 |
+
|
76 |
+
return init_from_pts3d(self, pts3d, im_focals, im_poses) # 初始化
|
77 |
+
|
78 |
+
|
79 |
+
def init_from_pts3d(self, pts3d, im_focals, im_poses):
|
80 |
+
# init poses
|
81 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
82 |
+
if nkp == 1: # 0
|
83 |
+
raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
|
84 |
+
elif nkp > 1:
|
85 |
+
# global rigid SE3 alignment
|
86 |
+
s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
|
87 |
+
trf = sRT_to_4x4(s, R, T, device=known_poses.device)
|
88 |
+
|
89 |
+
# rotate everything
|
90 |
+
im_poses = trf @ im_poses
|
91 |
+
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
|
92 |
+
for img_pts3d in pts3d:
|
93 |
+
img_pts3d[:] = geotrf(trf, img_pts3d)
|
94 |
+
|
95 |
+
# pw_poses:遍历所有的edge,计算每个edge对应的(即输入dust3r的第一张图片的)相机坐标系转成“世界坐标系”的转换矩阵即P_e
|
96 |
+
for e, (i, j) in enumerate(self.edges):
|
97 |
+
i_j = edge_str(i, j)
|
98 |
+
# compute transform that goes from cam to world
|
99 |
+
# pred_i:dust3r输出的第一张图片对应的3D点云
|
100 |
+
s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) # 估计每个edge对应的相机坐标系转成世界坐标系的外参矩阵
|
101 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s) # pw_poses *****************
|
102 |
+
|
103 |
+
# TODO gys:s_factor是什么? take into account the scale normalization
|
104 |
+
s_factor = self.get_pw_norm_scale_factor()
|
105 |
+
im_poses[:, :3, 3] *= s_factor # apply downscaling factorS
|
106 |
+
for img_pts3d in pts3d:
|
107 |
+
img_pts3d *= s_factor
|
108 |
+
|
109 |
+
# init all image poses
|
110 |
+
if self.has_im_poses:
|
111 |
+
for i in range(self.n_imgs):
|
112 |
+
cam2world = im_poses[i]
|
113 |
+
depth = geotrf(inv(cam2world), pts3d[i])[..., 2] # 将世界坐标系的点pts3d[i]转成相机坐标系
|
114 |
+
self._set_depthmap(i, depth)
|
115 |
+
self._set_pose(self.im_poses, i, cam2world) # im_poses ********************
|
116 |
+
if im_focals[i] is not None:
|
117 |
+
self._set_focal(i, im_focals[i])
|
118 |
+
|
119 |
+
print(' init loss =', float(self()))
|
120 |
+
|
121 |
+
|
122 |
+
def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
|
123 |
+
device, has_im_poses=True, niter_PnP=10):
|
124 |
+
n_imgs = len(imshapes)
|
125 |
+
sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)) # 计算置信度,返回一个矩阵,表示两两图片表示的edge的置信度
|
126 |
+
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() # 将上面的矩阵转换成最小生成树,因为sparse_graph加了负号,所以这里筛选出来的其实是最大的置信度
|
127 |
+
# 上面找最小生成树的目的是:为每个图片尽量选一个置信度最大的edge,因为每两两图片之间都存在一个edge
|
128 |
+
# temp variable to store 3d points
|
129 |
+
pts3d = [None] * len(imshapes) # 长度为5的空list(输入图片的数量是5)
|
130 |
+
|
131 |
+
todo = sorted(zip(-msp.data, msp.row, msp.col)) # 根据最小生成树选出:平均置信度最大的4个edge(输入图片的数量是5),这4个edge一定包含5张输入图像 ,因为是生成树 # sorted edges
|
132 |
+
im_poses = [None] * n_imgs
|
133 |
+
im_focals = [None] * n_imgs
|
134 |
+
|
135 |
+
# init with strongest edge
|
136 |
+
score, i, j = todo.pop() # 这里的socre是compute_edge_scores函数计算出的置信度
|
137 |
+
print(f' init edge ({i}*,{j}*) {score=}')
|
138 |
+
i_j = edge_str(i, j)
|
139 |
+
pts3d[i] = pred_i[i_j].clone() # 置信度最大的edge对应的两张图片的三维点云(对与所有图片,每两张图片经dust3r都会输出两个三维点云)
|
140 |
+
pts3d[j] = pred_j[i_j].clone()
|
141 |
+
done = {i, j}
|
142 |
+
if has_im_poses: #============选择置信度最高edge中的第一张图片的相机坐标系为世界坐标系==============
|
143 |
+
im_poses[i] = torch.eye(4, device=device) # 4*4的单位矩阵,因为该图片的相机坐标系就是世界坐标系,所以外参矩阵为单位矩阵
|
144 |
+
im_focals[i] = estimate_focal(pred_i[i_j]) # 3.3 估计内参矩阵
|
145 |
+
|
146 |
+
# set initial pointcloud based on pairwise graph
|
147 |
+
msp_edges = [(i, j)]
|
148 |
+
while todo:
|
149 |
+
# each time, predict the next one
|
150 |
+
score, i, j = todo.pop() # pop把list最后一个元素弹出
|
151 |
+
|
152 |
+
if im_focals[i] is None: # 图片i对应的相机内参已经计算过了
|
153 |
+
im_focals[i] = estimate_focal(pred_i[i_j])
|
154 |
+
|
155 |
+
if i in done:
|
156 |
+
print(f' init edge ({i},{j}*) {score=}')
|
157 |
+
assert j not in done
|
158 |
+
# align pred[i] with pts3d[i], and then set j accordingly
|
159 |
+
i_j = edge_str(i, j)
|
160 |
+
s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) # 3.3 外参估计,s是sigma;直接调用roma工具包实现的
|
161 |
+
trf = sRT_to_4x4(s, R, T, device) # 存放到4*4的矩阵中,第四行是[0,0,0,1],对应齐次坐标的转换
|
162 |
+
pts3d[j] = geotrf(trf, pred_j[i_j]) # pred_j[i_j]表示dust3r的输出:图片j在i的相机坐标系下的三维点云
|
163 |
+
done.add(j)
|
164 |
+
msp_edges.append((i, j))
|
165 |
+
|
166 |
+
if has_im_poses and im_poses[i] is None:
|
167 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
168 |
+
|
169 |
+
elif j in done:
|
170 |
+
print(f' init edge ({i}*,{j}) {score=}')
|
171 |
+
assert i not in done
|
172 |
+
i_j = edge_str(i, j)
|
173 |
+
s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) # 从pred_j[i_j]转换到 pts3d[j]的外参矩阵
|
174 |
+
trf = sRT_to_4x4(s, R, T, device)
|
175 |
+
pts3d[i] = geotrf(trf, pred_i[i_j]) # 应用估计出的外参矩阵将相机坐标系的点转成世界坐标系
|
176 |
+
done.add(i)
|
177 |
+
msp_edges.append((i, j))
|
178 |
+
|
179 |
+
if has_im_poses and im_poses[i] is None:
|
180 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
181 |
+
else:
|
182 |
+
# let's try again later
|
183 |
+
todo.insert(0, (score, i, j))
|
184 |
+
|
185 |
+
if has_im_poses:
|
186 |
+
# complete all missing informations
|
187 |
+
pair_scores = list(sparse_graph.values()) # already negative scores: less is best
|
188 |
+
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
|
189 |
+
for i, j in edges_from_best_to_worse.tolist():
|
190 |
+
if im_focals[i] is None:
|
191 |
+
im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
|
192 |
+
|
193 |
+
for i in range(n_imgs):
|
194 |
+
if im_poses[i] is None:
|
195 |
+
msk = im_conf[i] > min_conf_thr # 使用PnP算法估计外参矩阵
|
196 |
+
res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
|
197 |
+
if res:
|
198 |
+
im_focals[i], im_poses[i] = res
|
199 |
+
if im_poses[i] is None:
|
200 |
+
im_poses[i] = torch.eye(4, device=device)
|
201 |
+
im_poses = torch.stack(im_poses)
|
202 |
+
else:
|
203 |
+
im_poses = im_focals = None
|
204 |
+
|
205 |
+
return pts3d, msp_edges, im_focals, im_poses # pts3d表示:每个输入的图片在自己的相机坐标系下的三维点经im_poses转换成世界坐标系的点
|
206 |
+
|
207 |
+
|
208 |
+
def dict_to_sparse_graph(dic):
|
209 |
+
n_imgs = max(max(e) for e in dic) + 1 # 取出照片数量
|
210 |
+
for e in dic:
|
211 |
+
a1 = max(e)
|
212 |
+
a2 = 2
|
213 |
+
res = sp.dok_array((n_imgs, n_imgs))
|
214 |
+
for edge, value in dic.items():
|
215 |
+
res[edge] = value
|
216 |
+
return res # 将edge中存放的置信度转移到一个n_imgs * n_imgs大小的列表中
|
217 |
+
|
218 |
+
|
219 |
+
def rigid_points_registration(pts1, pts2, conf):
|
220 |
+
R, T, s = roma.rigid_points_registration( # 调用roma的工具类函数
|
221 |
+
pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
|
222 |
+
return s, R, T # return un-scaled (R, T)
|
223 |
+
|
224 |
+
|
225 |
+
def sRT_to_4x4(scale, R, T, device):
|
226 |
+
trf = torch.eye(4, device=device) # 单位矩阵
|
227 |
+
trf[:3, :3] = R * scale
|
228 |
+
trf[:3, 3] = T.ravel() # doesn't need scaling
|
229 |
+
return trf # 外参矩阵 3*4
|
230 |
+
|
231 |
+
|
232 |
+
def estimate_focal(pts3d_i, pp=None):
|
233 |
+
if pp is None:
|
234 |
+
H, W, THREE = pts3d_i.shape
|
235 |
+
assert THREE == 3
|
236 |
+
pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
|
237 |
+
focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(
|
238 |
+
0), focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5).ravel()
|
239 |
+
return float(focal)
|
240 |
+
|
241 |
+
|
242 |
+
@cache
|
243 |
+
def pixel_grid(H, W):
|
244 |
+
return np.mgrid[:W, :H].T.astype(np.float32)
|
245 |
+
|
246 |
+
|
247 |
+
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
|
248 |
+
# extract camera poses and focals with RANSAC-PnP
|
249 |
+
if msk.sum() < 4:
|
250 |
+
return None # we need at least 4 points for PnP
|
251 |
+
pts3d, msk = map(to_numpy, (pts3d, msk))
|
252 |
+
|
253 |
+
H, W, THREE = pts3d.shape
|
254 |
+
assert THREE == 3
|
255 |
+
pixels = pixel_grid(H, W)
|
256 |
+
|
257 |
+
if focal is None:
|
258 |
+
S = max(W, H)
|
259 |
+
tentative_focals = np.geomspace(S/2, S*3, 21)
|
260 |
+
else:
|
261 |
+
tentative_focals = [focal]
|
262 |
+
|
263 |
+
if pp is None:
|
264 |
+
pp = (W/2, H/2)
|
265 |
+
else:
|
266 |
+
pp = to_numpy(pp)
|
267 |
+
|
268 |
+
best = 0,
|
269 |
+
for focal in tentative_focals:
|
270 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
271 |
+
|
272 |
+
success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
|
273 |
+
iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
|
274 |
+
if not success:
|
275 |
+
continue
|
276 |
+
|
277 |
+
score = len(inliers)
|
278 |
+
if success and score > best[0]:
|
279 |
+
best = score, R, T, focal
|
280 |
+
|
281 |
+
if not best[0]:
|
282 |
+
return None
|
283 |
+
|
284 |
+
_, R, T, best_focal = best
|
285 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
286 |
+
R, T = map(torch.from_numpy, (R, T))
|
287 |
+
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
|
288 |
+
|
289 |
+
|
290 |
+
def get_known_poses(self):
|
291 |
+
if self.has_im_poses:
|
292 |
+
known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
|
293 |
+
known_poses = self.get_im_poses()
|
294 |
+
return known_poses_msk.sum(), known_poses_msk, known_poses
|
295 |
+
else:
|
296 |
+
return 0, None, None
|
297 |
+
|
298 |
+
|
299 |
+
def get_known_focals(self):
|
300 |
+
if self.has_im_poses:
|
301 |
+
known_focal_msk = self.get_known_focal_mask()
|
302 |
+
known_focals = self.get_focals()
|
303 |
+
return known_focal_msk.sum(), known_focal_msk, known_focals
|
304 |
+
else:
|
305 |
+
return 0, None, None
|
306 |
+
|
307 |
+
|
308 |
+
def align_multiple_poses(src_poses, target_poses):
|
309 |
+
N = len(src_poses)
|
310 |
+
assert src_poses.shape == target_poses.shape == (N, 4, 4)
|
311 |
+
|
312 |
+
def center_and_z(poses):
|
313 |
+
eps = get_med_dist_between_poses(poses) / 100
|
314 |
+
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
|
315 |
+
R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
|
316 |
+
return s, R, T
|
dust3r/cloud_opt/optimizer.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Main class for the implementation of the global alignment
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from dust3r.cloud_opt.base_opt import BasePCOptimizer
|
12 |
+
from dust3r.utils.geometry import xy_grid, geotrf
|
13 |
+
from dust3r.utils.device import to_cpu, to_numpy
|
14 |
+
|
15 |
+
|
16 |
+
class PointCloudOptimizer(BasePCOptimizer):
|
17 |
+
""" Optimize a global scene, given a list of pairwise observations.
|
18 |
+
Graph node: images
|
19 |
+
Graph edges: observations = (pred1, pred2)
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
|
25 |
+
self.has_im_poses = True # by definition of this class
|
26 |
+
self.focal_break = focal_break
|
27 |
+
|
28 |
+
# adding thing to optimize
|
29 |
+
self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
|
30 |
+
self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
|
31 |
+
self.im_focals = nn.ParameterList(torch.FloatTensor(
|
32 |
+
[self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
|
33 |
+
self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
|
34 |
+
self.im_pp.requires_grad_(optimize_pp)
|
35 |
+
|
36 |
+
self.imshape = self.imshapes[0]
|
37 |
+
im_areas = [h*w for h, w in self.imshapes]
|
38 |
+
|
39 |
+
|
40 |
+
self.max_area = max(im_areas)
|
41 |
+
|
42 |
+
|
43 |
+
# adding thing to optimize
|
44 |
+
self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
|
45 |
+
self.im_poses = ParameterStack(self.im_poses, is_param=True)
|
46 |
+
self.im_focals = ParameterStack(self.im_focals, is_param=True)
|
47 |
+
self.im_pp = ParameterStack(self.im_pp, is_param=True)
|
48 |
+
self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
|
49 |
+
self.register_buffer('_grid', ParameterStack(
|
50 |
+
[xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
|
51 |
+
|
52 |
+
# pre-compute pixel weights
|
53 |
+
self.register_buffer('_weight_i', ParameterStack(
|
54 |
+
[self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
55 |
+
self.register_buffer('_weight_j', ParameterStack(
|
56 |
+
[self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
57 |
+
|
58 |
+
# precompute
|
59 |
+
self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
|
60 |
+
self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
|
61 |
+
self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
|
62 |
+
self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
|
63 |
+
self.total_area_i = sum([im_areas[i] for i, j in self.edges])
|
64 |
+
self.total_area_j = sum([im_areas[j] for i, j in self.edges])
|
65 |
+
|
66 |
+
|
67 |
+
def _check_all_imgs_are_selected(self, msk):
|
68 |
+
assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
|
69 |
+
|
70 |
+
def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
|
71 |
+
self._check_all_imgs_are_selected(pose_msk)
|
72 |
+
|
73 |
+
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
|
74 |
+
known_poses = [known_poses]
|
75 |
+
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
|
76 |
+
print(f' (setting pose #{idx} = {pose[:3,3]})')
|
77 |
+
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
|
78 |
+
|
79 |
+
# normalize scale if there's less than 1 known pose
|
80 |
+
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
|
81 |
+
self.norm_pw_scale = (n_known_poses <= 1)
|
82 |
+
|
83 |
+
self.im_poses.requires_grad_(False)
|
84 |
+
self.norm_pw_scale = False
|
85 |
+
|
86 |
+
def preset_focal(self, known_focals, msk=None):
|
87 |
+
self._check_all_imgs_are_selected(msk)
|
88 |
+
|
89 |
+
for idx, focal in zip(self._get_msk_indices(msk), known_focals):
|
90 |
+
print(f' (setting focal #{idx} = {focal})')
|
91 |
+
self._no_grad(self._set_focal(idx, focal))
|
92 |
+
|
93 |
+
self.im_focals.requires_grad_(False)
|
94 |
+
|
95 |
+
def preset_principal_point(self, known_pp, msk=None):
|
96 |
+
self._check_all_imgs_are_selected(msk)
|
97 |
+
|
98 |
+
for idx, pp in zip(self._get_msk_indices(msk), known_pp):
|
99 |
+
print(f' (setting principal point #{idx} = {pp})')
|
100 |
+
self._no_grad(self._set_principal_point(idx, pp))
|
101 |
+
|
102 |
+
self.im_pp.requires_grad_(False)
|
103 |
+
|
104 |
+
def _get_msk_indices(self, msk):
|
105 |
+
if msk is None:
|
106 |
+
return range(self.n_imgs)
|
107 |
+
elif isinstance(msk, int):
|
108 |
+
return [msk]
|
109 |
+
elif isinstance(msk, (tuple, list)):
|
110 |
+
return self._get_msk_indices(np.array(msk))
|
111 |
+
elif msk.dtype in (bool, torch.bool, np.bool_):
|
112 |
+
assert len(msk) == self.n_imgs
|
113 |
+
return np.cumsum([0] + msk.tolist())
|
114 |
+
elif np.issubdtype(msk.dtype, np.integer):
|
115 |
+
return msk
|
116 |
+
else:
|
117 |
+
raise ValueError(f'bad {msk=}')
|
118 |
+
|
119 |
+
def _no_grad(self, tensor):
|
120 |
+
assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
|
121 |
+
|
122 |
+
def _set_focal(self, idx, focal, force=False):
|
123 |
+
param = self.im_focals[idx]
|
124 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
125 |
+
param.data[:] = self.focal_break * np.log(focal)
|
126 |
+
return param
|
127 |
+
|
128 |
+
def get_focals(self): # 论文中Recovering intrinsics章节:求内参矩阵(即焦距)
|
129 |
+
log_focals = torch.stack(list(self.im_focals), dim=0)
|
130 |
+
return (log_focals / self.focal_break).exp()
|
131 |
+
|
132 |
+
def get_known_focal_mask(self):
|
133 |
+
return torch.tensor([not (p.requires_grad) for p in self.im_focals])
|
134 |
+
|
135 |
+
def _set_principal_point(self, idx, pp, force=False):
|
136 |
+
param = self.im_pp[idx]
|
137 |
+
H, W = self.imshapes[idx]
|
138 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
139 |
+
param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
|
140 |
+
return param
|
141 |
+
|
142 |
+
def get_principal_points(self):
|
143 |
+
return self._pp + 10 * self.im_pp # 将图像坐标系和像素坐标系的中心点偏移量
|
144 |
+
|
145 |
+
def get_intrinsics(self):
|
146 |
+
K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
|
147 |
+
focals = self.get_focals().flatten()
|
148 |
+
K[:, 0, 0] = K[:, 1, 1] = focals
|
149 |
+
K[:, :2, 2] = self.get_principal_points()
|
150 |
+
K[:, 2, 2] = 1
|
151 |
+
return K
|
152 |
+
|
153 |
+
def get_im_poses(self): # cam to world 外参数矩阵的逆
|
154 |
+
cam2world = self._get_poses(self.im_poses)
|
155 |
+
return cam2world
|
156 |
+
|
157 |
+
def _set_depthmap(self, idx, depth, force=False):
|
158 |
+
depth = _ravel_hw(depth, self.max_area)
|
159 |
+
|
160 |
+
param = self.im_depthmaps[idx]
|
161 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
162 |
+
param.data[:] = depth.log().nan_to_num(neginf=0)
|
163 |
+
return param
|
164 |
+
|
165 |
+
def get_depthmaps(self, raw=False): #论文中公式(1)上面的的深度信息D
|
166 |
+
res = self.im_depthmaps.exp()
|
167 |
+
if not raw:
|
168 |
+
res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
|
169 |
+
return res
|
170 |
+
|
171 |
+
def depth_to_pts3d(self): # 这里根据深度信息D计算真实的世界坐标系下的点,即论文中公式(1)上面的公式
|
172 |
+
# Get depths and projection params if not provided
|
173 |
+
focals = self.get_focals() # 论文中Recovering intrinsics章节:求内参矩阵(即焦距)
|
174 |
+
pp = self.get_principal_points() # 图像坐标系和像素坐标系之间的偏移,即照片宽高的一半
|
175 |
+
im_poses = self.get_im_poses() # 外参数矩阵
|
176 |
+
depth = self.get_depthmaps(raw=True)#论文中公式(1)上面的深度信息D
|
177 |
+
|
178 |
+
# get pointmaps in camera frame self._grid:输入的所有图像(图像坐标系)
|
179 |
+
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) # 将输入图像的坐标点转成相机坐标系下的点
|
180 |
+
# project to world frame
|
181 |
+
return geotrf(im_poses, rel_ptmaps) # 再由相机坐标系转成世界坐标系
|
182 |
+
|
183 |
+
def get_pts3d(self, raw=False): # 计算真实的世界坐标系下的三维点坐标,根据公式(1)上面的深度D计算公式计算
|
184 |
+
res = self.depth_to_pts3d()
|
185 |
+
if not raw:
|
186 |
+
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
187 |
+
return res
|
188 |
+
# 这里的forward返回的就是公式(5)计算的损失值
|
189 |
+
def forward(self): # 论文中: Global optimization
|
190 |
+
pw_poses = self.get_pw_poses() # pw_poses cam-to-world 公式(5)的P_e: 外参矩阵的逆,由相机坐标系转成世界坐标系,requires_grad=True
|
191 |
+
pw_adapt = self.get_adaptors().unsqueeze(1) # 公式(5)中的比例系数 sigma,requires_grad=False
|
192 |
+
proj_pts3d = self.get_pts3d(raw=True) # im_poses 公式(5)的待优化的真实的世界坐标系下的三维点requires_grad=True
|
193 |
+
|
194 |
+
# rotate pairwise prediction according to pw_poses 根据公式(5)的外参矩阵部分转成世界坐标系requires_grad=True
|
195 |
+
aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) # _stacked_pred_i/j表示dest3r预测的三维点云, requires_grad=False
|
196 |
+
aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
|
197 |
+
|
198 |
+
# compute the loss: 转换成世界坐标系后的两张图像分别与待估计世界坐标系下的点(proj_pts3d)计算损失
|
199 |
+
li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
|
200 |
+
lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
|
201 |
+
|
202 |
+
return li + lj
|
203 |
+
|
204 |
+
|
205 |
+
def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
|
206 |
+
pp = pp.unsqueeze(1)
|
207 |
+
focal = focal.unsqueeze(1)
|
208 |
+
assert focal.shape == (len(depth), 1, 1)
|
209 |
+
assert pp.shape == (len(depth), 1, 2)
|
210 |
+
assert pixel_grid.shape == depth.shape + (2,)
|
211 |
+
depth = depth.unsqueeze(-1)
|
212 |
+
return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) # 公式(1)上面的计算公式,根据内参矩阵和深度D,将图像坐标系的点转成相机坐标系下的三维点
|
213 |
+
|
214 |
+
|
215 |
+
def ParameterStack(params, keys=None, is_param=None, fill=0):
|
216 |
+
if keys is not None:
|
217 |
+
params = [params[k] for k in keys]
|
218 |
+
|
219 |
+
if fill > 0:
|
220 |
+
params = [_ravel_hw(p, fill) for p in params]
|
221 |
+
|
222 |
+
requires_grad = params[0].requires_grad
|
223 |
+
assert all(p.requires_grad == requires_grad for p in params)
|
224 |
+
|
225 |
+
params = torch.stack(list(params)).float().detach()
|
226 |
+
if is_param or requires_grad:
|
227 |
+
params = nn.Parameter(params)
|
228 |
+
params.requires_grad_(requires_grad)
|
229 |
+
return params
|
230 |
+
|
231 |
+
|
232 |
+
def _ravel_hw(tensor, fill=0):
|
233 |
+
# ravel H,W
|
234 |
+
tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
235 |
+
|
236 |
+
if len(tensor) < fill:
|
237 |
+
tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
|
238 |
+
return tensor
|
239 |
+
|
240 |
+
|
241 |
+
def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
|
242 |
+
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
|
243 |
+
return minf*focal_base, maxf*focal_base
|
244 |
+
|
245 |
+
|
246 |
+
def apply_mask(img, msk):
|
247 |
+
img = img.copy()
|
248 |
+
img[msk] = 0
|
249 |
+
return img
|
dust3r/cloud_opt/pair_viewer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Dummy optimizer for visualizing pairs
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
from dust3r.cloud_opt.base_opt import BasePCOptimizer
|
13 |
+
from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
|
14 |
+
from dust3r.cloud_opt.commons import edge_str
|
15 |
+
from dust3r.post_process import estimate_focal_knowing_depth
|
16 |
+
|
17 |
+
|
18 |
+
class PairViewer (BasePCOptimizer):
|
19 |
+
"""
|
20 |
+
This a Dummy Optimizer.
|
21 |
+
To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, *args, **kwargs):
|
25 |
+
super().__init__(*args, **kwargs)
|
26 |
+
assert self.is_symmetrized and self.n_edges == 2
|
27 |
+
self.has_im_poses = True
|
28 |
+
|
29 |
+
# compute all parameters directly from raw input
|
30 |
+
self.focals = []
|
31 |
+
self.pp = []
|
32 |
+
rel_poses = []
|
33 |
+
confs = []
|
34 |
+
for i in range(self.n_imgs):
|
35 |
+
conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
|
36 |
+
print(f' - {conf=:.3} for edge {i}-{1-i}')
|
37 |
+
confs.append(conf)
|
38 |
+
|
39 |
+
H, W = self.imshapes[i]
|
40 |
+
pts3d = self.pred_i[edge_str(i, 1-i)]
|
41 |
+
pp = torch.tensor((W/2, H/2))
|
42 |
+
focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
|
43 |
+
self.focals.append(focal)
|
44 |
+
self.pp.append(pp)
|
45 |
+
|
46 |
+
# estimate the pose of pts1 in image 2
|
47 |
+
pixels = np.mgrid[:W, :H].T.astype(np.float32)
|
48 |
+
pts3d = self.pred_j[edge_str(1-i, i)].numpy()
|
49 |
+
assert pts3d.shape[:2] == (H, W)
|
50 |
+
msk = self.get_masks()[i].numpy()
|
51 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
52 |
+
|
53 |
+
try:
|
54 |
+
res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
|
55 |
+
iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
|
56 |
+
success, R, T, inliers = res
|
57 |
+
assert success
|
58 |
+
|
59 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
60 |
+
pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
|
61 |
+
except:
|
62 |
+
pose = np.eye(4)
|
63 |
+
rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
|
64 |
+
|
65 |
+
# let's use the pair with the most confidence
|
66 |
+
if confs[0] > confs[1]:
|
67 |
+
# ptcloud is expressed in camera1
|
68 |
+
self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
|
69 |
+
self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
|
70 |
+
else:
|
71 |
+
# ptcloud is expressed in camera2
|
72 |
+
self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
|
73 |
+
self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
|
74 |
+
|
75 |
+
self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
|
76 |
+
self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
|
77 |
+
self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
|
78 |
+
self.depth = nn.ParameterList(self.depth)
|
79 |
+
for p in self.parameters():
|
80 |
+
p.requires_grad = False
|
81 |
+
|
82 |
+
def _set_depthmap(self, idx, depth, force=False):
|
83 |
+
print('_set_depthmap is ignored in PairViewer')
|
84 |
+
return
|
85 |
+
|
86 |
+
def get_depthmaps(self, raw=False):
|
87 |
+
depth = [d.to(self.device) for d in self.depth]
|
88 |
+
return depth
|
89 |
+
|
90 |
+
def _set_focal(self, idx, focal, force=False):
|
91 |
+
self.focals[idx] = focal
|
92 |
+
|
93 |
+
def get_focals(self):
|
94 |
+
return self.focals
|
95 |
+
|
96 |
+
def get_known_focal_mask(self):
|
97 |
+
return torch.tensor([not (p.requires_grad) for p in self.focals])
|
98 |
+
|
99 |
+
def get_principal_points(self):
|
100 |
+
return self.pp
|
101 |
+
|
102 |
+
def get_intrinsics(self):
|
103 |
+
focals = self.get_focals()
|
104 |
+
pps = self.get_principal_points()
|
105 |
+
K = torch.zeros((len(focals), 3, 3), device=self.device)
|
106 |
+
for i in range(len(focals)):
|
107 |
+
K[i, 0, 0] = K[i, 1, 1] = focals[i]
|
108 |
+
K[i, :2, 2] = pps[i]
|
109 |
+
K[i, 2, 2] = 1
|
110 |
+
return K
|
111 |
+
|
112 |
+
def get_im_poses(self):
|
113 |
+
return self.im_poses
|
114 |
+
|
115 |
+
def depth_to_pts3d(self):
|
116 |
+
pts3d = []
|
117 |
+
for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
|
118 |
+
pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
|
119 |
+
intrinsics.cpu().numpy(),
|
120 |
+
im_pose.cpu().numpy())
|
121 |
+
pts3d.append(torch.from_numpy(pts).to(device=self.device))
|
122 |
+
return pts3d
|
123 |
+
|
124 |
+
def forward(self):
|
125 |
+
return float('nan')
|
dust3r/datasets/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
from .utils.transforms import *
|
4 |
+
from .base.batched_sampler import BatchedRandomSampler # noqa: F401
|
5 |
+
from .co3d import Co3d # noqa: F401
|
6 |
+
|
7 |
+
|
8 |
+
def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
|
9 |
+
import torch
|
10 |
+
from croco.utils.misc import get_world_size, get_rank
|
11 |
+
|
12 |
+
# pytorch dataset
|
13 |
+
if isinstance(dataset, str):
|
14 |
+
dataset = eval(dataset)
|
15 |
+
|
16 |
+
world_size = get_world_size()
|
17 |
+
rank = get_rank()
|
18 |
+
|
19 |
+
try:
|
20 |
+
sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
|
21 |
+
rank=rank, drop_last=drop_last)
|
22 |
+
except (AttributeError, NotImplementedError):
|
23 |
+
# not avail for this dataset
|
24 |
+
if torch.distributed.is_initialized():
|
25 |
+
sampler = torch.utils.data.DistributedSampler(
|
26 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
|
27 |
+
)
|
28 |
+
elif shuffle:
|
29 |
+
sampler = torch.utils.data.RandomSampler(dataset)
|
30 |
+
else:
|
31 |
+
sampler = torch.utils.data.SequentialSampler(dataset)
|
32 |
+
|
33 |
+
data_loader = torch.utils.data.DataLoader(
|
34 |
+
dataset,
|
35 |
+
sampler=sampler,
|
36 |
+
batch_size=batch_size,
|
37 |
+
num_workers=num_workers,
|
38 |
+
pin_memory=pin_mem,
|
39 |
+
drop_last=drop_last,
|
40 |
+
)
|
41 |
+
|
42 |
+
return data_loader
|
dust3r/datasets/base/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/datasets/base/base_stereo_view_dataset.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# base class for implementing datasets
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import PIL
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from dust3r.datasets.base.easy_dataset import EasyDataset
|
12 |
+
from dust3r.datasets.utils.transforms import ImgNorm
|
13 |
+
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
|
14 |
+
import dust3r.datasets.utils.cropping as cropping
|
15 |
+
|
16 |
+
|
17 |
+
class BaseStereoViewDataset (EasyDataset):
|
18 |
+
""" Define all basic options.
|
19 |
+
|
20 |
+
Usage:
|
21 |
+
class MyDataset (BaseStereoViewDataset):
|
22 |
+
def _get_views(self, idx, rng):
|
23 |
+
# overload here
|
24 |
+
views = []
|
25 |
+
views.append(dict(img=, ...))
|
26 |
+
return views
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, *, # only keyword arguments
|
30 |
+
split=None,
|
31 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
32 |
+
transform=ImgNorm,
|
33 |
+
aug_crop=False,
|
34 |
+
seed=None):
|
35 |
+
self.num_views = 2
|
36 |
+
self.split = split
|
37 |
+
self._set_resolutions(resolution)
|
38 |
+
|
39 |
+
self.transform = transform
|
40 |
+
if isinstance(transform, str):
|
41 |
+
transform = eval(transform)
|
42 |
+
|
43 |
+
self.aug_crop = aug_crop
|
44 |
+
self.seed = seed
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.scenes)
|
48 |
+
|
49 |
+
def get_stats(self):
|
50 |
+
return f"{len(self)} pairs"
|
51 |
+
|
52 |
+
def __repr__(self):
|
53 |
+
resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
|
54 |
+
return f"""{type(self).__name__}({self.get_stats()},
|
55 |
+
{self.split=},
|
56 |
+
{self.seed=},
|
57 |
+
resolutions={resolutions_str},
|
58 |
+
{self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
|
59 |
+
|
60 |
+
def _get_views(self, idx, resolution, rng):
|
61 |
+
raise NotImplementedError()
|
62 |
+
|
63 |
+
def __getitem__(self, idx):
|
64 |
+
if isinstance(idx, tuple):
|
65 |
+
# the idx is specifying the aspect-ratio
|
66 |
+
idx, ar_idx = idx
|
67 |
+
else:
|
68 |
+
assert len(self._resolutions) == 1
|
69 |
+
ar_idx = 0
|
70 |
+
|
71 |
+
# set-up the rng
|
72 |
+
if self.seed: # reseed for each __getitem__
|
73 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
74 |
+
elif not hasattr(self, '_rng'):
|
75 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
76 |
+
self._rng = np.random.default_rng(seed=seed)
|
77 |
+
|
78 |
+
# over-loaded code
|
79 |
+
resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
80 |
+
views = self._get_views(idx, resolution, self._rng)
|
81 |
+
assert len(views) == self.num_views
|
82 |
+
|
83 |
+
# check data-types
|
84 |
+
for v, view in enumerate(views):
|
85 |
+
assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
86 |
+
view['idx'] = (idx, ar_idx, v)
|
87 |
+
|
88 |
+
# encode the image
|
89 |
+
width, height = view['img'].size
|
90 |
+
view['true_shape'] = np.int32((height, width))
|
91 |
+
view['img'] = self.transform(view['img'])
|
92 |
+
|
93 |
+
assert 'camera_intrinsics' in view
|
94 |
+
if 'camera_pose' not in view:
|
95 |
+
view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
|
96 |
+
else:
|
97 |
+
assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
|
98 |
+
assert 'pts3d' not in view
|
99 |
+
assert 'valid_mask' not in view
|
100 |
+
assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
|
101 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
102 |
+
|
103 |
+
view['pts3d'] = pts3d
|
104 |
+
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
105 |
+
|
106 |
+
# check all datatypes
|
107 |
+
for key, val in view.items():
|
108 |
+
res, err_msg = is_good_type(key, val)
|
109 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
110 |
+
K = view['camera_intrinsics']
|
111 |
+
|
112 |
+
# last thing done!
|
113 |
+
for view in views:
|
114 |
+
# transpose to make sure all views are the same size
|
115 |
+
transpose_to_landscape(view)
|
116 |
+
# this allows to check whether the RNG is is the same state each time
|
117 |
+
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
118 |
+
return views
|
119 |
+
|
120 |
+
def _set_resolutions(self, resolutions):
|
121 |
+
assert resolutions is not None, 'undefined resolution'
|
122 |
+
|
123 |
+
if not isinstance(resolutions, list):
|
124 |
+
resolutions = [resolutions]
|
125 |
+
|
126 |
+
self._resolutions = []
|
127 |
+
for resolution in resolutions:
|
128 |
+
if isinstance(resolution, int):
|
129 |
+
width = height = resolution
|
130 |
+
else:
|
131 |
+
width, height = resolution
|
132 |
+
assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
|
133 |
+
assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
|
134 |
+
assert width >= height
|
135 |
+
self._resolutions.append((width, height))
|
136 |
+
|
137 |
+
def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
|
138 |
+
""" This function:
|
139 |
+
- first downsizes the image with LANCZOS inteprolation,
|
140 |
+
which is better than bilinear interpolation in
|
141 |
+
"""
|
142 |
+
if not isinstance(image, PIL.Image.Image):
|
143 |
+
image = PIL.Image.fromarray(image)
|
144 |
+
|
145 |
+
# downscale with lanczos interpolation so that image.size == resolution
|
146 |
+
# cropping centered on the principal point
|
147 |
+
W, H = image.size
|
148 |
+
cx, cy = intrinsics[:2, 2].round().astype(int)
|
149 |
+
min_margin_x = min(cx, W-cx)
|
150 |
+
min_margin_y = min(cy, H-cy)
|
151 |
+
assert min_margin_x > W/5, f'Bad principal point in view={info}'
|
152 |
+
assert min_margin_y > H/5, f'Bad principal point in view={info}'
|
153 |
+
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
|
154 |
+
l, t = cx - min_margin_x, cy - min_margin_y
|
155 |
+
r, b = cx + min_margin_x, cy + min_margin_y
|
156 |
+
crop_bbox = (l, t, r, b)
|
157 |
+
image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
158 |
+
|
159 |
+
# transpose the resolution if necessary
|
160 |
+
W, H = image.size # new size
|
161 |
+
assert resolution[0] >= resolution[1]
|
162 |
+
if H > 1.1*W:
|
163 |
+
# image is portrait mode
|
164 |
+
resolution = resolution[::-1]
|
165 |
+
elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
|
166 |
+
# image is square, so we chose (portrait, landscape) randomly
|
167 |
+
if rng.integers(2):
|
168 |
+
resolution = resolution[::-1]
|
169 |
+
|
170 |
+
# high-quality Lanczos down-scaling
|
171 |
+
target_resolution = np.array(resolution)
|
172 |
+
if self.aug_crop > 1:
|
173 |
+
target_resolution += rng.integers(0, self.aug_crop)
|
174 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
175 |
+
|
176 |
+
# actual cropping (if necessary) with bilinear interpolation
|
177 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
|
178 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
179 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
180 |
+
|
181 |
+
return image, depthmap, intrinsics2
|
182 |
+
|
183 |
+
|
184 |
+
def is_good_type(key, v):
|
185 |
+
""" returns (is_good, err_msg)
|
186 |
+
"""
|
187 |
+
if isinstance(v, (str, int, tuple)):
|
188 |
+
return True, None
|
189 |
+
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
190 |
+
return False, f"bad {v.dtype=}"
|
191 |
+
return True, None
|
192 |
+
|
193 |
+
|
194 |
+
def view_name(view, batch_index=None):
|
195 |
+
def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
|
196 |
+
db = sel(view['dataset'])
|
197 |
+
label = sel(view['label'])
|
198 |
+
instance = sel(view['instance'])
|
199 |
+
return f"{db}/{label}/{instance}"
|
200 |
+
|
201 |
+
|
202 |
+
def transpose_to_landscape(view):
|
203 |
+
height, width = view['true_shape']
|
204 |
+
|
205 |
+
if width < height:
|
206 |
+
# rectify portrait to landscape
|
207 |
+
assert view['img'].shape == (3, height, width)
|
208 |
+
view['img'] = view['img'].swapaxes(1, 2)
|
209 |
+
|
210 |
+
assert view['valid_mask'].shape == (height, width)
|
211 |
+
view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
|
212 |
+
|
213 |
+
assert view['depthmap'].shape == (height, width)
|
214 |
+
view['depthmap'] = view['depthmap'].swapaxes(0, 1)
|
215 |
+
|
216 |
+
assert view['pts3d'].shape == (height, width, 3)
|
217 |
+
view['pts3d'] = view['pts3d'].swapaxes(0, 1)
|
218 |
+
|
219 |
+
# transpose x and y pixels
|
220 |
+
view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
|
dust3r/datasets/base/batched_sampler.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Random sampling under a constraint
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class BatchedRandomSampler:
|
12 |
+
""" Random sampling under a constraint: each sample in the batch has the same feature,
|
13 |
+
which is chosen randomly from a known pool of 'features' for each batch.
|
14 |
+
|
15 |
+
For instance, the 'feature' could be the image aspect-ratio.
|
16 |
+
|
17 |
+
The index returned is a tuple (sample_idx, feat_idx).
|
18 |
+
This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
|
22 |
+
self.batch_size = batch_size
|
23 |
+
self.pool_size = pool_size
|
24 |
+
|
25 |
+
self.len_dataset = N = len(dataset)
|
26 |
+
self.total_size = round_by(N, batch_size*world_size) if drop_last else N
|
27 |
+
assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
|
28 |
+
|
29 |
+
# distributed sampler
|
30 |
+
self.world_size = world_size
|
31 |
+
self.rank = rank
|
32 |
+
self.epoch = None
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return self.total_size // self.world_size
|
36 |
+
|
37 |
+
def set_epoch(self, epoch):
|
38 |
+
self.epoch = epoch
|
39 |
+
|
40 |
+
def __iter__(self):
|
41 |
+
# prepare RNG
|
42 |
+
if self.epoch is None:
|
43 |
+
assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
|
44 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
45 |
+
else:
|
46 |
+
seed = self.epoch + 777
|
47 |
+
rng = np.random.default_rng(seed=seed)
|
48 |
+
|
49 |
+
# random indices (will restart from 0 if not drop_last)
|
50 |
+
sample_idxs = np.arange(self.total_size)
|
51 |
+
rng.shuffle(sample_idxs)
|
52 |
+
|
53 |
+
# random feat_idxs (same across each batch)
|
54 |
+
n_batches = (self.total_size+self.batch_size-1) // self.batch_size
|
55 |
+
feat_idxs = rng.integers(self.pool_size, size=n_batches)
|
56 |
+
feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
|
57 |
+
feat_idxs = feat_idxs.ravel()[:self.total_size]
|
58 |
+
|
59 |
+
# put them together
|
60 |
+
idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
|
61 |
+
|
62 |
+
# Distributed sampler: we select a subset of batches
|
63 |
+
# make sure the slice for each node is aligned with batch_size
|
64 |
+
size_per_proc = self.batch_size * ((self.total_size + self.world_size *
|
65 |
+
self.batch_size-1) // (self.world_size * self.batch_size))
|
66 |
+
idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
|
67 |
+
|
68 |
+
yield from (tuple(idx) for idx in idxs)
|
69 |
+
|
70 |
+
|
71 |
+
def round_by(total, multiple, up=False):
|
72 |
+
if up:
|
73 |
+
total = total + multiple-1
|
74 |
+
return (total//multiple) * multiple
|
dust3r/datasets/base/easy_dataset.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# A dataset base class that you can easily resize and combine.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
|
9 |
+
|
10 |
+
|
11 |
+
class EasyDataset:
|
12 |
+
""" a dataset that you can easily resize and combine.
|
13 |
+
Examples:
|
14 |
+
---------
|
15 |
+
2 * dataset ==> duplicate each element 2x
|
16 |
+
|
17 |
+
10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
|
18 |
+
|
19 |
+
dataset1 + dataset2 ==> concatenate datasets
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __add__(self, other):
|
23 |
+
return CatDataset([self, other])
|
24 |
+
|
25 |
+
def __rmul__(self, factor):
|
26 |
+
return MulDataset(factor, self)
|
27 |
+
|
28 |
+
def __rmatmul__(self, factor):
|
29 |
+
return ResizedDataset(factor, self)
|
30 |
+
|
31 |
+
def set_epoch(self, epoch):
|
32 |
+
pass # nothing to do by default
|
33 |
+
|
34 |
+
def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
|
35 |
+
if not (shuffle):
|
36 |
+
raise NotImplementedError() # cannot deal yet
|
37 |
+
num_of_aspect_ratios = len(self._resolutions)
|
38 |
+
return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
|
39 |
+
|
40 |
+
|
41 |
+
class MulDataset (EasyDataset):
|
42 |
+
""" Artifically augmenting the size of a dataset.
|
43 |
+
"""
|
44 |
+
multiplicator: int
|
45 |
+
|
46 |
+
def __init__(self, multiplicator, dataset):
|
47 |
+
assert isinstance(multiplicator, int) and multiplicator > 0
|
48 |
+
self.multiplicator = multiplicator
|
49 |
+
self.dataset = dataset
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return self.multiplicator * len(self.dataset)
|
53 |
+
|
54 |
+
def __repr__(self):
|
55 |
+
return f'{self.multiplicator}*{repr(self.dataset)}'
|
56 |
+
|
57 |
+
def __getitem__(self, idx):
|
58 |
+
if isinstance(idx, tuple):
|
59 |
+
idx, other = idx
|
60 |
+
return self.dataset[idx // self.multiplicator, other]
|
61 |
+
else:
|
62 |
+
return self.dataset[idx // self.multiplicator]
|
63 |
+
|
64 |
+
@property
|
65 |
+
def _resolutions(self):
|
66 |
+
return self.dataset._resolutions
|
67 |
+
|
68 |
+
|
69 |
+
class ResizedDataset (EasyDataset):
|
70 |
+
""" Artifically changing the size of a dataset.
|
71 |
+
"""
|
72 |
+
new_size: int
|
73 |
+
|
74 |
+
def __init__(self, new_size, dataset):
|
75 |
+
assert isinstance(new_size, int) and new_size > 0
|
76 |
+
self.new_size = new_size
|
77 |
+
self.dataset = dataset
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return self.new_size
|
81 |
+
|
82 |
+
def __repr__(self):
|
83 |
+
size_str = str(self.new_size)
|
84 |
+
for i in range((len(size_str)-1) // 3):
|
85 |
+
sep = -4*i-3
|
86 |
+
size_str = size_str[:sep] + '_' + size_str[sep:]
|
87 |
+
return f'{size_str} @ {repr(self.dataset)}'
|
88 |
+
|
89 |
+
def set_epoch(self, epoch):
|
90 |
+
# this random shuffle only depends on the epoch
|
91 |
+
rng = np.random.default_rng(seed=epoch+777)
|
92 |
+
|
93 |
+
# shuffle all indices
|
94 |
+
perm = rng.permutation(len(self.dataset))
|
95 |
+
|
96 |
+
# rotary extension until target size is met
|
97 |
+
shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
|
98 |
+
self._idxs_mapping = shuffled_idxs[:self.new_size]
|
99 |
+
|
100 |
+
assert len(self._idxs_mapping) == self.new_size
|
101 |
+
|
102 |
+
def __getitem__(self, idx):
|
103 |
+
assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
|
104 |
+
if isinstance(idx, tuple):
|
105 |
+
idx, other = idx
|
106 |
+
return self.dataset[self._idxs_mapping[idx], other]
|
107 |
+
else:
|
108 |
+
return self.dataset[self._idxs_mapping[idx]]
|
109 |
+
|
110 |
+
@property
|
111 |
+
def _resolutions(self):
|
112 |
+
return self.dataset._resolutions
|
113 |
+
|
114 |
+
|
115 |
+
class CatDataset (EasyDataset):
|
116 |
+
""" Concatenation of several datasets
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, datasets):
|
120 |
+
for dataset in datasets:
|
121 |
+
assert isinstance(dataset, EasyDataset)
|
122 |
+
self.datasets = datasets
|
123 |
+
self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
return self._cum_sizes[-1]
|
127 |
+
|
128 |
+
def __repr__(self):
|
129 |
+
# remove uselessly long transform
|
130 |
+
return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
|
131 |
+
|
132 |
+
def set_epoch(self, epoch):
|
133 |
+
for dataset in self.datasets:
|
134 |
+
dataset.set_epoch(epoch)
|
135 |
+
|
136 |
+
def __getitem__(self, idx):
|
137 |
+
other = None
|
138 |
+
if isinstance(idx, tuple):
|
139 |
+
idx, other = idx
|
140 |
+
|
141 |
+
if not (0 <= idx < len(self)):
|
142 |
+
raise IndexError()
|
143 |
+
|
144 |
+
db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
|
145 |
+
dataset = self.datasets[db_idx]
|
146 |
+
new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
|
147 |
+
|
148 |
+
if other is not None:
|
149 |
+
new_idx = (new_idx, other)
|
150 |
+
return dataset[new_idx]
|
151 |
+
|
152 |
+
@property
|
153 |
+
def _resolutions(self):
|
154 |
+
resolutions = self.datasets[0]._resolutions
|
155 |
+
for dataset in self.datasets[1:]:
|
156 |
+
assert tuple(dataset._resolutions) == tuple(resolutions)
|
157 |
+
return resolutions
|
dust3r/datasets/co3d.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Dataloader for preprocessed Co3d_v2
|
6 |
+
# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
|
7 |
+
# See datasets_preprocess/preprocess_co3d.py
|
8 |
+
# --------------------------------------------------------
|
9 |
+
import os.path as osp
|
10 |
+
import json
|
11 |
+
import itertools
|
12 |
+
from collections import deque
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
|
18 |
+
from dust3r.utils.image import imread_cv2
|
19 |
+
|
20 |
+
|
21 |
+
class Co3d(BaseStereoViewDataset):
|
22 |
+
def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
|
23 |
+
self.ROOT = ROOT
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
assert mask_bg in (True, False, 'rand')
|
26 |
+
self.mask_bg = mask_bg
|
27 |
+
|
28 |
+
# load all scenes
|
29 |
+
with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
|
30 |
+
self.scenes = json.load(f)
|
31 |
+
self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
|
32 |
+
self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
|
33 |
+
for k2, v2 in v.items()}
|
34 |
+
self.scene_list = list(self.scenes.keys())
|
35 |
+
|
36 |
+
# for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
|
37 |
+
# we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees
|
38 |
+
self.combinations = [(i, j)
|
39 |
+
for i, j in itertools.combinations(range(100), 2)
|
40 |
+
if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0]
|
41 |
+
|
42 |
+
self.invalidate = {scene: {} for scene in self.scene_list}
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.scene_list) * len(self.combinations)
|
46 |
+
|
47 |
+
def _get_views(self, idx, resolution, rng):
|
48 |
+
# choose a scene
|
49 |
+
obj, instance = self.scene_list[idx // len(self.combinations)]
|
50 |
+
image_pool = self.scenes[obj, instance]
|
51 |
+
im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]
|
52 |
+
|
53 |
+
# add a bit of randomness
|
54 |
+
last = len(image_pool)-1
|
55 |
+
|
56 |
+
if resolution not in self.invalidate[obj, instance]: # flag invalid images
|
57 |
+
self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
|
58 |
+
|
59 |
+
# decide now if we mask the bg
|
60 |
+
mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
|
61 |
+
|
62 |
+
views = []
|
63 |
+
imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]
|
64 |
+
imgs_idxs = deque(imgs_idxs)
|
65 |
+
while len(imgs_idxs) > 0: # some images (few) have zero depth
|
66 |
+
im_idx = imgs_idxs.pop()
|
67 |
+
|
68 |
+
if self.invalidate[obj, instance][resolution][im_idx]:
|
69 |
+
# search for a valid image
|
70 |
+
random_direction = 2 * rng.choice(2) - 1
|
71 |
+
for offset in range(1, len(image_pool)):
|
72 |
+
tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
|
73 |
+
if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
|
74 |
+
im_idx = tentative_im_idx
|
75 |
+
break
|
76 |
+
|
77 |
+
view_idx = image_pool[im_idx]
|
78 |
+
|
79 |
+
impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
|
80 |
+
|
81 |
+
# load camera params
|
82 |
+
input_metadata = np.load(impath.replace('jpg', 'npz'))
|
83 |
+
camera_pose = input_metadata['camera_pose'].astype(np.float32)
|
84 |
+
intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
|
85 |
+
|
86 |
+
# load image and depth
|
87 |
+
rgb_image = imread_cv2(impath)
|
88 |
+
depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
|
89 |
+
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
|
90 |
+
|
91 |
+
if mask_bg:
|
92 |
+
# load object mask
|
93 |
+
maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
|
94 |
+
maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
95 |
+
maskmap = (maskmap / 255.0) > 0.1
|
96 |
+
|
97 |
+
# update the depthmap with mask
|
98 |
+
depthmap *= maskmap
|
99 |
+
|
100 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
101 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
|
102 |
+
|
103 |
+
num_valid = (depthmap > 0.0).sum()
|
104 |
+
if num_valid == 0:
|
105 |
+
# problem, invalidate image and retry
|
106 |
+
self.invalidate[obj, instance][resolution][im_idx] = True
|
107 |
+
imgs_idxs.append(im_idx)
|
108 |
+
continue
|
109 |
+
|
110 |
+
views.append(dict(
|
111 |
+
img=rgb_image,
|
112 |
+
depthmap=depthmap,
|
113 |
+
camera_pose=camera_pose,
|
114 |
+
camera_intrinsics=intrinsics,
|
115 |
+
dataset='Co3d_v2',
|
116 |
+
label=osp.join(obj, instance),
|
117 |
+
instance=osp.split(impath)[1],
|
118 |
+
))
|
119 |
+
return views
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
from dust3r.datasets.base.base_stereo_view_dataset import view_name
|
124 |
+
from dust3r.viz import SceneViz, auto_cam_size
|
125 |
+
from dust3r.utils.image import rgb
|
126 |
+
|
127 |
+
dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16)
|
128 |
+
|
129 |
+
for idx in np.random.permutation(len(dataset)):
|
130 |
+
views = dataset[idx]
|
131 |
+
assert len(views) == 2
|
132 |
+
print(view_name(views[0]), view_name(views[1]))
|
133 |
+
viz = SceneViz()
|
134 |
+
poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
|
135 |
+
cam_size = max(auto_cam_size(poses), 0.001)
|
136 |
+
for view_idx in [0, 1]:
|
137 |
+
pts3d = views[view_idx]['pts3d']
|
138 |
+
valid_mask = views[view_idx]['valid_mask']
|
139 |
+
colors = rgb(views[view_idx]['img'])
|
140 |
+
viz.add_pointcloud(pts3d, colors, valid_mask)
|
141 |
+
viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
|
142 |
+
focal=views[view_idx]['camera_intrinsics'][0, 0],
|
143 |
+
color=(idx*255, (1 - idx)*255, 0),
|
144 |
+
image=colors,
|
145 |
+
cam_size=cam_size)
|
146 |
+
viz.show()
|
dust3r/datasets/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/datasets/utils/cropping.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# croppping utilities
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import PIL.Image
|
8 |
+
import os
|
9 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
10 |
+
import cv2 # noqa
|
11 |
+
import numpy as np # noqa
|
12 |
+
from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa
|
13 |
+
try:
|
14 |
+
lanczos = PIL.Image.Resampling.LANCZOS
|
15 |
+
except AttributeError:
|
16 |
+
lanczos = PIL.Image.LANCZOS
|
17 |
+
|
18 |
+
|
19 |
+
class ImageList:
|
20 |
+
""" Convenience class to aply the same operation to a whole set of images.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, images):
|
24 |
+
if not isinstance(images, (tuple, list, set)):
|
25 |
+
images = [images]
|
26 |
+
self.images = []
|
27 |
+
for image in images:
|
28 |
+
if not isinstance(image, PIL.Image.Image):
|
29 |
+
image = PIL.Image.fromarray(image)
|
30 |
+
self.images.append(image)
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.images)
|
34 |
+
|
35 |
+
def to_pil(self):
|
36 |
+
return tuple(self.images) if len(self.images) > 1 else self.images[0]
|
37 |
+
|
38 |
+
@property
|
39 |
+
def size(self):
|
40 |
+
sizes = [im.size for im in self.images]
|
41 |
+
assert all(sizes[0] == s for s in sizes)
|
42 |
+
return sizes[0]
|
43 |
+
|
44 |
+
def resize(self, *args, **kwargs):
|
45 |
+
return ImageList(self._dispatch('resize', *args, **kwargs))
|
46 |
+
|
47 |
+
def crop(self, *args, **kwargs):
|
48 |
+
return ImageList(self._dispatch('crop', *args, **kwargs))
|
49 |
+
|
50 |
+
def _dispatch(self, func, *args, **kwargs):
|
51 |
+
return [getattr(im, func)(*args, **kwargs) for im in self.images]
|
52 |
+
|
53 |
+
|
54 |
+
def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution):
|
55 |
+
""" Jointly rescale a (image, depthmap)
|
56 |
+
so that (out_width, out_height) >= output_res
|
57 |
+
"""
|
58 |
+
image = ImageList(image)
|
59 |
+
input_resolution = np.array(image.size) # (W,H)
|
60 |
+
output_resolution = np.array(output_resolution)
|
61 |
+
if depthmap is not None:
|
62 |
+
# can also use this with masks instead of depthmaps
|
63 |
+
assert tuple(depthmap.shape[:2]) == image.size[::-1]
|
64 |
+
assert output_resolution.shape == (2,)
|
65 |
+
# define output resolution
|
66 |
+
scale_final = max(output_resolution / image.size) + 1e-8
|
67 |
+
output_resolution = np.floor(input_resolution * scale_final).astype(int)
|
68 |
+
|
69 |
+
# first rescale the image so that it contains the crop
|
70 |
+
image = image.resize(output_resolution, resample=lanczos)
|
71 |
+
if depthmap is not None:
|
72 |
+
depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
|
73 |
+
fy=scale_final, interpolation=cv2.INTER_NEAREST)
|
74 |
+
|
75 |
+
# no offset here; simple rescaling
|
76 |
+
camera_intrinsics = camera_matrix_of_crop(
|
77 |
+
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
|
78 |
+
|
79 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
80 |
+
|
81 |
+
|
82 |
+
def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
|
83 |
+
# Margins to offset the origin
|
84 |
+
margins = np.asarray(input_resolution) * scaling - output_resolution
|
85 |
+
assert np.all(margins >= 0.0)
|
86 |
+
if offset is None:
|
87 |
+
offset = offset_factor * margins
|
88 |
+
|
89 |
+
# Generate new camera parameters
|
90 |
+
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
|
91 |
+
output_camera_matrix_colmap[:2, :] *= scaling
|
92 |
+
output_camera_matrix_colmap[:2, 2] -= offset
|
93 |
+
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
|
94 |
+
|
95 |
+
return output_camera_matrix
|
96 |
+
|
97 |
+
|
98 |
+
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
|
99 |
+
"""
|
100 |
+
Return a crop of the input view.
|
101 |
+
"""
|
102 |
+
image = ImageList(image)
|
103 |
+
l, t, r, b = crop_bbox
|
104 |
+
|
105 |
+
image = image.crop((l, t, r, b))
|
106 |
+
depthmap = depthmap[t:b, l:r]
|
107 |
+
|
108 |
+
camera_intrinsics = camera_intrinsics.copy()
|
109 |
+
camera_intrinsics[0, 2] -= l
|
110 |
+
camera_intrinsics[1, 2] -= t
|
111 |
+
|
112 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
113 |
+
|
114 |
+
|
115 |
+
def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
|
116 |
+
out_width, out_height = output_resolution
|
117 |
+
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
|
118 |
+
crop_bbox = (l, t, l+out_width, t+out_height)
|
119 |
+
return crop_bbox
|
dust3r/datasets/utils/transforms.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# DUST3R default transforms
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torchvision.transforms as tvf
|
8 |
+
from dust3r.utils.image import ImgNorm
|
9 |
+
|
10 |
+
# define the standard image transforms
|
11 |
+
ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
|
dust3r/heads/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# head factory
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from .linear_head import LinearPts3d
|
8 |
+
from .dpt_head import create_dpt_head
|
9 |
+
|
10 |
+
|
11 |
+
def head_factory(head_type, output_mode, net, has_conf=False):
|
12 |
+
"""" build a prediction head for the decoder
|
13 |
+
"""
|
14 |
+
if head_type == 'linear' and output_mode == 'pts3d':
|
15 |
+
return LinearPts3d(net, has_conf)
|
16 |
+
elif head_type == 'dpt' and output_mode == 'pts3d':
|
17 |
+
return create_dpt_head(net, has_conf=has_conf)
|
18 |
+
else:
|
19 |
+
raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
|
dust3r/heads/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (618 Bytes). View file
|
|
dust3r/heads/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (621 Bytes). View file
|
|
dust3r/heads/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (621 Bytes). View file
|
|