File size: 4,989 Bytes
9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import kornia
from kornia.feature.laf import laf_from_center_scale_ori, extract_patches_from_pyramid
import numpy as np
import torch
import pycolmap
from ..utils.base_model import BaseModel
EPS = 1e-6
def sift_to_rootsift(x):
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
x = np.sqrt(x.clip(min=EPS))
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
return x
class DoG(BaseModel):
default_conf = {
"options": {
"first_octave": 0,
"peak_threshold": 0.01,
},
"descriptor": "rootsift",
"max_keypoints": -1,
"patch_size": 32,
"mr_size": 12,
}
required_inputs = ["image"]
detection_noise = 1.0
max_batch_size = 1024
def _init(self, conf):
if conf["descriptor"] == "sosnet":
self.describe = kornia.feature.SOSNet(pretrained=True)
elif conf["descriptor"] == "hardnet":
self.describe = kornia.feature.HardNet(pretrained=True)
elif conf["descriptor"] not in ["sift", "rootsift"]:
raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')
self.sift = None # lazily instantiated on the first image
self.device = torch.device("cpu")
def to(self, *args, **kwargs):
device = kwargs.get("device")
if device is None:
match = [a for a in args if isinstance(a, (torch.device, str))]
if len(match) > 0:
device = match[0]
if device is not None:
self.device = torch.device(device)
return super().to(*args, **kwargs)
def _forward(self, data):
image = data["image"]
image_np = image.cpu().numpy()[0, 0]
assert image.shape[1] == 1
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
if self.sift is None:
use_gpu = pycolmap.has_cuda and self.device.type == "cuda"
options = {**self.conf["options"]}
if self.conf["descriptor"] == "rootsift":
options["normalization"] = pycolmap.Normalization.L1_ROOT
else:
options["normalization"] = pycolmap.Normalization.L2
self.sift = pycolmap.Sift(
options=pycolmap.SiftExtractionOptions(options),
device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
)
keypoints, scores, descriptors = self.sift.extract(image_np)
scales = keypoints[:, 2]
oris = np.rad2deg(keypoints[:, 3])
if self.conf["descriptor"] in ["sift", "rootsift"]:
# We still renormalize because COLMAP does not normalize well,
# maybe due to numerical errors
if self.conf["descriptor"] == "rootsift":
descriptors = sift_to_rootsift(descriptors)
descriptors = torch.from_numpy(descriptors)
elif self.conf["descriptor"] in ("sosnet", "hardnet"):
center = keypoints[:, :2] + 0.5
laf_scale = scales * self.conf["mr_size"] / 2
laf_ori = -oris
lafs = laf_from_center_scale_ori(
torch.from_numpy(center)[None],
torch.from_numpy(laf_scale)[None, :, None, None],
torch.from_numpy(laf_ori)[None, :, None],
).to(image.device)
patches = extract_patches_from_pyramid(
image, lafs, PS=self.conf["patch_size"]
)[0]
descriptors = patches.new_zeros((len(patches), 128))
if len(patches) > 0:
for start_idx in range(0, len(patches), self.max_batch_size):
end_idx = min(len(patches), start_idx + self.max_batch_size)
descriptors[start_idx:end_idx] = self.describe(
patches[start_idx:end_idx]
)
else:
raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
scales = torch.from_numpy(scales)
oris = torch.from_numpy(oris)
scores = torch.from_numpy(scores)
if self.conf["max_keypoints"] != -1:
# TODO: check that the scores from PyCOLMAP are 100% correct,
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
max_number = (
scores.shape[0]
if scores.shape[0] < self.conf["max_keypoints"]
else self.conf["max_keypoints"]
)
values, indices = torch.topk(scores, max_number)
keypoints = keypoints[indices]
scales = scales[indices]
oris = oris[indices]
scores = scores[indices]
descriptors = descriptors[indices]
return {
"keypoints": keypoints[None],
"scales": scales[None],
"oris": oris[None],
"scores": scores[None],
"descriptors": descriptors.T[None],
}
|