Realcat commited on
Commit
8ff3c52
1 Parent(s): 8888dc8
Files changed (6) hide show
  1. api/client.py +147 -0
  2. api/server.py +135 -0
  3. format.sh +3 -3
  4. requirements.txt +1 -0
  5. test_app_cli.py +1 -4
  6. ui/viz.py +5 -7
api/client.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pickle
3
+ import time
4
+ from typing import Dict
5
+
6
+ import numpy as np
7
+ import requests
8
+ from loguru import logger
9
+
10
+ API_URL_MATCH = "http://127.0.0.1:8001/v1/match"
11
+ API_URL_EXTRACT = "http://127.0.0.1:8001/v1/extract"
12
+ API_URL_EXTRACT_V2 = "http://127.0.0.1:8001/v2/extract"
13
+
14
+
15
+ def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]:
16
+ """
17
+ Send a request to the API to generate a match between two images.
18
+
19
+ Args:
20
+ path0 (str): The path to the first image.
21
+ path1 (str): The path to the second image.
22
+
23
+ Returns:
24
+ Dict[str, np.ndarray]: A dictionary containing the generated matches.
25
+ The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
26
+ and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and
27
+ (N, 2), respectively.
28
+ """
29
+ files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
30
+ try:
31
+ response = requests.post(API_URL_MATCH, files=files)
32
+ pred = {}
33
+ if response.status_code == 200:
34
+ pred = response.json()
35
+ for key in list(pred.keys()):
36
+ pred[key] = np.array(pred[key])
37
+ else:
38
+ print(
39
+ f"Error: Response code {response.status_code} - {response.text}"
40
+ )
41
+ finally:
42
+ files["image0"].close()
43
+ files["image1"].close()
44
+ return pred
45
+
46
+
47
+ def send_generate_request1(path0: str) -> Dict[str, np.ndarray]:
48
+ """
49
+ Send a request to the API to extract features from an image.
50
+
51
+ Args:
52
+ path0 (str): The path to the image.
53
+
54
+ Returns:
55
+ Dict[str, np.ndarray]: A dictionary containing the extracted features.
56
+ The keys are "keypoints", "descriptors", and "scores", and the
57
+ values are ndarrays of shape (N, 2), (N, 128), and (N,),
58
+ respectively.
59
+ """
60
+ files = {"image": open(path0, "rb")}
61
+ try:
62
+ response = requests.post(API_URL_EXTRACT, files=files)
63
+ pred: Dict[str, np.ndarray] = {}
64
+ if response.status_code == 200:
65
+ pred = response.json()
66
+ for key in list(pred.keys()):
67
+ pred[key] = np.array(pred[key])
68
+ else:
69
+ print(
70
+ f"Error: Response code {response.status_code} - {response.text}"
71
+ )
72
+ finally:
73
+ files["image"].close()
74
+ return pred
75
+
76
+
77
+ def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]:
78
+ """
79
+ Send a request to the API to extract features from an image.
80
+
81
+ Args:
82
+ image_path (str): The path to the image.
83
+
84
+ Returns:
85
+ Dict[str, np.ndarray]: A dictionary containing the extracted features.
86
+ The keys are "keypoints", "descriptors", and "scores", and the
87
+ values are ndarrays of shape (N, 2), (N, 128), and (N,), respectively.
88
+ """
89
+ data = {
90
+ "image_path": image_path,
91
+ "max_keypoints": 1024,
92
+ "reference_points": [[0.0, 0.0], [1.0, 1.0]],
93
+ }
94
+ pred = {}
95
+ try:
96
+ response = requests.post(API_URL_EXTRACT_V2, json=data)
97
+ pred: Dict[str, np.ndarray] = {}
98
+ if response.status_code == 200:
99
+ pred = response.json()
100
+ for key in list(pred.keys()):
101
+ pred[key] = np.array(pred[key])
102
+ else:
103
+ print(
104
+ f"Error: Response code {response.status_code} - {response.text}"
105
+ )
106
+ except Exception as e:
107
+ print(f"An error occurred: {e}")
108
+ return pred
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser(
113
+ description="Send text to stable audio server and receive generated audio."
114
+ )
115
+ parser.add_argument(
116
+ "--image0",
117
+ required=False,
118
+ help="Path for the file's melody",
119
+ default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg",
120
+ )
121
+ parser.add_argument(
122
+ "--image1",
123
+ required=False,
124
+ help="Path for the file's melody",
125
+ default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg",
126
+ )
127
+ args = parser.parse_args()
128
+ for i in range(10):
129
+ t1 = time.time()
130
+ preds = send_generate_request(args.image0, args.image1)
131
+ t2 = time.time()
132
+ logger.info(f"Time cost1: {(t2 - t1)} seconds")
133
+
134
+ for i in range(10):
135
+ t1 = time.time()
136
+ preds = send_generate_request1(args.image0)
137
+ t2 = time.time()
138
+ logger.info(f"Time cost2: {(t2 - t1)} seconds")
139
+
140
+ for i in range(10):
141
+ t1 = time.time()
142
+ preds = send_generate_request2(args.image0)
143
+ t2 = time.time()
144
+ logger.info(f"Time cost2: {(t2 - t1)} seconds")
145
+
146
+ with open("preds.pkl", "wb") as f:
147
+ pickle.dump(preds, f)
api/server.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # server.py
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import uvicorn
8
+ from fastapi import FastAPI, File, UploadFile
9
+ from fastapi.responses import JSONResponse
10
+ from PIL import Image
11
+
12
+ sys.path.append("..")
13
+ from pydantic import BaseModel
14
+
15
+ from ui.api import ImageMatchingAPI
16
+ from ui.utils import DEVICE
17
+
18
+
19
+ class ImageInfo(BaseModel):
20
+ image_path: str
21
+ max_keypoints: int
22
+ reference_points: list
23
+
24
+
25
+ class ImageMatchingService:
26
+ def __init__(self, conf: dict, device: str):
27
+ self.api = ImageMatchingAPI(conf=conf, device=device)
28
+ self.app = FastAPI()
29
+ self.register_routes()
30
+
31
+ def register_routes(self):
32
+ @self.app.post("/v1/match")
33
+ async def match(
34
+ image0: UploadFile = File(...), image1: UploadFile = File(...)
35
+ ):
36
+ try:
37
+ image0_array = self.load_image(image0)
38
+ image1_array = self.load_image(image1)
39
+
40
+ output = self.api(image0_array, image1_array)
41
+
42
+ skip_keys = ["image0_orig", "image1_orig"]
43
+ pred = self.filter_output(output, skip_keys)
44
+
45
+ return JSONResponse(content=pred)
46
+ except Exception as e:
47
+ return JSONResponse(content={"error": str(e)}, status_code=500)
48
+
49
+ @self.app.post("/v1/extract")
50
+ async def extract(image: UploadFile = File(...)):
51
+ try:
52
+ image_array = self.load_image(image)
53
+ output = self.api.extract(image_array)
54
+ skip_keys = ["descriptors", "image", "image_orig"]
55
+ pred = self.filter_output(output, skip_keys)
56
+ return JSONResponse(content=pred)
57
+ except Exception as e:
58
+ return JSONResponse(content={"error": str(e)}, status_code=500)
59
+
60
+ @self.app.post("/v2/extract")
61
+ async def extract_v2(image_path: ImageInfo):
62
+ img_path = image_path.image_path
63
+ try:
64
+ safe_path = Path(img_path).resolve(strict=False)
65
+ image_array = self.load_image(str(safe_path))
66
+ output = self.api.extract(image_array)
67
+ skip_keys = ["descriptors", "image", "image_orig"]
68
+ pred = self.filter_output(output, skip_keys)
69
+ return JSONResponse(content=pred)
70
+ except Exception as e:
71
+ return JSONResponse(content={"error": str(e)}, status_code=500)
72
+
73
+ def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
74
+ """
75
+ Reads an image from a file path or an UploadFile object.
76
+
77
+ Args:
78
+ file_path: A file path or an UploadFile object.
79
+
80
+ Returns:
81
+ A numpy array representing the image.
82
+ """
83
+ if isinstance(file_path, str):
84
+ file_path = Path(file_path).resolve(strict=False)
85
+ else:
86
+ file_path = file_path.file
87
+ with Image.open(file_path) as img:
88
+ image_array = np.array(img)
89
+ return image_array
90
+
91
+ def filter_output(self, output: dict, skip_keys: list) -> dict:
92
+ pred = {}
93
+ for key, value in output.items():
94
+ if key in skip_keys:
95
+ continue
96
+ if isinstance(value, np.ndarray):
97
+ pred[key] = value.tolist()
98
+ return pred
99
+
100
+ def run(self, host: str = "0.0.0.0", port: int = 8001):
101
+ uvicorn.run(self.app, host=host, port=port)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ conf = {
106
+ "feature": {
107
+ "output": "feats-superpoint-n4096-rmax1600",
108
+ "model": {
109
+ "name": "superpoint",
110
+ "nms_radius": 3,
111
+ "max_keypoints": 4096,
112
+ "keypoint_threshold": 0.005,
113
+ },
114
+ "preprocessing": {
115
+ "grayscale": True,
116
+ "force_resize": True,
117
+ "resize_max": 1600,
118
+ "width": 640,
119
+ "height": 480,
120
+ "dfactor": 8,
121
+ },
122
+ },
123
+ "matcher": {
124
+ "output": "matches-NN-mutual",
125
+ "model": {
126
+ "name": "nearest_neighbor",
127
+ "do_mutual_check": True,
128
+ "match_threshold": 0.2,
129
+ },
130
+ },
131
+ "dense": False,
132
+ }
133
+
134
+ service = ImageMatchingService(conf=conf, device=DEVICE)
135
+ service.run()
format.sh CHANGED
@@ -1,3 +1,3 @@
1
- python -m flake8 ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
2
- python -m isort ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
3
- python -m black ui/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
 
