File size: 1,645 Bytes
9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import sys
from pathlib import Path
import logging
from ..utils.base_model import BaseModel
logger = logging.getLogger(__name__)
lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
sys.path.append(str(lightglue_path))
from lightglue import LightGlue as LG
class LightGlue(BaseModel):
default_conf = {
"match_threshold": 0.2,
"filter_threshold": 0.2,
"width_confidence": 0.99, # for point pruning
"depth_confidence": 0.95, # for early stopping,
"features": "superpoint",
"model_name": "superpoint_lightglue.pth",
"flash": True, # enable FlashAttention if available.
"mp": False, # enable mixed precision
}
required_inputs = [
"image0",
"keypoints0",
"scores0",
"descriptors0",
"image1",
"keypoints1",
"scores1",
"descriptors1",
]
def _init(self, conf):
weight_path = lightglue_path / "weights" / conf["model_name"]
conf["weights"] = str(weight_path)
conf["filter_threshold"] = conf["match_threshold"]
self.net = LG(**conf)
logger.info(f"Load lightglue model done.")
def _forward(self, data):
input = {}
input["image0"] = {
"image": data["image0"],
"keypoints": data["keypoints0"][None],
"descriptors": data["descriptors0"].permute(0, 2, 1),
}
input["image1"] = {
"image": data["image1"],
"keypoints": data["keypoints1"][None],
"descriptors": data["descriptors1"].permute(0, 2, 1),
}
return self.net(input)
|