|
|
|
|
|
|
|
import os, glob, cv2 |
|
import argparse |
|
from argparse import Namespace |
|
import yaml |
|
from tqdm import tqdm |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader, SequentialSampler |
|
|
|
from src.datasets.custom_dataloader import TestDataLoader |
|
from src.utils.dataset import read_img_gray |
|
from configs.data.base import cfg as data_cfg |
|
import viz |
|
|
|
|
|
def get_model_config(method_name, dataset_name, root_dir="viz"): |
|
config_file = f"{root_dir}/configs/{method_name}.yml" |
|
with open(config_file, "r") as f: |
|
model_conf = yaml.load(f, Loader=yaml.FullLoader)[dataset_name] |
|
return model_conf |
|
|
|
|
|
class DemoDataset(Dataset): |
|
def __init__(self, dataset_dir, img_file=None, resize=0, down_factor=16): |
|
self.dataset_dir = dataset_dir |
|
if img_file is None: |
|
self.list_img_files = glob.glob(os.path.join(dataset_dir, "*.*")) |
|
self.list_img_files.sort() |
|
else: |
|
with open(img_file) as f: |
|
self.list_img_files = [ |
|
os.path.join(dataset_dir, img_file.strip()) |
|
for img_file in f.readlines() |
|
] |
|
self.resize = resize |
|
self.down_factor = down_factor |
|
|
|
def __len__(self): |
|
return len(self.list_img_files) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.list_img_files[ |
|
idx |
|
] |
|
img, scale = read_img_gray( |
|
img_path, resize=self.resize, down_factor=self.down_factor |
|
) |
|
return {"img": img, "id": idx, "img_path": img_path} |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Visualize matches") |
|
parser.add_argument("--gpu", "-gpu", type=str, default="0") |
|
parser.add_argument("--method", type=str, default=None) |
|
parser.add_argument("--dataset_dir", type=str, default="data/aachen-day-night") |
|
parser.add_argument("--pair_dir", type=str, default=None) |
|
parser.add_argument( |
|
"--dataset_name", |
|
type=str, |
|
choices=["megadepth", "scannet", "aachen_v1.1", "inloc"], |
|
default="megadepth", |
|
) |
|
parser.add_argument("--measure_time", action="store_true") |
|
parser.add_argument("--no_viz", action="store_true") |
|
parser.add_argument("--compute_eval_metrics", action="store_true") |
|
parser.add_argument("--run_demo", action="store_true") |
|
|
|
args = parser.parse_args() |
|
|
|
model_cfg = get_model_config(args.method, args.dataset_name) |
|
class_name = model_cfg["class"] |
|
model = viz.__dict__[class_name](model_cfg) |
|
|
|
if not args.run_demo: |
|
if args.dataset_name == "megadepth": |
|
from configs.data.megadepth_test_1500 import cfg |
|
|
|
data_cfg.merge_from_other_cfg(cfg) |
|
elif args.dataset_name == "scannet": |
|
from configs.data.scannet_test_1500 import cfg |
|
|
|
data_cfg.merge_from_other_cfg(cfg) |
|
elif args.dataset_name == "aachen_v1.1": |
|
data_cfg.merge_from_list( |
|
[ |
|
"DATASET.TEST_DATA_SOURCE", |
|
"aachen_v1.1", |
|
"DATASET.TEST_DATA_ROOT", |
|
os.path.join(args.dataset_dir, "images/images_upright"), |
|
"DATASET.TEST_LIST_PATH", |
|
args.pair_dir, |
|
"DATASET.TEST_IMGSIZE", |
|
model_cfg["imsize"], |
|
] |
|
) |
|
elif args.dataset_name == "inloc": |
|
data_cfg.merge_from_list( |
|
[ |
|
"DATASET.TEST_DATA_SOURCE", |
|
"inloc", |
|
"DATASET.TEST_DATA_ROOT", |
|
args.dataset_dir, |
|
"DATASET.TEST_LIST_PATH", |
|
args.pair_dir, |
|
"DATASET.TEST_IMGSIZE", |
|
model_cfg["imsize"], |
|
] |
|
) |
|
|
|
has_ground_truth = str(data_cfg.DATASET.TEST_DATA_SOURCE).lower() in [ |
|
"megadepth", |
|
"scannet", |
|
] |
|
dataloader = TestDataLoader(data_cfg) |
|
with torch.no_grad(): |
|
for data_dict in tqdm(dataloader): |
|
for k, v in data_dict.items(): |
|
if isinstance(v, torch.Tensor): |
|
data_dict[k] = v.cuda() if torch.cuda.is_available() else v |
|
img_root_dir = data_cfg.DATASET.TEST_DATA_ROOT |
|
model.match_and_draw( |
|
data_dict, |
|
root_dir=img_root_dir, |
|
ground_truth=has_ground_truth, |
|
measure_time=args.measure_time, |
|
viz_matches=(not args.no_viz), |
|
) |
|
|
|
if args.measure_time: |
|
print( |
|
"Running time for each image is {} miliseconds".format( |
|
model.measure_time() |
|
) |
|
) |
|
if args.compute_eval_metrics and has_ground_truth: |
|
model.compute_eval_metrics() |
|
else: |
|
demo_dataset = DemoDataset(args.dataset_dir, img_file=args.pair_dir, resize=640) |
|
sampler = SequentialSampler(demo_dataset) |
|
dataloader = DataLoader(demo_dataset, batch_size=1, sampler=sampler) |
|
|
|
writer = cv2.VideoWriter( |
|
"topicfm_demo.mp4", |
|
cv2.VideoWriter_fourcc(*"mp4v"), |
|
15, |
|
(640 * 2 + 5, 480 * 2 + 10), |
|
) |
|
|
|
model.run_demo( |
|
iter(dataloader), writer |
|
) |
|
|