1
+ python -m flake8 ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
2
+ python -m isort ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
3
+ python -m black ui/*.py api/*.py hloc/*.py hloc/matchers/*.py hloc/extractors/*.py
requirements.txt CHANGED
@@ -36,3 +36,4 @@ torchvision==0.19.0
36
  roma #dust3r
37
  tqdm
38
  yacs
 
 
36
  roma #dust3r
37
  tqdm
38
  yacs
39
+ fastapi
test_app_cli.py CHANGED
@@ -1,7 +1,4 @@
1
  import cv2
2
- import warnings
3
- import numpy as np
4
- from pathlib import Path
5
  from hloc import logger
6
  from ui.utils import (
7
  get_matcher_zoo,
@@ -71,7 +68,7 @@ def test_one():
71
  "dense": False,
72
  }
73
  api = ImageMatchingAPI(conf=conf, device=DEVICE)
74
- api(image0, image1)
75
  log_path = ROOT / "experiments" / "one"
76
  log_path.mkdir(exist_ok=True, parents=True)
77
  api.visualize(log_path=log_path)
 
1
  import cv2
 
 
 
2
  from hloc import logger
3
  from ui.utils import (
4
  get_matcher_zoo,
 
68
  "dense": False,
69
  }
70
  api = ImageMatchingAPI(conf=conf, device=DEVICE)
71
+ pred = api(image0, image1)
72
  log_path = ROOT / "experiments" / "one"
73
  log_path.mkdir(exist_ok=True, parents=True)
74
  api.visualize(log_path=log_path)
ui/viz.py CHANGED
@@ -10,6 +10,10 @@ import seaborn as sns
10
 
11
  from hloc.utils.viz import add_text, plot_keypoints
12
 
 
 
 
 
13
 
14
  def plot_images(
15
  imgs: List[np.ndarray],
@@ -232,11 +236,6 @@ def error_colormap(
232
  )
233
 
234
 
235
- np.random.seed(1995)
236
- color_map = np.arange(100)
237
- np.random.shuffle(color_map)
238
-
239
-
240
  def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray:
241
  """
242
  Convert a matplotlib figure to a numpy array with RGB values.
@@ -284,9 +283,8 @@ def draw_matches_core(
284
  The figure as a numpy array with shape (height, width, 3) and dtype uint8
285
  containing the RGB values of the figure.
286
  """
287
- thr = 5e-4
288
  thr = 0.5
289
- color = error_colormap(conf, thr, alpha=0.1)
290
  text = [
291
  # "image name",
292
  f"#Matches: {len(mkpts0)}",
 
10
 
11
  from hloc.utils.viz import add_text, plot_keypoints
12
 
13
+ np.random.seed(1995)
14
+ color_map = np.arange(100)
15
+ np.random.shuffle(color_map)
16
+
17
 
18
  def plot_images(
19
  imgs: List[np.ndarray],
 
236
  )
237
 
238
 
 
 
 
 
 
239
  def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray:
240
  """
241
  Convert a matplotlib figure to a numpy array with RGB values.
 
283
  The figure as a numpy array with shape (height, width, 3) and dtype uint8
284
  containing the RGB values of the figure.
285
  """
 
286
  thr = 0.5
287
+ color = error_colormap(1 - conf, thr, alpha=0.1)
288
  text = [
289
  # "image name",
290
  f"#Matches: {len(mkpts0)}",