|
import sys |
|
from pathlib import Path |
|
import subprocess |
|
import torch |
|
import logging |
|
|
|
from ..utils.base_model import BaseModel |
|
|
|
example_path = Path(__file__).parent / "../../third_party/example" |
|
sys.path.append(str(example_path)) |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Example(BaseModel): |
|
|
|
default_conf = { |
|
"name": "example", |
|
"keypoint_threshold": 0.1, |
|
"max_keypoints": 2000, |
|
"model_name": "model.pth", |
|
} |
|
required_inputs = ["image"] |
|
|
|
def _init(self, conf): |
|
|
|
|
|
model_path = example_path / "checkpoints" / f'{conf["model_name"]}' |
|
if not model_path.exists(): |
|
logger.info(f"No model found at {model_path}") |
|
|
|
|
|
self.net = callable |
|
|
|
state_dict = torch.load(model_path, map_location="cpu") |
|
self.net.load_state_dict(state_dict["model_state"]) |
|
logger.info(f"Load example model done.") |
|
|
|
def _forward(self, data): |
|
|
|
|
|
|
|
image = data["image"] |
|
|
|
|
|
|
|
|
|
|
|
keypoints, scores, descriptors = self.net(image) |
|
|
|
return { |
|
"keypoints": keypoints, |
|
"scores": scores, |
|
"descriptors": descriptors, |
|
} |
|
|