Spaces:
Sleeping
Sleeping
yunusserhat
commited on
Commit
•
94f372a
1
Parent(s):
abd15df
Upload 40 files
Browse files- metrics/__init__.py +0 -0
- metrics/distance_based.py +129 -0
- metrics/elo.py +21 -0
- metrics/utils.py +85 -0
- models/__init__.py +0 -0
- models/classification/utils_global.py +177 -0
- models/eval_best_model.py +62 -0
- models/huggingface.py +24 -0
- models/losses.py +614 -0
- models/misc.py +9 -0
- models/module.py +157 -0
- models/networks/backbones.py +162 -0
- models/networks/heads/__init__.py +0 -0
- models/networks/heads/auxilliary.py +33 -0
- models/networks/heads/classification.py +17 -0
- models/networks/heads/hybrid.py +194 -0
- models/networks/heads/id_to_gps.py +33 -0
- models/networks/heads/random.py +53 -0
- models/networks/heads/regression.py +44 -0
- models/networks/mlp.py +258 -0
- models/networks/network.py +335 -0
- models/networks/utils.py +22 -0
- models/utils.py +54 -0
- scripts/download-dataset.py +27 -0
- scripts/preprocessing/enrich-metadata-adaptive-quadtrees.py +225 -0
- scripts/preprocessing/enrich-metadata-quadtree.py +208 -0
- scripts/preprocessing/enrich-metadata.py +123 -0
- scripts/preprocessing/fix_namimbia.py +64 -0
- scripts/preprocessing/nearest-neighbors.py +140 -0
- scripts/preprocessing/preprocess.py +400 -0
- scripts/preprocessing/train-val-split.py +15 -0
- scripts/retrieval/backbone.py +150 -0
- scripts/retrieval/retrieval.py +143 -0
- scripts/retrieval/street-clip-zero-shot.py +299 -0
- scripts/retrieval/utils.py +113 -0
- utils/__init__.py +0 -0
- utils/image_processing.py +58 -0
- utils/lr_scheduler.py +96 -0
- utils/model_utils.py +14 -0
- utils/quadtree_10_1000.csv +0 -0
metrics/__init__.py
ADDED
File without changes
|
metrics/distance_based.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from metrics.utils import haversine, reverse
|
4 |
+
|
5 |
+
from torchmetrics import Metric
|
6 |
+
|
7 |
+
|
8 |
+
class HaversineMetrics(Metric):
|
9 |
+
"""
|
10 |
+
Computes the average haversine distance between the predicted and ground truth points.
|
11 |
+
Compute the accuracy given some radiuses.
|
12 |
+
Compute the Geoguessr score given some radiuses.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
acc_radiuses (list): list of radiuses to compute the accuracy from
|
16 |
+
acc_area (list): list of areas to compute the accuracy from.
|
17 |
+
acc_data (list): list of auxilliary data to compute the accuracy from.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
acc_radiuses=[],
|
23 |
+
acc_area=["country", "region", "sub-region", "city"],
|
24 |
+
aux_data=[],
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
28 |
+
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
29 |
+
for acc in acc_radiuses:
|
30 |
+
self.add_state(
|
31 |
+
f"close_enough_points_{acc}",
|
32 |
+
default=torch.tensor(0.0),
|
33 |
+
dist_reduce_fx="sum",
|
34 |
+
)
|
35 |
+
for acc in acc_area:
|
36 |
+
self.add_state(
|
37 |
+
f"close_enough_points_{acc}",
|
38 |
+
default=torch.tensor(0.0),
|
39 |
+
dist_reduce_fx="sum",
|
40 |
+
)
|
41 |
+
self.add_state(
|
42 |
+
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum"
|
43 |
+
)
|
44 |
+
self.acc_radius = acc_radiuses
|
45 |
+
self.acc_area = acc_area
|
46 |
+
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
|
47 |
+
self.aux = len(aux_data) > 0
|
48 |
+
self.aux_list = aux_data
|
49 |
+
if self.aux:
|
50 |
+
self.aux_count = {}
|
51 |
+
for col in self.aux_list:
|
52 |
+
self.add_state(
|
53 |
+
f"aux_{col}",
|
54 |
+
default=torch.tensor(0.0),
|
55 |
+
dist_reduce_fx="sum",
|
56 |
+
)
|
57 |
+
|
58 |
+
def update(self, pred, gt):
|
59 |
+
haversine_distance = haversine(pred["gps"], gt["gps"])
|
60 |
+
for acc in self.acc_radius:
|
61 |
+
self.__dict__[f"close_enough_points_{acc}"] += (
|
62 |
+
haversine_distance < acc
|
63 |
+
).sum()
|
64 |
+
if len(self.acc_area) > 0:
|
65 |
+
area_pred, area_gt = reverse(pred["gps"], gt, self.acc_area)
|
66 |
+
for acc in self.acc_area:
|
67 |
+
self.__dict__[f"close_enough_points_{acc}"] += (
|
68 |
+
area_pred[acc] == area_gt["_".join(["unique", acc])]
|
69 |
+
).sum()
|
70 |
+
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])])
|
71 |
+
self.haversine_sum += haversine_distance.sum()
|
72 |
+
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum()
|
73 |
+
|
74 |
+
if self.aux:
|
75 |
+
if "land_cover" in self.aux_list:
|
76 |
+
col = "land_cover"
|
77 |
+
self.__dict__[f"aux_{col}"] += (
|
78 |
+
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
79 |
+
).sum()
|
80 |
+
if "road_index" in self.aux_list:
|
81 |
+
col = "road_index"
|
82 |
+
self.__dict__[f"aux_{col}"] += (
|
83 |
+
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
84 |
+
).sum()
|
85 |
+
if "drive_side" in self.aux_list:
|
86 |
+
col = "drive_side"
|
87 |
+
self.__dict__[f"aux_{col}"] += (
|
88 |
+
(pred[col] > 0.5).float() == gt[col]
|
89 |
+
).sum()
|
90 |
+
if "climate" in self.aux_list:
|
91 |
+
col = "climate"
|
92 |
+
self.__dict__[f"aux_{col}"] += (
|
93 |
+
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
94 |
+
).sum()
|
95 |
+
if "soil" in self.aux_list:
|
96 |
+
col = "soil"
|
97 |
+
self.__dict__[f"aux_{col}"] += (
|
98 |
+
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
|
99 |
+
).sum()
|
100 |
+
if "dist_sea" in self.aux_list:
|
101 |
+
col = "dist_sea"
|
102 |
+
self.__dict__[f"aux_{col}"] += (
|
103 |
+
(pred[col] - gt[col]).pow(2).sum(dim=1).sum()
|
104 |
+
)
|
105 |
+
|
106 |
+
self.count += pred["gps"].shape[0]
|
107 |
+
|
108 |
+
def compute(self):
|
109 |
+
output = {
|
110 |
+
"Haversine": self.haversine_sum / self.count,
|
111 |
+
"Geoguessr": self.geoguessr_sum / self.count,
|
112 |
+
}
|
113 |
+
for acc in self.acc_radius:
|
114 |
+
output[f"Accuracy_{acc}_km_radius"] = (
|
115 |
+
self.__dict__[f"close_enough_points_{acc}"] / self.count
|
116 |
+
)
|
117 |
+
for acc in self.acc_area:
|
118 |
+
output[f"Accuracy_{acc}"] = (
|
119 |
+
self.__dict__[f"close_enough_points_{acc}"]
|
120 |
+
/ self.__dict__[f"count_{acc}"]
|
121 |
+
)
|
122 |
+
|
123 |
+
if self.aux:
|
124 |
+
for col in self.aux_list:
|
125 |
+
output["_".join(["Accuracy", col])] = (
|
126 |
+
self.__dict__[f"aux_{col}"] / self.count
|
127 |
+
)
|
128 |
+
|
129 |
+
return output
|
metrics/elo.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from metrics.utils import haversine
|
4 |
+
|
5 |
+
from torchmetrics import Metric
|
6 |
+
|
7 |
+
|
8 |
+
class HaversineELOMetric(Metric):
|
9 |
+
"""
|
10 |
+
Computes the ELO score of the current network given previous players
|
11 |
+
|
12 |
+
Args:
|
13 |
+
previous_players_scores (str): path to the csv containing the scores of the previous players
|
14 |
+
previous_players_predictions (str): path to the folder containing the predictions of the previous players
|
15 |
+
tag (str): tag of the current experiment
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, cache_folder, tag):
|
20 |
+
### TODO
|
21 |
+
pass
|
metrics/utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import reverse_geocoder
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def haversine(pred, gt):
|
7 |
+
# expects inputs to be np arrays in (lat, lon) format as radians
|
8 |
+
# N x 2
|
9 |
+
|
10 |
+
# calculate the difference in latitude and longitude between the predicted and ground truth points
|
11 |
+
lat_diff = pred[:, 0] - gt[:, 0]
|
12 |
+
lon_diff = pred[:, 1] - gt[:, 1]
|
13 |
+
|
14 |
+
# calculate the haversine formula components
|
15 |
+
lhs = torch.sin(lat_diff / 2) ** 2
|
16 |
+
rhs = torch.cos(pred[:, 0]) * torch.cos(gt[:, 0]) * torch.sin(lon_diff / 2) ** 2
|
17 |
+
a = lhs + rhs
|
18 |
+
|
19 |
+
# calculate the final distance using the haversine formula
|
20 |
+
c = 2 * torch.arctan2(torch.sqrt(a), torch.sqrt(1 - a))
|
21 |
+
distance = 6371 * c
|
22 |
+
|
23 |
+
return distance
|
24 |
+
|
25 |
+
|
26 |
+
def reverse(pred, gt, area):
|
27 |
+
df = {}
|
28 |
+
gt_area = {}
|
29 |
+
nan_mask = {}
|
30 |
+
areas = ["_".join(["unique", ar]) for ar in area]
|
31 |
+
if "unique_continent" in areas:
|
32 |
+
areas.remove("unique_continent")
|
33 |
+
for ar in areas:
|
34 |
+
inter = np.array(gt[ar])
|
35 |
+
nan_mask[ar] = inter != "nan"
|
36 |
+
gt_area[ar] = inter[nan_mask[ar]]
|
37 |
+
location = reverse_geocoder.search(
|
38 |
+
[
|
39 |
+
(lat, lon)
|
40 |
+
for lat, lon in zip(
|
41 |
+
np.degrees(pred[:, 0].cpu()), np.degrees(pred[:, 1].cpu())
|
42 |
+
)
|
43 |
+
]
|
44 |
+
)
|
45 |
+
if "continent" in area:
|
46 |
+
continent = torch.load("continent.pt")
|
47 |
+
inter = np.array([l.get("cc", "") for l in location])[
|
48 |
+
nan_mask["unique_country"]
|
49 |
+
]
|
50 |
+
df["continent"] = np.array([continent[i] for i in inter])
|
51 |
+
gt_area["unique_continent"] = np.array(
|
52 |
+
[continent[i] for i in gt_area["unique_country"]]
|
53 |
+
)
|
54 |
+
|
55 |
+
if "country" in area:
|
56 |
+
df["country"] = np.array([l.get("cc", "") for l in location])[
|
57 |
+
nan_mask["unique_country"]
|
58 |
+
]
|
59 |
+
if "region" in area:
|
60 |
+
df["region"] = np.array(
|
61 |
+
["_".join([l.get("admin1", ""), l.get("cc", "")]) for l in location]
|
62 |
+
)[nan_mask["unique_region"]]
|
63 |
+
if "sub-region" in area:
|
64 |
+
df["sub-region"] = np.array(
|
65 |
+
[
|
66 |
+
"_".join([l.get("admin2", ""), l.get("admin1", ""), l.get("cc", "")])
|
67 |
+
for l in location
|
68 |
+
]
|
69 |
+
)[nan_mask["unique_sub-region"]]
|
70 |
+
if "city" in area:
|
71 |
+
df["city"] = np.array(
|
72 |
+
[
|
73 |
+
"_".join(
|
74 |
+
[
|
75 |
+
l.get("name", ""),
|
76 |
+
l.get("admin2", ""),
|
77 |
+
l.get("admin1", ""),
|
78 |
+
l.get("cc", ""),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
for l in location
|
82 |
+
]
|
83 |
+
)[nan_mask["unique_city"]]
|
84 |
+
|
85 |
+
return df, gt_area
|
models/__init__.py
ADDED
File without changes
|
models/classification/utils_global.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import OrderedDict
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union, List
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
|
10 |
+
def check_is_valid_torchvision_architecture(architecture: str):
|
11 |
+
"""Raises an ValueError if architecture is not part of available torchvision models
|
12 |
+
"""
|
13 |
+
available = sorted(
|
14 |
+
name
|
15 |
+
for name in torchvision.models.__dict__
|
16 |
+
if name.islower()
|
17 |
+
and not name.startswith("__")
|
18 |
+
and callable(torchvision.models.__dict__[name])
|
19 |
+
)
|
20 |
+
if architecture not in available:
|
21 |
+
raise ValueError(f"{architecture} not in {available}")
|
22 |
+
|
23 |
+
|
24 |
+
def build_base_model(arch: str):
|
25 |
+
|
26 |
+
model = torchvision.models.__dict__[arch](pretrained=True)
|
27 |
+
|
28 |
+
# get input dimension before classification layer
|
29 |
+
if arch in ["mobilenet_v2"]:
|
30 |
+
nfeatures = model.classifier[-1].in_features
|
31 |
+
model = torch.nn.Sequential(*list(model.children())[:-1])
|
32 |
+
elif arch in ["densenet121", "densenet161", "densenet169"]:
|
33 |
+
nfeatures = model.classifier.in_features
|
34 |
+
model = torch.nn.Sequential(*list(model.children())[:-1])
|
35 |
+
elif "resne" in arch:
|
36 |
+
# usually all ResNet variants
|
37 |
+
nfeatures = model.fc.in_features
|
38 |
+
model = torch.nn.Sequential(*list(model.children())[:-2])
|
39 |
+
else:
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
|
43 |
+
model.flatten = torch.nn.Flatten(start_dim=1)
|
44 |
+
return model, nfeatures
|
45 |
+
|
46 |
+
|
47 |
+
def load_weights_if_available(
|
48 |
+
model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path]
|
49 |
+
):
|
50 |
+
|
51 |
+
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
52 |
+
|
53 |
+
state_dict_features = OrderedDict()
|
54 |
+
state_dict_classifier = OrderedDict()
|
55 |
+
for k, w in checkpoint["state_dict"].items():
|
56 |
+
if k.startswith("model"):
|
57 |
+
state_dict_features[k.replace("model.", "")] = w
|
58 |
+
elif k.startswith("classifier"):
|
59 |
+
state_dict_classifier[k.replace("classifier.", "")] = w
|
60 |
+
else:
|
61 |
+
logging.warning(f"Unexpected prefix in state_dict: {k}")
|
62 |
+
model.load_state_dict(state_dict_features, strict=True)
|
63 |
+
return model, classifier
|
64 |
+
|
65 |
+
|
66 |
+
def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt):
|
67 |
+
R = 6371
|
68 |
+
factor_rad = 0.01745329252
|
69 |
+
longitudes = factor_rad * longitudes
|
70 |
+
longitudes_gt = factor_rad * longitudes_gt
|
71 |
+
latitudes = factor_rad * latitudes
|
72 |
+
latitudes_gt = factor_rad * latitudes_gt
|
73 |
+
delta_long = longitudes_gt - longitudes
|
74 |
+
delta_lat = latitudes_gt - latitudes
|
75 |
+
subterm0 = torch.sin(delta_lat / 2) ** 2
|
76 |
+
subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt)
|
77 |
+
subterm2 = torch.sin(delta_long / 2) ** 2
|
78 |
+
subterm1 = subterm1 * subterm2
|
79 |
+
a = subterm0 + subterm1
|
80 |
+
c = 2 * torch.asin(torch.sqrt(a))
|
81 |
+
gcd = R * c
|
82 |
+
return gcd
|
83 |
+
|
84 |
+
|
85 |
+
def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]):
|
86 |
+
# calculate accuracy for given gcd thresolds
|
87 |
+
results = {}
|
88 |
+
for thres in thresholds:
|
89 |
+
results[thres] = torch.true_divide(
|
90 |
+
torch.sum(gc_dists <= thres), len(gc_dists)
|
91 |
+
).item()
|
92 |
+
return results
|
93 |
+
|
94 |
+
|
95 |
+
def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)):
|
96 |
+
def _accuracy(output, target, topk=(1,)):
|
97 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
98 |
+
with torch.no_grad():
|
99 |
+
maxk = max(topk)
|
100 |
+
batch_size = target.size(0)
|
101 |
+
|
102 |
+
_, pred = output.topk(maxk, 1, True, True)
|
103 |
+
pred = pred.t()
|
104 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
105 |
+
|
106 |
+
res = {}
|
107 |
+
for k in topk:
|
108 |
+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
109 |
+
res[k] = correct_k / batch_size
|
110 |
+
return res
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
out_dict = {}
|
114 |
+
for i, pname in enumerate(partitioning_shortnames):
|
115 |
+
res_dict = _accuracy(output[i], target[i], topk=topk)
|
116 |
+
for k, v in res_dict.items():
|
117 |
+
out_dict[f"acc{k}_val/{pname}"] = v
|
118 |
+
|
119 |
+
return out_dict
|
120 |
+
|
121 |
+
|
122 |
+
def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None):
|
123 |
+
gcd_dict = {}
|
124 |
+
metric_names = [f"gcd_{p}_val" for p in pnames]
|
125 |
+
if hierarchy is not None:
|
126 |
+
metric_names.append("gcd_hierarchy_val")
|
127 |
+
for metric_name in metric_names:
|
128 |
+
distances_flat = [output[metric_name] for output in outputs]
|
129 |
+
distances_flat = torch.cat(distances_flat, dim=0)
|
130 |
+
gcd_results = gcd_threshold_eval(distances_flat)
|
131 |
+
for gcd_thres, acc in gcd_results.items():
|
132 |
+
gcd_dict[f"{metric_name}/{gcd_thres}"] = acc
|
133 |
+
return gcd_dict
|
134 |
+
|
135 |
+
|
136 |
+
def summarize_test_gcd(pnames, outputs, hierarchy=None):
|
137 |
+
def _eval(output):
|
138 |
+
# calculate acc@km for a list of given thresholds
|
139 |
+
accuracy_outputs = {}
|
140 |
+
if hierarchy is not None:
|
141 |
+
pnames.append("hierarchy")
|
142 |
+
for pname in pnames:
|
143 |
+
# concat batches of distances
|
144 |
+
distances_flat = torch.cat([x[pname] for x in output], dim=0)
|
145 |
+
# acc for all distances
|
146 |
+
acc_dict = gcd_threshold_eval(distances_flat)
|
147 |
+
accuracy_outputs[f"acc_test/{pname}"] = acc_dict
|
148 |
+
return accuracy_outputs
|
149 |
+
|
150 |
+
result = {}
|
151 |
+
|
152 |
+
if isinstance(outputs[0], dict): # only one testset
|
153 |
+
result = _eval(outputs)
|
154 |
+
elif isinstance(outputs[0], list): # multiple testsets
|
155 |
+
for testset_index, output in enumerate(outputs):
|
156 |
+
result[testset_index] = _eval(output)
|
157 |
+
else:
|
158 |
+
raise TypeError
|
159 |
+
|
160 |
+
return result
|
161 |
+
|
162 |
+
|
163 |
+
def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]):
|
164 |
+
|
165 |
+
loss_acc_dict = {}
|
166 |
+
metric_names = []
|
167 |
+
for k in topk:
|
168 |
+
accuracy_names = [f"acc{k}_val/{p}" for p in pnames]
|
169 |
+
metric_names.extend(accuracy_names)
|
170 |
+
metric_names.extend([f"loss_val/{p}" for p in pnames])
|
171 |
+
for metric_name in ["loss_val/total", *metric_names]:
|
172 |
+
metric_total = 0
|
173 |
+
for output in outputs:
|
174 |
+
metric_value = output[metric_name]
|
175 |
+
metric_total += metric_value
|
176 |
+
loss_acc_dict[metric_name] = metric_total / len(outputs)
|
177 |
+
return loss_acc_dict
|
models/eval_best_model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any
|
3 |
+
import pytorch_lightning as L
|
4 |
+
import torch
|
5 |
+
from hydra.utils import instantiate
|
6 |
+
from models.huggingface import Geolocalizer
|
7 |
+
|
8 |
+
class EvalModule(L.LightningModule):
|
9 |
+
def __init__(self, cfg):
|
10 |
+
super().__init__()
|
11 |
+
self.cfg = cfg
|
12 |
+
os.chdir(cfg.network.root_dir)
|
13 |
+
self.model = Geolocalizer.from_pretrained('osv5m/baseline')
|
14 |
+
self.test_metrics = instantiate(cfg.test_metrics)
|
15 |
+
|
16 |
+
def training_step(self, batch, batch_idx):
|
17 |
+
pred = self.model(batch)
|
18 |
+
pass
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
def validation_step(self, batch, batch_idx):
|
22 |
+
pred = self.model(batch)
|
23 |
+
pass
|
24 |
+
|
25 |
+
def on_validation_epoch_end(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def test_step(self, batch, batch_idx):
|
30 |
+
pred = self.model.forward_tensor(batch)
|
31 |
+
self.test_metrics.update({"gps": pred}, batch)
|
32 |
+
|
33 |
+
def on_test_epoch_end(self):
|
34 |
+
metrics = self.test_metrics.compute()
|
35 |
+
for metric_name, metric_value in metrics.items():
|
36 |
+
self.log(
|
37 |
+
f"test/{metric_name}",
|
38 |
+
metric_value,
|
39 |
+
sync_dist=True,
|
40 |
+
on_step=False,
|
41 |
+
on_epoch=True,
|
42 |
+
)
|
43 |
+
|
44 |
+
def lr_scheduler_step(self, scheduler, metric):
|
45 |
+
scheduler.step(self.global_step)
|
46 |
+
|
47 |
+
|
48 |
+
def get_parameter_names(model, forbidden_layer_types):
|
49 |
+
"""
|
50 |
+
Returns the names of the model parameters that are not inside a forbidden layer.
|
51 |
+
Taken from HuggingFace transformers.
|
52 |
+
"""
|
53 |
+
result = []
|
54 |
+
for name, child in model.named_children():
|
55 |
+
result += [
|
56 |
+
f"{name}.{n}"
|
57 |
+
for n in get_parameter_names(child, forbidden_layer_types)
|
58 |
+
if not isinstance(child, tuple(forbidden_layer_types))
|
59 |
+
]
|
60 |
+
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
61 |
+
result += list(model._parameters.keys())
|
62 |
+
return result
|
models/huggingface.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from hydra.utils import instantiate
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from huggingface_hub import PyTorchModelHubMixin
|
6 |
+
|
7 |
+
class Geolocalizer(nn.Module, PyTorchModelHubMixin):
|
8 |
+
def __init__(self, config):
|
9 |
+
super().__init__()
|
10 |
+
self.config = OmegaConf.create(config)
|
11 |
+
self.transform = instantiate(self.config.transform)
|
12 |
+
self.model = instantiate(self.config.model)
|
13 |
+
self.head = self.model.head
|
14 |
+
self.mid = self.model.mid
|
15 |
+
self.backbone = self.model.backbone
|
16 |
+
|
17 |
+
def forward(self, img: torch.Tensor):
|
18 |
+
output = self.head(self.mid(self.backbone({"img": img})), None)
|
19 |
+
return output["gps"]
|
20 |
+
|
21 |
+
def forward_tensor(self, img: torch.Tensor):
|
22 |
+
output = self.head(self.mid(self.backbone(img)), None)
|
23 |
+
return output["gps"]
|
24 |
+
|
models/losses.py
ADDED
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from os.path import join
|
6 |
+
from models.networks.utils import NormGPS
|
7 |
+
|
8 |
+
|
9 |
+
class L1(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(L1, self).__init__()
|
12 |
+
|
13 |
+
def forward(self, x, y):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
x: dict that contains "gps": torch.Tensor Bx2
|
17 |
+
y: dict that contains "gps": torch.Tensor Bx2
|
18 |
+
Returns:
|
19 |
+
torch.Tensor: L1 loss between x and y: torch.Tensor([B])
|
20 |
+
"""
|
21 |
+
return {"L1_loss": torch.abs(x["gps"] - y["gps"]).mean(dim=-1)}
|
22 |
+
|
23 |
+
|
24 |
+
class L2(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(L2, self).__init__()
|
27 |
+
|
28 |
+
def forward(self, x, y):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
x: dict that contains "gps": torch.Tensor Bx2
|
32 |
+
y: dict that contains "gps": torch.Tensor Bx2
|
33 |
+
Returns:
|
34 |
+
torch.Tensor: L2 loss between x and y: torch.Tensor([B])
|
35 |
+
"""
|
36 |
+
return {"L2_loss": ((x["gps"] - y["gps"]) ** 2).mean(dim=-1)}
|
37 |
+
|
38 |
+
|
39 |
+
class L2Hybrid(nn.Module):
|
40 |
+
def __init__(self):
|
41 |
+
super(L2Hybrid, self).__init__()
|
42 |
+
self.norm = NormGPS()
|
43 |
+
|
44 |
+
def forward(self, x, y):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
x: dict that contains "gps": torch.Tensor Bx2
|
48 |
+
y: dict that contains "gps": torch.Tensor Bx2
|
49 |
+
Returns:
|
50 |
+
torch.Tensor: L2 loss between x and y: torch.Tensor([B])
|
51 |
+
"""
|
52 |
+
return {
|
53 |
+
"L2_loss": (
|
54 |
+
(x["reg"] - (self.norm(y["gps"]) - x["center"]) * x["size"]) ** 2
|
55 |
+
).mean(dim=-1)
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
class CrossEntropy(nn.Module):
|
60 |
+
def __init__(self):
|
61 |
+
super(CrossEntropy, self).__init__()
|
62 |
+
self.loss = nn.CrossEntropyLoss(reduction="none")
|
63 |
+
|
64 |
+
def forward(self, x, y):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
x: dict that contains "label": torch.Tensor BxN
|
68 |
+
y: dict that contains "label": torch.Tensor BxN
|
69 |
+
Returns:
|
70 |
+
torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
|
71 |
+
"""
|
72 |
+
return {"cross_entropy_loss": self.loss(x["label"], y["label"])}
|
73 |
+
|
74 |
+
|
75 |
+
class HierarchicalCrossEntropyQuad(nn.Module):
|
76 |
+
def __init__(self, data_path=""):
|
77 |
+
super(HierarchicalCrossEntropyQuad, self).__init__()
|
78 |
+
self.dict_losses = {"classif_loss": nn.CrossEntropyLoss(reduction="none")}
|
79 |
+
for i in range(1, 10):
|
80 |
+
self.dict_losses[f"quadtree_{i}_loss"] = nn.NLLLoss()
|
81 |
+
self.matrixes = torch.load(join(data_path, "quadtree_matrixes.pt"))
|
82 |
+
self.dicts = torch.load(join(data_path, "quadtree_dicts.pt"))
|
83 |
+
self.id_to_quad = torch.load(join(data_path, "id_to_quad_10_1000.pt"))
|
84 |
+
|
85 |
+
def forward(self, x, y):
|
86 |
+
"""
|
87 |
+
Args:
|
88 |
+
x: dict that contains "label": torch.Tensor BxN
|
89 |
+
y: dict that contains "label": torch.Tensor BxN
|
90 |
+
Returns:
|
91 |
+
torch.Tensor: Hierarchical CrossEntropy for Quadtrees loss between x and y: torch.Tensor([B])
|
92 |
+
"""
|
93 |
+
out = {"classif_loss": self.dict_losses["classif_loss"](x["label"], y["label"])}
|
94 |
+
probas = nn.functional.softmax(x["label"], dim=1)
|
95 |
+
device = x["label"].device
|
96 |
+
gt = self.id_to_quad[y["label"].cpu()]
|
97 |
+
for i in range(9):
|
98 |
+
logits = torch.log(torch.mm(probas, self.matrixes[i].to(device)) + 1e-10)
|
99 |
+
l = [s[: 9 - i] if len(s) >= 10 - i else s for s in gt]
|
100 |
+
out[f"quadtree_{i+1}_loss"] = self.dict_losses[f"quadtree_{i+1}_loss"](
|
101 |
+
logits, torch.tensor([self.dicts[i][item] for item in l]).to(device)
|
102 |
+
)
|
103 |
+
|
104 |
+
return out
|
105 |
+
|
106 |
+
|
107 |
+
class HierarchicalCrossEntropy(nn.Module):
|
108 |
+
def __init__(self, path=""):
|
109 |
+
super(HierarchicalCrossEntropy, self).__init__()
|
110 |
+
self.city_loss = nn.CrossEntropyLoss(reduction="none")
|
111 |
+
self.country_loss = nn.NLLLoss()
|
112 |
+
self.area_loss = nn.NLLLoss()
|
113 |
+
self.region_loss = nn.NLLLoss()
|
114 |
+
self.city_to_country = torch.load(path + "city_to_country.pt")
|
115 |
+
self.city_to_region = torch.load(path + "city_to_region.pt")
|
116 |
+
self.city_to_area = torch.load(path + "city_to_area.pt")
|
117 |
+
self.country_to_idx = torch.load(path + "country_to_idx.pt")
|
118 |
+
self.region_to_idx = torch.load(path + "region_to_idx.pt")
|
119 |
+
self.area_to_idx = torch.load(path + "area_to_idx.pt")
|
120 |
+
|
121 |
+
def forward(self, x, y):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
x: dict that contains "label": torch.Tensor BxN
|
125 |
+
y: dict that contains "label": torch.Tensor BxN
|
126 |
+
Returns:
|
127 |
+
torch.Tensor: Hierarchical CrossEntropy loss between x and y: torch.Tensor([B])
|
128 |
+
"""
|
129 |
+
country_mask = np.array(y["unique_country"]) != "NaN"
|
130 |
+
self.city_to_country = self.city_to_country.to(x["label"].device)
|
131 |
+
countries_probas = nn.functional.softmax(x["label"][country_mask], dim=1)
|
132 |
+
countries_logits = torch.log(
|
133 |
+
torch.mm(countries_probas, self.city_to_country) + 1e-10
|
134 |
+
)
|
135 |
+
country_gt = torch.tensor(
|
136 |
+
[
|
137 |
+
self.country_to_idx[item]
|
138 |
+
for item in np.array(y["unique_country"])[country_mask]
|
139 |
+
]
|
140 |
+
).to(x["label"].device)
|
141 |
+
|
142 |
+
region_mask = np.array(y["unique_region"]) != "NaN"
|
143 |
+
self.city_to_region = self.city_to_region.to(x["label"].device)
|
144 |
+
regions_probas = nn.functional.softmax(x["label"][region_mask], dim=1)
|
145 |
+
regions_logits = torch.log(
|
146 |
+
torch.mm(regions_probas, self.city_to_region) + 1e-10
|
147 |
+
)
|
148 |
+
region_gt = torch.tensor(
|
149 |
+
[
|
150 |
+
self.region_to_idx[item]
|
151 |
+
for item in np.array(y["unique_region"])[region_mask]
|
152 |
+
]
|
153 |
+
).to(x["label"].device)
|
154 |
+
|
155 |
+
area_mask = np.array(y["unique_sub-region"]) != "NaN"
|
156 |
+
self.city_to_area = self.city_to_area.to(x["label"].device)
|
157 |
+
areas_probas = nn.functional.softmax(x["label"][area_mask], dim=1)
|
158 |
+
areas_logits = torch.log(torch.mm(areas_probas, self.city_to_area) + 1e-10)
|
159 |
+
area_gt = torch.tensor(
|
160 |
+
[
|
161 |
+
self.area_to_idx[item]
|
162 |
+
for item in np.array(y["unique_sub-region"])[area_mask]
|
163 |
+
]
|
164 |
+
).to(x["label"].device)
|
165 |
+
|
166 |
+
return {
|
167 |
+
"cross_entropy_country_loss": self.country_loss(
|
168 |
+
countries_logits, country_gt
|
169 |
+
),
|
170 |
+
"cross_entropy_city_loss": self.city_loss(x["label"], y["label"]),
|
171 |
+
"cross_entropy_area_loss": self.area_loss(areas_logits, area_gt),
|
172 |
+
"cross_entropy_region_loss": self.region_loss(regions_logits, region_gt),
|
173 |
+
}
|
174 |
+
|
175 |
+
|
176 |
+
class LandCoverLoss(nn.Module):
|
177 |
+
def __init__(self):
|
178 |
+
super(LandCoverLoss, self).__init__()
|
179 |
+
self.loss = nn.CrossEntropyLoss()
|
180 |
+
|
181 |
+
def forward(self, x, y):
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
x: dict that contains "land_cover": torch.Tensor BxN
|
185 |
+
y: dict that contains "land_cover": torch.Tensor BxN
|
186 |
+
Returns:
|
187 |
+
torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
|
188 |
+
"""
|
189 |
+
return {
|
190 |
+
"land_cover_cross_entropy_loss": self.loss(x["land_cover"], y["land_cover"])
|
191 |
+
}
|
192 |
+
|
193 |
+
|
194 |
+
class RoadIndexLoss(nn.Module):
|
195 |
+
def __init__(self):
|
196 |
+
super(RoadIndexLoss, self).__init__()
|
197 |
+
self.loss = nn.MSELoss()
|
198 |
+
|
199 |
+
def forward(self, x, y):
|
200 |
+
"""
|
201 |
+
Args:
|
202 |
+
x: dict that contains "road_index": torch.Tensor BxN
|
203 |
+
y: dict that contains "road_index": torch.Tensor BxN
|
204 |
+
Returns:
|
205 |
+
torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
|
206 |
+
"""
|
207 |
+
return {"road_index_mse_loss": self.loss(x["road_index"], y["road_index"])}
|
208 |
+
|
209 |
+
|
210 |
+
class DriveSideLoss(nn.Module):
|
211 |
+
def __init__(self):
|
212 |
+
super(DriveSideLoss, self).__init__()
|
213 |
+
self.loss = nn.BCELoss()
|
214 |
+
|
215 |
+
def forward(self, x, y):
|
216 |
+
"""
|
217 |
+
Args:
|
218 |
+
x: dict that contains "drive_side": torch.Tensor BxN
|
219 |
+
y: dict that contains "drive_side": torch.Tensor BxN
|
220 |
+
Returns:
|
221 |
+
torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
|
222 |
+
"""
|
223 |
+
return {"drive_side_bce_loss": self.loss(x["drive_side"], y["drive_side"])}
|
224 |
+
|
225 |
+
|
226 |
+
class ClimateLoss(nn.Module):
|
227 |
+
def __init__(self):
|
228 |
+
super(ClimateLoss, self).__init__()
|
229 |
+
self.loss = nn.CrossEntropyLoss()
|
230 |
+
|
231 |
+
def forward(self, x, y):
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
x: dict that contains "climate": torch.Tensor BxN
|
235 |
+
y: dict that contains "climate": torch.Tensor BxN
|
236 |
+
Returns:
|
237 |
+
torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
|
238 |
+
"""
|
239 |
+
return {"climate_cross_entropy_loss": self.loss(x["climate"], y["climate"])}
|
240 |
+
|
241 |
+
|
242 |
+
class SoilLoss(nn.Module):
|
243 |
+
def __init__(self):
|
244 |
+
super(SoilLoss, self).__init__()
|
245 |
+
self.loss = nn.CrossEntropyLoss()
|
246 |
+
|
247 |
+
def forward(self, x, y):
|
248 |
+
"""
|
249 |
+
Args:
|
250 |
+
x: dict that contains "soil": torch.Tensor BxN
|
251 |
+
y: dict that contains "soil": torch.Tensor BxN
|
252 |
+
Returns:
|
253 |
+
torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
|
254 |
+
"""
|
255 |
+
return {"soil_cross_entropy_loss": self.loss(x["soil"], y["soil"])}
|
256 |
+
|
257 |
+
|
258 |
+
class DistSeaLoss(nn.Module):
|
259 |
+
def __init__(self):
|
260 |
+
super(DistSeaLoss, self).__init__()
|
261 |
+
self.loss = nn.MSELoss()
|
262 |
+
|
263 |
+
def forward(self, x, y):
|
264 |
+
"""
|
265 |
+
Args:
|
266 |
+
x: dict that contains "dist_sea": torch.Tensor BxN
|
267 |
+
y: dict that contains "dist_sea": torch.Tensor BxN
|
268 |
+
Returns:
|
269 |
+
torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
|
270 |
+
"""
|
271 |
+
return {"dist_sea_mse_loss": self.loss(x["dist_sea"], y["dist_sea"])}
|
272 |
+
|
273 |
+
|
274 |
+
class Haversine(nn.Module):
|
275 |
+
def __init__(self):
|
276 |
+
super(Haversine, self).__init__()
|
277 |
+
|
278 |
+
def forward(self, x, y):
|
279 |
+
"""
|
280 |
+
Args:
|
281 |
+
x: dict that contains "gps": torch.Tensor Bx2
|
282 |
+
y: dict that contains "gps": torch.Tensor Bx2
|
283 |
+
Returns:
|
284 |
+
torch.Tensor: Haversine loss between x and y: torch.Tensor([B])
|
285 |
+
Note:
|
286 |
+
Haversine distance doesn't contain the 2 * 6371 constant.
|
287 |
+
"""
|
288 |
+
x, y = x["gps"], y["gps"]
|
289 |
+
lhs = torch.sin((x[:, 0] - y[:, 0]) / 2) ** 2
|
290 |
+
rhs = (
|
291 |
+
torch.cos(x[:, 0])
|
292 |
+
* torch.cos(y[:, 0])
|
293 |
+
* torch.sin((x[:, 1] - y[:, 1]) / 2) ** 2
|
294 |
+
)
|
295 |
+
a = lhs + rhs
|
296 |
+
return {
|
297 |
+
"haversine_loss": torch.arctan2(torch.sqrt(a), torch.sqrt(1 - a))
|
298 |
+
} # ommitting 2 * 6371 as both are a constant
|
299 |
+
|
300 |
+
|
301 |
+
class GeoguessrLoss(Haversine):
|
302 |
+
def __init__(self):
|
303 |
+
super(GeoguessrLoss, self).__init__()
|
304 |
+
|
305 |
+
def forward(self, x, y):
|
306 |
+
distance = super().forward(x, y)["haversine_loss"]
|
307 |
+
loss = torch.exp(-distance / 1852)
|
308 |
+
return {"geoguessr_loss": loss}
|
309 |
+
|
310 |
+
|
311 |
+
class InfoNCE(nn.Module):
|
312 |
+
def __init__(self, tau=0.1):
|
313 |
+
super(InfoNCE, self).__init__()
|
314 |
+
self.tau = tau
|
315 |
+
|
316 |
+
def cosine_similarity(self, a, b, normalize=True):
|
317 |
+
if normalize:
|
318 |
+
w1 = a.norm(p=2, dim=1, keepdim=True)
|
319 |
+
w2 = b.norm(p=2, dim=1, keepdim=True)
|
320 |
+
sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
|
321 |
+
else:
|
322 |
+
sim_matrix = torch.mm(a, b.t())
|
323 |
+
return sim_matrix
|
324 |
+
|
325 |
+
def forward(self, x, y=None):
|
326 |
+
"""
|
327 |
+
neg_sim: BxB
|
328 |
+
pos_sim: Bx1
|
329 |
+
"""
|
330 |
+
features = x["features"]
|
331 |
+
positive_features = x["pos_features"]
|
332 |
+
pos_sim = F.cosine_similarity(
|
333 |
+
features, positive_features, dim=1, eps=1e-8
|
334 |
+
).unsqueeze(1)
|
335 |
+
neg_sim = self.cosine_similarity(features, features, normalize=True)
|
336 |
+
|
337 |
+
b = neg_sim.shape[0]
|
338 |
+
logits = (1 - torch.eye(b)).type_as(neg_sim) * neg_sim + torch.eye(b).type_as(
|
339 |
+
pos_sim
|
340 |
+
) * pos_sim
|
341 |
+
logits = logits / self.tau
|
342 |
+
labels = torch.arange(b, dtype=torch.long).cuda()
|
343 |
+
loss = F.cross_entropy(logits, labels)
|
344 |
+
return {
|
345 |
+
"contrastive_loss": loss,
|
346 |
+
}
|
347 |
+
|
348 |
+
|
349 |
+
class TextNCE(nn.Module):
|
350 |
+
def __init__(self, tau=0.1, num_devices=1):
|
351 |
+
super(TextNCE, self).__init__()
|
352 |
+
self.distributed = num_devices > 1
|
353 |
+
self.tau = tau
|
354 |
+
|
355 |
+
def cosine_similarity(self, a, b, normalize=True):
|
356 |
+
if normalize:
|
357 |
+
w1 = a.norm(p=2, dim=1, keepdim=True)
|
358 |
+
w2 = b.norm(p=2, dim=1, keepdim=True)
|
359 |
+
sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
|
360 |
+
else:
|
361 |
+
sim_matrix = torch.mm(a, b.t())
|
362 |
+
return sim_matrix
|
363 |
+
|
364 |
+
def forward(self, x, y=None):
|
365 |
+
"""
|
366 |
+
neg_sim: BxB
|
367 |
+
pos_sim: Bx1
|
368 |
+
"""
|
369 |
+
if self.distributed:
|
370 |
+
all_image_features = torch.cat(
|
371 |
+
torch.distributed.nn.all_gather(x["features"]), dim=0
|
372 |
+
)
|
373 |
+
all_text_features = torch.cat(
|
374 |
+
torch.distributed.nn.all_gather(x["text_features"]), dim=0
|
375 |
+
)
|
376 |
+
all_labels = torch.cat(torch.distributed.nn.all_gather(y["label"]), dim=0)
|
377 |
+
else:
|
378 |
+
all_image_features = x["features"]
|
379 |
+
all_text_features = x["text_features"]
|
380 |
+
all_labels = y["label"]
|
381 |
+
labels_u = torch.unique(all_labels)
|
382 |
+
logits = self.cosine_similarity(
|
383 |
+
all_image_features, all_text_features, normalize=True
|
384 |
+
)
|
385 |
+
rows, cols = logits.size()
|
386 |
+
indices = torch.arange(0, rows, device=all_image_features.device)
|
387 |
+
loss = torch.sum(
|
388 |
+
torch.logsumexp(
|
389 |
+
logits[indices != indices.view(-1, 1)].view(rows, cols - 1) / self.tau,
|
390 |
+
dim=1,
|
391 |
+
)
|
392 |
+
)
|
393 |
+
for label in labels_u:
|
394 |
+
if not (label == "NaN"):
|
395 |
+
# Get the positive and negative examples
|
396 |
+
idx = all_labels == label
|
397 |
+
pos_logits = logits[idx][:, idx]
|
398 |
+
# Compute the MIL-NCE loss
|
399 |
+
loss += torch.sum(-torch.logsumexp(pos_logits / self.tau, dim=1))
|
400 |
+
return {
|
401 |
+
"contrastive_loss": loss,
|
402 |
+
}
|
403 |
+
|
404 |
+
|
405 |
+
class MILNCE(nn.Module):
|
406 |
+
def __init__(self, tau=0.1, num_devices=1):
|
407 |
+
super(MILNCE, self).__init__()
|
408 |
+
self.distributed = num_devices > 1
|
409 |
+
self.tau = tau
|
410 |
+
|
411 |
+
def cosine_similarity(self, a, b, normalize=True):
|
412 |
+
if normalize:
|
413 |
+
w1 = a.norm(p=2, dim=1, keepdim=True)
|
414 |
+
w2 = b.norm(p=2, dim=1, keepdim=True)
|
415 |
+
sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
|
416 |
+
else:
|
417 |
+
sim_matrix = torch.mm(a, b.t())
|
418 |
+
return sim_matrix
|
419 |
+
|
420 |
+
def forward(self, x, y=None):
|
421 |
+
"""
|
422 |
+
COmpute MIL-NCE loss
|
423 |
+
"""
|
424 |
+
if self.distributed:
|
425 |
+
all_image_features = torch.cat(
|
426 |
+
torch.distributed.nn.all_gather(x["features"]), dim=0
|
427 |
+
)
|
428 |
+
all_pos_features = torch.cat(
|
429 |
+
torch.distributed.nn.all_gather(x["pos_features"]), dim=0
|
430 |
+
)
|
431 |
+
all_labels = torch.cat(torch.distributed.nn.all_gather(y["label"]), dim=0)
|
432 |
+
else:
|
433 |
+
all_image_features = x["features"]
|
434 |
+
all_pos_features = x["pos_features"]
|
435 |
+
all_labels = y["label"]
|
436 |
+
labels_u = torch.unique(all_labels)
|
437 |
+
features = torch.cat([all_image_features, all_pos_features])
|
438 |
+
labels = torch.cat([all_labels, all_labels])
|
439 |
+
logits = self.cosine_similarity(features, features, normalize=True)
|
440 |
+
rows, cols = logits.size()
|
441 |
+
indices = torch.arange(0, rows, device=features.device)
|
442 |
+
loss = torch.sum(
|
443 |
+
torch.logsumexp(
|
444 |
+
logits[indices != indices.view(-1, 1)].view(rows, cols - 1) / self.tau,
|
445 |
+
dim=1,
|
446 |
+
)
|
447 |
+
)
|
448 |
+
for label in labels_u:
|
449 |
+
if not (label == "NaN"):
|
450 |
+
# Get the positive and negative examples
|
451 |
+
idx = labels == label
|
452 |
+
pos_logits = logits[idx][:, idx]
|
453 |
+
|
454 |
+
rows, cols = pos_logits.size()
|
455 |
+
indices = torch.arange(0, rows, device=features.device)
|
456 |
+
pos_logits = pos_logits[indices != indices.view(-1, 1)].view(
|
457 |
+
rows, cols - 1
|
458 |
+
)
|
459 |
+
|
460 |
+
# Compute the MIL-NCE loss
|
461 |
+
loss += torch.sum(-torch.logsumexp(pos_logits / self.tau, dim=1))
|
462 |
+
return {
|
463 |
+
"contrastive_loss": loss,
|
464 |
+
}
|
465 |
+
|
466 |
+
|
467 |
+
class RegionMILNCE(nn.Module):
|
468 |
+
def __init__(self, tau=0.1, num_devices=1):
|
469 |
+
super(RegionMILNCE, self).__init__()
|
470 |
+
self.distributed = num_devices > 1
|
471 |
+
self.tau = tau
|
472 |
+
|
473 |
+
def cosine_similarity(self, a, b, normalize=True):
|
474 |
+
if normalize:
|
475 |
+
w1 = a.norm(p=2, dim=1, keepdim=True)
|
476 |
+
w2 = b.norm(p=2, dim=1, keepdim=True)
|
477 |
+
sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
|
478 |
+
else:
|
479 |
+
sim_matrix = torch.mm(a, b.t())
|
480 |
+
return sim_matrix
|
481 |
+
|
482 |
+
def forward(self, x, y=None):
|
483 |
+
"""
|
484 |
+
neg_sim: BxB
|
485 |
+
pos_sim: Bx1
|
486 |
+
"""
|
487 |
+
if self.distributed:
|
488 |
+
all_image_features = torch.cat(
|
489 |
+
torch.distributed.nn.all_gather(x["features"]), dim=0
|
490 |
+
)
|
491 |
+
all_pos_features = torch.cat(
|
492 |
+
torch.distributed.nn.all_gather(x["pos_features"]), dim=0
|
493 |
+
)
|
494 |
+
all_labels = torch.cat(torch.distributed.nn.all_gather(y["label"]), dim=0)
|
495 |
+
else:
|
496 |
+
all_image_features = x["features"]
|
497 |
+
all_pos_features = x["pos_features"]
|
498 |
+
all_labels = y["label"]
|
499 |
+
labels_u = torch.unique(all_labels)
|
500 |
+
features = torch.cat([all_image_features, all_pos_features])
|
501 |
+
labels = torch.cat([all_labels, all_labels])
|
502 |
+
logits = self.cosine_similarity(features, features, normalize=True)
|
503 |
+
rows, cols = logits.size()
|
504 |
+
indices = torch.arange(0, rows, device=features.device)
|
505 |
+
loss = torch.sum(
|
506 |
+
torch.logsumexp(
|
507 |
+
logits[indices != indices.view(-1, 1)].view(rows, cols - 1) / self.tau,
|
508 |
+
dim=1,
|
509 |
+
)
|
510 |
+
)
|
511 |
+
for label in labels_u:
|
512 |
+
if not (label == "NaN"):
|
513 |
+
# Get the positive and negative examples
|
514 |
+
idx = labels == label
|
515 |
+
pos_logits = logits[idx][:, idx]
|
516 |
+
|
517 |
+
rows, cols = pos_logits.size()
|
518 |
+
indices = torch.arange(0, rows, device=features.device)
|
519 |
+
pos_logits = pos_logits[indices != indices.view(-1, 1)].view(
|
520 |
+
rows, cols - 1
|
521 |
+
)
|
522 |
+
|
523 |
+
# Compute the MIL-NCE loss
|
524 |
+
loss += torch.sum(-torch.logsumexp(pos_logits / self.tau, dim=1))
|
525 |
+
return {
|
526 |
+
"contrastive_loss": loss / len(all_labels),
|
527 |
+
}
|
528 |
+
|
529 |
+
|
530 |
+
LOSSES = {
|
531 |
+
"l1": L1,
|
532 |
+
"l2": L2,
|
533 |
+
"l2_hybrid": L2Hybrid,
|
534 |
+
"haversine": Haversine,
|
535 |
+
"geoguessr": GeoguessrLoss,
|
536 |
+
"crossentropy": CrossEntropy,
|
537 |
+
"infonce": InfoNCE,
|
538 |
+
"mil-nce": MILNCE,
|
539 |
+
"text-nce": TextNCE,
|
540 |
+
"land_cover": LandCoverLoss,
|
541 |
+
"road_index": RoadIndexLoss,
|
542 |
+
"drive_side": DriveSideLoss,
|
543 |
+
"climate": ClimateLoss,
|
544 |
+
"soil": SoilLoss,
|
545 |
+
"dist_sea": DistSeaLoss,
|
546 |
+
"hierarchical": HierarchicalCrossEntropy,
|
547 |
+
"hier_quad": HierarchicalCrossEntropyQuad,
|
548 |
+
"region_mil": RegionMILNCE,
|
549 |
+
}
|
550 |
+
AVERAGE = {False: lambda x: x, True: lambda x: x.mean(dim=-1)}
|
551 |
+
|
552 |
+
|
553 |
+
class Losses(nn.Module):
|
554 |
+
"""The Losses meta-object that can take a mix of losses."""
|
555 |
+
|
556 |
+
def __init__(self, mix={}, aux_data=[], path="", num_devices=1):
|
557 |
+
"""Initializes the Losses object.
|
558 |
+
Args:
|
559 |
+
mix (dict): dictionary with keys "loss_name" and values weight
|
560 |
+
"""
|
561 |
+
super(Losses, self).__init__()
|
562 |
+
assert len(mix)
|
563 |
+
self.aux = len(aux_data) > 0
|
564 |
+
if self.aux:
|
565 |
+
self.aux_list = aux_data
|
566 |
+
total = ["land_cover", "drive_side", "climate", "soil", "dist_sea"]
|
567 |
+
for col in self.aux_list:
|
568 |
+
total.remove(col)
|
569 |
+
for col in total:
|
570 |
+
del mix[col]
|
571 |
+
self.init_losses(mix, path, num_devices)
|
572 |
+
|
573 |
+
def init_losses(self, mix, path="", num_devices=1):
|
574 |
+
"""Initializes the losses.
|
575 |
+
Args:
|
576 |
+
mix (dict): dictionary with keys "loss_name" and values weight
|
577 |
+
"""
|
578 |
+
self.loss = {}
|
579 |
+
for m, v in mix.items():
|
580 |
+
m = m.lower()
|
581 |
+
if m in ["hierarchical", "hier_quad"]:
|
582 |
+
try:
|
583 |
+
self.loss[m] = (LOSSES[m](path), v)
|
584 |
+
except KeyError:
|
585 |
+
raise KeyError(f"Loss {m} not found in {LOSSES.keys()}")
|
586 |
+
elif m in ["region_mil", "mil-nce", "text-nce"]:
|
587 |
+
try:
|
588 |
+
self.loss[m] = (LOSSES[m](num_devices=num_devices), v)
|
589 |
+
except KeyError:
|
590 |
+
raise KeyError(f"Loss {m} not found in {LOSSES.keys()}")
|
591 |
+
else:
|
592 |
+
try:
|
593 |
+
self.loss[m] = (LOSSES[m](), v)
|
594 |
+
except KeyError:
|
595 |
+
raise KeyError(f"Loss {m} not found in {LOSSES.keys()}")
|
596 |
+
|
597 |
+
def forward(self, x, y, average=True):
|
598 |
+
"""Computes the losses.
|
599 |
+
Args:
|
600 |
+
x: dict that contains "gps": torch.Tensor Bx2 or "label": torch.Tensor BxN
|
601 |
+
y: dict that contains "gps": torch.Tensor Bx2 or "label": torch.Tensor BxN
|
602 |
+
average (bool): whether to average the losses or not
|
603 |
+
Returns:
|
604 |
+
dict: dictionary with losses
|
605 |
+
"""
|
606 |
+
output = {"loss": 0}
|
607 |
+
for loss_name, (loss, weight) in self.loss.items():
|
608 |
+
loss_output = loss(x, y)
|
609 |
+
for k, v in loss_output.items():
|
610 |
+
v = AVERAGE[average](v)
|
611 |
+
if k.endswith("_loss"):
|
612 |
+
output["loss"] += weight * v
|
613 |
+
output[k] = v
|
614 |
+
return output
|
models/misc.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class DoNothingOptimizer(nn.Module):
|
2 |
+
def __init__(self, *args, **kwargs):
|
3 |
+
pass
|
4 |
+
|
5 |
+
def step(self, *args, **kwargs):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def zero_grad(self, *args, **kwargs):
|
9 |
+
pass
|
models/module.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any
|
3 |
+
import pytorch_lightning as L
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from hydra.utils import instantiate
|
7 |
+
import copy
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class Geolocalizer(L.LightningModule):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super().__init__()
|
15 |
+
self.cfg = cfg
|
16 |
+
self.model = instantiate(cfg.network.instance)
|
17 |
+
if cfg.text_tuning:
|
18 |
+
self.text_model = instantiate(cfg.text_network.instance)
|
19 |
+
self.loss = instantiate(cfg.loss)
|
20 |
+
self.val_metrics = instantiate(cfg.val_metrics)
|
21 |
+
self.test_metrics = instantiate(cfg.test_metrics)
|
22 |
+
self.text_tuning = cfg.text_tuning
|
23 |
+
|
24 |
+
def training_step(self, batch, batch_idx):
|
25 |
+
pred = self.model(batch)
|
26 |
+
if self.text_tuning:
|
27 |
+
pred["text_features"] = self.text_model(batch)
|
28 |
+
loss = self.loss(pred, batch, average=True)
|
29 |
+
for metric_name, metric_value in loss.items():
|
30 |
+
self.log(
|
31 |
+
f"train/{metric_name}",
|
32 |
+
metric_value,
|
33 |
+
sync_dist=True,
|
34 |
+
on_step=True,
|
35 |
+
on_epoch=True,
|
36 |
+
)
|
37 |
+
return loss
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def validation_step(self, batch, batch_idx):
|
41 |
+
pred = self.model(batch)
|
42 |
+
if self.text_tuning:
|
43 |
+
pred["text_features"] = self.text_model(batch)
|
44 |
+
loss = self.loss(pred, batch, average=True)["loss"]
|
45 |
+
self.val_metrics.update(pred, batch)
|
46 |
+
self.log("val/loss", loss, sync_dist=True, on_step=False, on_epoch=True)
|
47 |
+
|
48 |
+
def on_validation_epoch_end(self):
|
49 |
+
metrics = self.val_metrics.compute()
|
50 |
+
for metric_name, metric_value in metrics.items():
|
51 |
+
self.log(
|
52 |
+
f"val/{metric_name}",
|
53 |
+
metric_value,
|
54 |
+
sync_dist=True,
|
55 |
+
on_step=False,
|
56 |
+
on_epoch=True,
|
57 |
+
)
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def test_step(self, batch, batch_idx):
|
61 |
+
pred = self.model(batch)
|
62 |
+
self.test_metrics.update(pred, batch)
|
63 |
+
|
64 |
+
def on_test_epoch_end(self):
|
65 |
+
metrics = self.test_metrics.compute()
|
66 |
+
for metric_name, metric_value in metrics.items():
|
67 |
+
self.log(
|
68 |
+
f"test/{metric_name}",
|
69 |
+
metric_value,
|
70 |
+
sync_dist=True,
|
71 |
+
on_step=False,
|
72 |
+
on_epoch=True,
|
73 |
+
)
|
74 |
+
|
75 |
+
def configure_optimizers(self):
|
76 |
+
lora_params = []
|
77 |
+
backbone_params = []
|
78 |
+
other_params = []
|
79 |
+
last_block_params = []
|
80 |
+
for name, param in self.model.named_parameters():
|
81 |
+
if "lora" in name:
|
82 |
+
lora_params.append(param)
|
83 |
+
elif "backbone" in name:
|
84 |
+
if self.cfg.optimizer.diff_backbone_last and ".11." in name:
|
85 |
+
last_block_params.append(param)
|
86 |
+
else:
|
87 |
+
backbone_params.append(param)
|
88 |
+
else:
|
89 |
+
other_params.append(param)
|
90 |
+
|
91 |
+
params_to_optimize = [{"params": other_params}]
|
92 |
+
if self.cfg.optimizer.unfreeze_lr:
|
93 |
+
params_to_optimize += [
|
94 |
+
{"params": backbone_params, "lr": self.cfg.optimizer.backbone_lr}
|
95 |
+
]
|
96 |
+
if self.cfg.optimizer.diff_backbone_last:
|
97 |
+
params_to_optimize += [
|
98 |
+
{
|
99 |
+
"params": last_block_params,
|
100 |
+
"lr": self.cfg.optimizer.last_block_lr,
|
101 |
+
}
|
102 |
+
]
|
103 |
+
if len(lora_params) > 0:
|
104 |
+
# LoRA params sometimes train better with a different lr (~1e-4 for CLIP)
|
105 |
+
params_to_optimize += [
|
106 |
+
{"params": lora_params, "lr": self.cfg.optimizer.lora_lr}
|
107 |
+
]
|
108 |
+
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay:
|
109 |
+
parameters_names_wd = get_parameter_names(self.model, [nn.LayerNorm])
|
110 |
+
parameters_names_wd = [
|
111 |
+
name for name in parameters_names_wd if "bias" not in name
|
112 |
+
]
|
113 |
+
optimizer_grouped_parameters = [
|
114 |
+
{
|
115 |
+
"params": [
|
116 |
+
p
|
117 |
+
for n, p in self.model.named_parameters()
|
118 |
+
if n in parameters_names_wd
|
119 |
+
],
|
120 |
+
"weight_decay": self.cfg.optimizer.optim.weight_decay,
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"params": [
|
124 |
+
p
|
125 |
+
for n, p in self.model.named_parameters()
|
126 |
+
if n not in parameters_names_wd
|
127 |
+
],
|
128 |
+
"weight_decay": 0.0,
|
129 |
+
},
|
130 |
+
]
|
131 |
+
optimizer = instantiate(
|
132 |
+
self.cfg.optimizer.optim, optimizer_grouped_parameters
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
optimizer = instantiate(self.cfg.optimizer.optim, params_to_optimize)
|
136 |
+
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer)
|
137 |
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
138 |
+
|
139 |
+
def lr_scheduler_step(self, scheduler, metric):
|
140 |
+
scheduler.step(self.global_step)
|
141 |
+
|
142 |
+
|
143 |
+
def get_parameter_names(model, forbidden_layer_types):
|
144 |
+
"""
|
145 |
+
Returns the names of the model parameters that are not inside a forbidden layer.
|
146 |
+
Taken from HuggingFace transformers.
|
147 |
+
"""
|
148 |
+
result = []
|
149 |
+
for name, child in model.named_children():
|
150 |
+
result += [
|
151 |
+
f"{name}.{n}"
|
152 |
+
for n in get_parameter_names(child, forbidden_layer_types)
|
153 |
+
if not isinstance(child, tuple(forbidden_layer_types))
|
154 |
+
]
|
155 |
+
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
|
156 |
+
result += list(model._parameters.keys())
|
157 |
+
return result
|
models/networks/backbones.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.hub
|
2 |
+
|
3 |
+
from transformers import (
|
4 |
+
CLIPVisionModel,
|
5 |
+
CLIPVisionConfig,
|
6 |
+
CLIPModel,
|
7 |
+
CLIPProcessor,
|
8 |
+
AutoTokenizer,
|
9 |
+
CLIPTextModelWithProjection,
|
10 |
+
CLIPTextConfig,
|
11 |
+
CLIPVisionModelWithProjection,
|
12 |
+
ResNetModel,
|
13 |
+
ResNetConfig
|
14 |
+
)
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from PIL import Image
|
18 |
+
import requests
|
19 |
+
|
20 |
+
|
21 |
+
class CLIP(nn.Module):
|
22 |
+
def __init__(self, path):
|
23 |
+
"""Initializes the CLIP model."""
|
24 |
+
super().__init__()
|
25 |
+
if path == "":
|
26 |
+
config_vision = CLIPVisionConfig()
|
27 |
+
self.clip = CLIPVisionModel(config_vision)
|
28 |
+
else:
|
29 |
+
self.clip = CLIPVisionModel.from_pretrained(path)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
"""Predicts CLIP features from an image.
|
33 |
+
Args:
|
34 |
+
x (dict that contains "img": torch.Tensor): Input batch
|
35 |
+
"""
|
36 |
+
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
|
37 |
+
return features
|
38 |
+
|
39 |
+
|
40 |
+
class CLIPJZ(nn.Module):
|
41 |
+
def __init__(self, path):
|
42 |
+
"""Initializes the CLIP model."""
|
43 |
+
super().__init__()
|
44 |
+
if path == "":
|
45 |
+
config_vision = CLIPVisionConfig()
|
46 |
+
self.clip = CLIPVisionModel(config_vision)
|
47 |
+
else:
|
48 |
+
self.clip = CLIPVisionModel.from_pretrained(path)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
"""Predicts CLIP features from an image.
|
52 |
+
Args:
|
53 |
+
x (dict that contains "img": torch.Tensor): Input batch
|
54 |
+
"""
|
55 |
+
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
|
56 |
+
return features
|
57 |
+
|
58 |
+
|
59 |
+
class StreetCLIP(nn.Module):
|
60 |
+
def __init__(self, path):
|
61 |
+
"""Initializes the CLIP model."""
|
62 |
+
super().__init__()
|
63 |
+
self.clip = CLIPModel.from_pretrained(path)
|
64 |
+
self.transform = CLIPProcessor.from_pretrained(path)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
"""Predicts CLIP features from an image.
|
68 |
+
Args:
|
69 |
+
x (dict that contains "img": torch.Tensor): Input batch
|
70 |
+
"""
|
71 |
+
features = self.clip.get_image_features(
|
72 |
+
**self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device)
|
73 |
+
).unsqueeze(1)
|
74 |
+
return features
|
75 |
+
|
76 |
+
|
77 |
+
class CLIPText(nn.Module):
|
78 |
+
def __init__(self, path):
|
79 |
+
"""Initializes the CLIP model."""
|
80 |
+
super().__init__()
|
81 |
+
if path == "":
|
82 |
+
config_vision = CLIPVisionConfig()
|
83 |
+
self.clip = CLIPVisionModel(config_vision)
|
84 |
+
else:
|
85 |
+
self.clip = CLIPVisionModelWithProjection.from_pretrained(path)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
"""Predicts CLIP features from an image.
|
89 |
+
Args:
|
90 |
+
x (dict that contains "img": torch.Tensor): Input batch
|
91 |
+
"""
|
92 |
+
features = self.clip(pixel_values=x["img"])
|
93 |
+
return features.image_embeds, features.last_hidden_state
|
94 |
+
|
95 |
+
|
96 |
+
class TextEncoder(nn.Module):
|
97 |
+
def __init__(self, path):
|
98 |
+
"""Initializes the CLIP text model."""
|
99 |
+
super().__init__()
|
100 |
+
if path == "":
|
101 |
+
config_vision = CLIPTextConfig()
|
102 |
+
self.clip = CLIPTextModelWithProjection(config_vision)
|
103 |
+
self.transform = AutoTokenizer()
|
104 |
+
else:
|
105 |
+
self.clip = CLIPTextModelWithProjection.from_pretrained(path)
|
106 |
+
self.transform = AutoTokenizer.from_pretrained(path)
|
107 |
+
for p in self.clip.parameters():
|
108 |
+
p.requires_grad = False
|
109 |
+
self.clip.eval()
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
"""Predicts CLIP features from text.
|
113 |
+
Args:
|
114 |
+
x (dict that contains "text": list): Input batch
|
115 |
+
"""
|
116 |
+
features = self.clip(
|
117 |
+
**self.transform(x["text"], padding=True, return_tensors="pt").to(
|
118 |
+
x["gps"].device
|
119 |
+
)
|
120 |
+
).text_embeds
|
121 |
+
return features
|
122 |
+
|
123 |
+
|
124 |
+
class DINOv2(nn.Module):
|
125 |
+
def __init__(self, tag) -> None:
|
126 |
+
"""Initializes the DINO model."""
|
127 |
+
super().__init__()
|
128 |
+
self.dino = torch.hub.load("facebookresearch/dinov2", tag)
|
129 |
+
self.stride = 14 # ugly but dinov2 stride = 14
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
"""Predicts DINO features from an image."""
|
133 |
+
x = x["img"]
|
134 |
+
|
135 |
+
# crop for stride
|
136 |
+
_, _, H, W = x.shape
|
137 |
+
H_new = H - H % self.stride
|
138 |
+
W_new = W - W % self.stride
|
139 |
+
x = x[:, :, :H_new, :W_new]
|
140 |
+
|
141 |
+
# forward features
|
142 |
+
x = self.dino.forward_features(x)
|
143 |
+
x = x["x_prenorm"]
|
144 |
+
return x
|
145 |
+
|
146 |
+
class ResNet(nn.Module):
|
147 |
+
def __init__(self, path):
|
148 |
+
"""Initializes the ResNet model."""
|
149 |
+
super().__init__()
|
150 |
+
if path == "":
|
151 |
+
config_vision = ResNetConfig()
|
152 |
+
self.resnet = ResNetModel(config_vision)
|
153 |
+
else:
|
154 |
+
self.resnet = ResNetModel.from_pretrained(path)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
"""Predicts ResNet50 features from an image.
|
158 |
+
Args:
|
159 |
+
x (dict that contains "img": torch.Tensor): Input batch
|
160 |
+
"""
|
161 |
+
features = self.resnet(x["img"])["pooler_output"]
|
162 |
+
return features.squeeze()
|
models/networks/heads/__init__.py
ADDED
File without changes
|
models/networks/heads/auxilliary.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from models.networks.utils import UnormGPS
|
3 |
+
from torch.nn.functional import tanh, sigmoid, softmax
|
4 |
+
|
5 |
+
|
6 |
+
class AuxHead(nn.Module):
|
7 |
+
def __init__(self, aux_data=[], use_tanh=False):
|
8 |
+
super().__init__()
|
9 |
+
self.aux_data = aux_data
|
10 |
+
self.unorm = UnormGPS()
|
11 |
+
self.use_tanh = use_tanh
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
"""Forward pass of the network.
|
15 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
16 |
+
"""
|
17 |
+
if self.use_tanh:
|
18 |
+
gps = tanh(x["gps"])
|
19 |
+
gps = self.unorm(gps)
|
20 |
+
output = {"gps": gps}
|
21 |
+
if "land_cover" in self.aux_data:
|
22 |
+
output["land_cover"] = softmax(x["land_cover"])
|
23 |
+
if "road_index" in self.aux_data:
|
24 |
+
output["road_index"] = x["road_index"]
|
25 |
+
if "drive_side" in self.aux_data:
|
26 |
+
output["drive_side"] = sigmoid(x["drive_side"])
|
27 |
+
if "climate" in self.aux_data:
|
28 |
+
output["climate"] = softmax(x["climate"])
|
29 |
+
if "soil" in self.aux_data:
|
30 |
+
output["soil"] = softmax(x["soil"])
|
31 |
+
if "dist_sea" in self.aux_data:
|
32 |
+
output["dist_sea"] = x["dist_sea"]
|
33 |
+
return output
|
models/networks/heads/classification.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class ClassificationHead(nn.Module):
|
6 |
+
"""Classification head for the network."""
|
7 |
+
|
8 |
+
def __init__(self, id_to_gps):
|
9 |
+
super().__init__()
|
10 |
+
self.id_to_gps = id_to_gps
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
"""Forward pass of the network.
|
14 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
15 |
+
"""
|
16 |
+
gps = self.id_to_gps(x.argmax(dim=-1))
|
17 |
+
return {"label": x, **gps}
|
models/networks/heads/hybrid.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from models.networks.utils import UnormGPS
|
6 |
+
|
7 |
+
|
8 |
+
class HybridHead(nn.Module):
|
9 |
+
"""Classification head followed by regression head for the network."""
|
10 |
+
|
11 |
+
def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh):
|
12 |
+
super().__init__()
|
13 |
+
self.final_dim = final_dim
|
14 |
+
self.use_tanh = use_tanh
|
15 |
+
self.scale_tanh = scale_tanh
|
16 |
+
|
17 |
+
self.unorm = UnormGPS()
|
18 |
+
|
19 |
+
if quadtree_path is not None:
|
20 |
+
quadtree = pd.read_csv(quadtree_path)
|
21 |
+
self.init_quadtree(quadtree)
|
22 |
+
|
23 |
+
def init_quadtree(self, quadtree):
|
24 |
+
quadtree[["min_lat", "max_lat"]] /= 90.0
|
25 |
+
quadtree[["min_lon", "max_lon"]] /= 180.0
|
26 |
+
self.register_buffer(
|
27 |
+
"cell_center",
|
28 |
+
0.5 * torch.tensor(quadtree[["max_lat", "max_lon"]].values)
|
29 |
+
+ 0.5 * torch.tensor(quadtree[["min_lat", "min_lon"]].values),
|
30 |
+
)
|
31 |
+
self.register_buffer(
|
32 |
+
"cell_size",
|
33 |
+
torch.tensor(quadtree[["max_lat", "max_lon"]].values)
|
34 |
+
- torch.tensor(quadtree[["min_lat", "min_lon"]].values),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x, gt_label):
|
38 |
+
"""Forward pass of the network.
|
39 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
40 |
+
"""
|
41 |
+
|
42 |
+
classification_logits = x[..., : self.final_dim]
|
43 |
+
classification = classification_logits.argmax(dim=-1)
|
44 |
+
|
45 |
+
regression = x[..., self.final_dim :]
|
46 |
+
|
47 |
+
if self.use_tanh:
|
48 |
+
regression = self.scale_tanh * torch.tanh(regression)
|
49 |
+
|
50 |
+
regression = regression.view(regression.shape[0], -1, 2)
|
51 |
+
|
52 |
+
if self.training:
|
53 |
+
regression = torch.gather(
|
54 |
+
regression,
|
55 |
+
1,
|
56 |
+
gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2),
|
57 |
+
)[:, 0, :]
|
58 |
+
size = 2.0 / self.cell_size[gt_label]
|
59 |
+
center = self.cell_center[gt_label]
|
60 |
+
gps = (
|
61 |
+
self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
regression = torch.gather(
|
65 |
+
regression,
|
66 |
+
1,
|
67 |
+
classification.unsqueeze(-1)
|
68 |
+
.unsqueeze(-1)
|
69 |
+
.expand(regression.shape[0], 1, 2),
|
70 |
+
)[:, 0, :]
|
71 |
+
size = 2.0 / self.cell_size[classification]
|
72 |
+
center = self.cell_center[classification]
|
73 |
+
gps = (
|
74 |
+
self.cell_center[classification]
|
75 |
+
+ regression * self.cell_size[classification] / 2.0
|
76 |
+
)
|
77 |
+
|
78 |
+
gps = self.unorm(gps)
|
79 |
+
|
80 |
+
return {
|
81 |
+
"label": classification_logits,
|
82 |
+
"gps": gps,
|
83 |
+
"size": size,
|
84 |
+
"center": center,
|
85 |
+
"reg": regression,
|
86 |
+
}
|
87 |
+
|
88 |
+
class HybridHeadCentroid(nn.Module):
|
89 |
+
"""Classification head followed by regression head for the network."""
|
90 |
+
|
91 |
+
def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh):
|
92 |
+
super().__init__()
|
93 |
+
self.final_dim = final_dim
|
94 |
+
self.use_tanh = use_tanh
|
95 |
+
self.scale_tanh = scale_tanh
|
96 |
+
|
97 |
+
self.unorm = UnormGPS()
|
98 |
+
if quadtree_path is not None:
|
99 |
+
quadtree = pd.read_csv(quadtree_path)
|
100 |
+
self.init_quadtree(quadtree)
|
101 |
+
|
102 |
+
def init_quadtree(self, quadtree):
|
103 |
+
quadtree[["min_lat", "max_lat", "mean_lat"]] /= 90.0
|
104 |
+
quadtree[["min_lon", "max_lon", "mean_lon"]] /= 180.0
|
105 |
+
self.cell_center = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values)
|
106 |
+
self.cell_size_up = torch.tensor(quadtree[["max_lat", "max_lon"]].values) - torch.tensor(quadtree[["mean_lat", "mean_lon"]].values)
|
107 |
+
self.cell_size_down = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) - torch.tensor(quadtree[["min_lat", "min_lon"]].values)
|
108 |
+
|
109 |
+
def forward(self, x, gt_label):
|
110 |
+
"""Forward pass of the network.
|
111 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
112 |
+
"""
|
113 |
+
classification_logits = x[..., : self.final_dim]
|
114 |
+
classification = classification_logits.argmax(dim=-1)
|
115 |
+
self.cell_size_up = self.cell_size_up.to(classification.device)
|
116 |
+
self.cell_center = self.cell_center.to(classification.device)
|
117 |
+
self.cell_size_down = self.cell_size_down.to(classification.device)
|
118 |
+
|
119 |
+
regression = x[..., self.final_dim :]
|
120 |
+
|
121 |
+
if self.use_tanh:
|
122 |
+
regression = self.scale_tanh * torch.tanh(regression)
|
123 |
+
|
124 |
+
regression = regression.view(regression.shape[0], -1, 2)
|
125 |
+
|
126 |
+
if self.training:
|
127 |
+
regression = torch.gather(
|
128 |
+
regression,
|
129 |
+
1,
|
130 |
+
gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2),
|
131 |
+
)[:, 0, :]
|
132 |
+
size = torch.where(
|
133 |
+
regression > 0,
|
134 |
+
self.cell_size_up[gt_label],
|
135 |
+
self.cell_size_down[gt_label],
|
136 |
+
)
|
137 |
+
center = self.cell_center[gt_label]
|
138 |
+
gps = self.cell_center[gt_label] + regression * size
|
139 |
+
else:
|
140 |
+
regression = torch.gather(
|
141 |
+
regression,
|
142 |
+
1,
|
143 |
+
classification.unsqueeze(-1)
|
144 |
+
.unsqueeze(-1)
|
145 |
+
.expand(regression.shape[0], 1, 2),
|
146 |
+
)[:, 0, :]
|
147 |
+
size = torch.where(
|
148 |
+
regression > 0,
|
149 |
+
self.cell_size_up[classification],
|
150 |
+
self.cell_size_down[classification],
|
151 |
+
)
|
152 |
+
center = self.cell_center[classification]
|
153 |
+
gps = self.cell_center[classification] + regression * size
|
154 |
+
|
155 |
+
gps = self.unorm(gps)
|
156 |
+
|
157 |
+
return {
|
158 |
+
"label": classification_logits,
|
159 |
+
"gps": gps,
|
160 |
+
"size": 1.0 / size,
|
161 |
+
"center": center,
|
162 |
+
"reg": regression,
|
163 |
+
}
|
164 |
+
|
165 |
+
|
166 |
+
class SharedHybridHead(HybridHead):
|
167 |
+
"""Classification head followed by SHARED regression head for the network."""
|
168 |
+
|
169 |
+
def forward(self, x, gt_label):
|
170 |
+
"""Forward pass of the network.
|
171 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
172 |
+
"""
|
173 |
+
|
174 |
+
classification_logits = x[..., : self.final_dim]
|
175 |
+
classification = classification_logits.argmax(dim=-1)
|
176 |
+
|
177 |
+
regression = x[..., self.final_dim :]
|
178 |
+
|
179 |
+
if self.use_tanh:
|
180 |
+
regression = self.scale_tanh * torch.tanh(regression)
|
181 |
+
|
182 |
+
if self.training:
|
183 |
+
gps = (
|
184 |
+
self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
gps = (
|
188 |
+
self.cell_center[classification]
|
189 |
+
+ regression * self.cell_size[classification] / 2.0
|
190 |
+
)
|
191 |
+
|
192 |
+
gps = self.unorm(gps)
|
193 |
+
|
194 |
+
return {"label": classification_logits, "gps": gps}
|
models/networks/heads/id_to_gps.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models.networks.utils import UnormGPS
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class IdToGPS(nn.Module):
|
8 |
+
def __init__(self, id_to_gps: str):
|
9 |
+
"""Map index to gps coordinates (indices can be country or city ids)"""
|
10 |
+
super().__init__()
|
11 |
+
if "quadtree" in id_to_gps:
|
12 |
+
self.id_to_gps = torch.load(
|
13 |
+
"_".join(id_to_gps.split("_")[:-4] + id_to_gps.split("_")[-3:])
|
14 |
+
)
|
15 |
+
else:
|
16 |
+
self.id_to_gps = torch.load(id_to_gps)
|
17 |
+
#self.unorm = UnormGPS()
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
"""Mapping from country id to gps coordinates
|
21 |
+
Args:
|
22 |
+
x: torch.Tensor with features
|
23 |
+
"""
|
24 |
+
|
25 |
+
if isinstance(x, dict):
|
26 |
+
# for oracle
|
27 |
+
labels, x = x["label"], x["img"]
|
28 |
+
else:
|
29 |
+
# predicted labels
|
30 |
+
labels = x
|
31 |
+
self.id_to_gps = self.id_to_gps.to(labels.device)
|
32 |
+
#return {"gps": self.unorm(self.id_to_gps[labels])}
|
33 |
+
return {"gps": self.id_to_gps[labels]}
|
models/networks/heads/random.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from models.networks.utils import UnormGPS
|
5 |
+
|
6 |
+
|
7 |
+
class Random(nn.Module):
|
8 |
+
def __init__(self, num_output):
|
9 |
+
"""Random"""
|
10 |
+
super().__init__()
|
11 |
+
self.num_output = num_output
|
12 |
+
self.unorm = UnormGPS()
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
"""Predicts GPS coordinates from an image.
|
16 |
+
Args:
|
17 |
+
x: torch.Tensor with features
|
18 |
+
"""
|
19 |
+
#x = x["img"]
|
20 |
+
gps = torch.rand((x.shape[0], self.num_output), device=x.device) * 2 - 1
|
21 |
+
return {"gps": self.unorm(gps)}
|
22 |
+
|
23 |
+
|
24 |
+
class RandomCoords(nn.Module):
|
25 |
+
def __init__(self, coords_path: str):
|
26 |
+
"""Randomly sample from a list of coordinates
|
27 |
+
Args:
|
28 |
+
coords_path: str with path to csv file with coordinates
|
29 |
+
"""
|
30 |
+
super().__init__()
|
31 |
+
coordinates = pd.read_csv(coords_path)
|
32 |
+
longitudes = coordinates["longitude"].values / 180
|
33 |
+
latitudes = coordinates["latitude"].values / 90
|
34 |
+
self.unorm = UnormGPS()
|
35 |
+
del coordinates
|
36 |
+
|
37 |
+
self.N = len(longitudes)
|
38 |
+
assert len(longitudes) == len(latitudes)
|
39 |
+
self.coordinates = torch.stack(
|
40 |
+
[torch.tensor(latitudes), torch.tensor(longitudes)],
|
41 |
+
dim=-1,
|
42 |
+
)
|
43 |
+
del longitudes, latitudes
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
"""Predicts GPS coordinates from an image.
|
47 |
+
Args:
|
48 |
+
x: torch.Tensor with features
|
49 |
+
"""
|
50 |
+
x = x["img"]
|
51 |
+
# randomly select a coordinate in the list
|
52 |
+
n = torch.randint(0, self.N, (x.shape[0],))
|
53 |
+
return {"gps": self.unorm(self.coordinates[n].to(x.device))}
|
models/networks/heads/regression.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.networks.utils import UnormGPS
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.functional import tanh
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class RegressionHead(nn.Module):
|
8 |
+
def __init__(self, use_tanh=False):
|
9 |
+
super().__init__()
|
10 |
+
self.unorm = UnormGPS()
|
11 |
+
self.use_tanh = use_tanh
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
"""Forward pass of the network.
|
15 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
16 |
+
"""
|
17 |
+
if self.use_tanh:
|
18 |
+
x = tanh(x)
|
19 |
+
gps = self.unorm(x)
|
20 |
+
return {"gps": gps}
|
21 |
+
|
22 |
+
|
23 |
+
class RegressionHeadAngle(nn.Module):
|
24 |
+
def __init__(self):
|
25 |
+
super().__init__()
|
26 |
+
self.unorm = UnormGPS()
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
"""Forward pass of the network.
|
30 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
31 |
+
"""
|
32 |
+
x1 = x[:, 0].pow(2)
|
33 |
+
x2 = x[:, 1].pow(2)
|
34 |
+
x3 = x[:, 2].pow(2)
|
35 |
+
x4 = x[:, 3].pow(2)
|
36 |
+
cos_lambda = x1 / (x1 + x2)
|
37 |
+
sin_lambda = x2 / (x1 + x2)
|
38 |
+
cos_phi = x3 / (x3 + x4)
|
39 |
+
sin_phi = x4 / (x3 + x4)
|
40 |
+
lbd = torch.atan2(sin_lambda, cos_lambda)
|
41 |
+
phi = torch.atan2(sin_phi, cos_phi)
|
42 |
+
gps = torch.cat((lbd.unsqueeze(1), phi.unsqueeze(1)), dim=1)
|
43 |
+
# gps = self.unorm(x)
|
44 |
+
return {"gps": gps}
|
models/networks/mlp.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class MLP(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
initial_dim=512,
|
9 |
+
hidden_dim=[128, 32, 2],
|
10 |
+
final_dim=2,
|
11 |
+
norm=nn.InstanceNorm1d,
|
12 |
+
activation=nn.ReLU,
|
13 |
+
aux_data=[],
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
Initializes an MLP Classification Head
|
17 |
+
Args:
|
18 |
+
hidden_dim (list): list of hidden dimensions for the MLP
|
19 |
+
norm (nn.Module): normalization layer
|
20 |
+
activation (nn.Module): activation layer
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
self.aux_data = aux_data
|
24 |
+
self.aux = len(self.aux_data) > 0
|
25 |
+
if self.aux:
|
26 |
+
hidden_dim_aux = hidden_dim
|
27 |
+
hidden_dim_aux[-1] = 128
|
28 |
+
final_dim_aux_dict = {
|
29 |
+
"land_cover": 12,
|
30 |
+
"climate": 30,
|
31 |
+
"soil": 14,
|
32 |
+
"road_index": 1,
|
33 |
+
"drive_side": 1,
|
34 |
+
"dist_sea": 1,
|
35 |
+
}
|
36 |
+
self.idx = {}
|
37 |
+
final_dim_aux = 0
|
38 |
+
for col in self.aux_data:
|
39 |
+
self.idx[col] = [
|
40 |
+
final_dim_aux + i for i in range(final_dim_aux_dict[col])
|
41 |
+
]
|
42 |
+
final_dim_aux += final_dim_aux_dict[col]
|
43 |
+
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
|
44 |
+
args = self.init_layers(dim, norm, activation)
|
45 |
+
self.mlp_aux = nn.Sequential(*args)
|
46 |
+
dim = [initial_dim] + hidden_dim + [final_dim]
|
47 |
+
args = self.init_layers(dim, norm, activation)
|
48 |
+
self.mlp = nn.Sequential(*args)
|
49 |
+
|
50 |
+
def init_layers(self, dim, norm, activation):
|
51 |
+
"""Initializes the MLP layers."""
|
52 |
+
args = [nn.LayerNorm(dim[0])]
|
53 |
+
for i in range(len(dim) - 1):
|
54 |
+
args.append(nn.Linear(dim[i], dim[i + 1]))
|
55 |
+
if i < len(dim) - 2:
|
56 |
+
# args.append(norm(dim[i + 1]))
|
57 |
+
args.append(norm(4, dim[i + 1]))
|
58 |
+
args.append(activation())
|
59 |
+
return args
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
"""Predicts GPS coordinates from an image.
|
63 |
+
Args:
|
64 |
+
x: torch.Tensor with features
|
65 |
+
"""
|
66 |
+
if self.aux:
|
67 |
+
out = {"gps": self.mlp(x[:, 0, :])}
|
68 |
+
x = self.mlp_aux(x[:, 0, :])
|
69 |
+
for col in list(self.idx.keys()):
|
70 |
+
out[col] = x[:, self.idx[col]]
|
71 |
+
return out
|
72 |
+
return self.mlp(x[:, 0, :])
|
73 |
+
|
74 |
+
class MLPResNet(nn.Module):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
initial_dim=512,
|
78 |
+
hidden_dim=[128, 32, 2],
|
79 |
+
final_dim=2,
|
80 |
+
norm=nn.InstanceNorm1d,
|
81 |
+
activation=nn.ReLU,
|
82 |
+
aux_data=[],
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
Initializes an MLP Classification Head
|
86 |
+
Args:
|
87 |
+
hidden_dim (list): list of hidden dimensions for the MLP
|
88 |
+
norm (nn.Module): normalization layer
|
89 |
+
activation (nn.Module): activation layer
|
90 |
+
"""
|
91 |
+
super().__init__()
|
92 |
+
self.aux_data = aux_data
|
93 |
+
self.aux = len(self.aux_data) > 0
|
94 |
+
if self.aux:
|
95 |
+
hidden_dim_aux = hidden_dim
|
96 |
+
hidden_dim_aux[-1] = 128
|
97 |
+
final_dim_aux_dict = {
|
98 |
+
"land_cover": 12,
|
99 |
+
"climate": 30,
|
100 |
+
"soil": 14,
|
101 |
+
"road_index": 1,
|
102 |
+
"drive_side": 1,
|
103 |
+
"dist_sea": 1,
|
104 |
+
}
|
105 |
+
self.idx = {}
|
106 |
+
final_dim_aux = 0
|
107 |
+
for col in self.aux_data:
|
108 |
+
self.idx[col] = [
|
109 |
+
final_dim_aux + i for i in range(final_dim_aux_dict[col])
|
110 |
+
]
|
111 |
+
final_dim_aux += final_dim_aux_dict[col]
|
112 |
+
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
|
113 |
+
args = self.init_layers(dim, norm, activation)
|
114 |
+
self.mlp_aux = nn.Sequential(*args)
|
115 |
+
dim = [initial_dim] + hidden_dim + [final_dim]
|
116 |
+
args = self.init_layers(dim, norm, activation)
|
117 |
+
self.mlp = nn.Sequential(*args)
|
118 |
+
|
119 |
+
def init_layers(self, dim, norm, activation):
|
120 |
+
"""Initializes the MLP layers."""
|
121 |
+
args = [nn.LayerNorm(dim[0])]
|
122 |
+
for i in range(len(dim) - 1):
|
123 |
+
args.append(nn.Linear(dim[i], dim[i + 1]))
|
124 |
+
if i < len(dim) - 2:
|
125 |
+
# args.append(norm(dim[i + 1]))
|
126 |
+
args.append(norm(4, dim[i + 1]))
|
127 |
+
args.append(activation())
|
128 |
+
return args
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
"""Predicts GPS coordinates from an image.
|
132 |
+
Args:
|
133 |
+
x: torch.Tensor with features
|
134 |
+
"""
|
135 |
+
if self.aux:
|
136 |
+
out = {"gps": self.mlp(x[:, 0, :])}
|
137 |
+
x = self.mlp_aux(x[:, 0, :])
|
138 |
+
for col in list(self.idx.keys()):
|
139 |
+
out[col] = x[:, self.idx[col]]
|
140 |
+
return out
|
141 |
+
return self.mlp(x)
|
142 |
+
|
143 |
+
|
144 |
+
class MLPCentroid(nn.Module):
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
initial_dim=512,
|
148 |
+
hidden_dim=[128, 32, 2],
|
149 |
+
final_dim=2,
|
150 |
+
norm=nn.InstanceNorm1d,
|
151 |
+
activation=nn.ReLU,
|
152 |
+
aux_data=[],
|
153 |
+
):
|
154 |
+
"""
|
155 |
+
Initializes an MLP Classification Head
|
156 |
+
Args:
|
157 |
+
hidden_dim (list): list of hidden dimensions for the MLP
|
158 |
+
norm (nn.Module): normalization layer
|
159 |
+
activation (nn.Module): activation layer
|
160 |
+
"""
|
161 |
+
super().__init__()
|
162 |
+
self.aux_data = aux_data
|
163 |
+
self.aux = len(self.aux_data) > 0
|
164 |
+
dim = [initial_dim] + hidden_dim + [final_dim // 3]
|
165 |
+
args = self.init_layers(dim, norm, activation)
|
166 |
+
self.classif = nn.Sequential(*args)
|
167 |
+
dim = [initial_dim] + hidden_dim + [2 * final_dim // 3]
|
168 |
+
args = self.init_layers(dim, norm, activation)
|
169 |
+
self.reg = nn.Sequential(*args)
|
170 |
+
# torch.nn.init.normal_(self.reg.weight, mean=0.0, std=0.01)
|
171 |
+
if self.aux:
|
172 |
+
self.dim = [initial_dim] + hidden_dim
|
173 |
+
self.predictors = {"gps": self.mlp}
|
174 |
+
self.init_aux(dim, norm, activation)
|
175 |
+
|
176 |
+
def init_layers(self, dim, norm, activation):
|
177 |
+
"""Initializes the MLP layers."""
|
178 |
+
args = [nn.LayerNorm(dim[0])]
|
179 |
+
for i in range(len(dim) - 1):
|
180 |
+
args.append(nn.Linear(dim[i], dim[i + 1]))
|
181 |
+
if i < len(dim) - 2:
|
182 |
+
# args.append(norm(dim[i + 1]))
|
183 |
+
args.append(norm(4, dim[i + 1]))
|
184 |
+
args.append(activation())
|
185 |
+
return args
|
186 |
+
|
187 |
+
def init_aux(self, dim, norm, activation):
|
188 |
+
final_dim_aux = {
|
189 |
+
"land_cover": 12,
|
190 |
+
"climate": 30,
|
191 |
+
"soil": 14,
|
192 |
+
"road_index": 1,
|
193 |
+
"drive_side": 1,
|
194 |
+
"dist_sea": 1,
|
195 |
+
}
|
196 |
+
if "land_cover" in self.aux_data:
|
197 |
+
args = self.init_layers(
|
198 |
+
self.dim + [final_dim_aux["land_cover"]], norm, activation
|
199 |
+
)
|
200 |
+
self.land_cover = nn.Sequential(*args)
|
201 |
+
self.predictors["land_cover"] = self.land_cover
|
202 |
+
if "road_index" in self.aux_data:
|
203 |
+
args = self.init_layers(
|
204 |
+
self.dim + [final_dim_aux["road_index"]], norm, activation
|
205 |
+
)
|
206 |
+
self.road_index = nn.Sequential(*args)
|
207 |
+
self.predictors["road_index"] = self.road_index
|
208 |
+
if "drive_side" in self.aux_data:
|
209 |
+
args = self.init_layers(
|
210 |
+
self.dim + [final_dim_aux["drive_side"]], norm, activation
|
211 |
+
)
|
212 |
+
self.drive_side = nn.Sequential(*args)
|
213 |
+
self.predictors["drive_side"] = self.drive_side
|
214 |
+
if "climate" in self.aux_data:
|
215 |
+
args = self.init_layers(
|
216 |
+
self.dim + [final_dim_aux["climate"]], norm, activation
|
217 |
+
)
|
218 |
+
self.climate = nn.Sequential(*args)
|
219 |
+
self.predictors["climate"] = self.climate
|
220 |
+
if "soil" in self.aux_data:
|
221 |
+
args = self.init_layers(
|
222 |
+
self.dim + [final_dim_aux["soil"]], norm, activation
|
223 |
+
)
|
224 |
+
self.soil = nn.Sequential(*args)
|
225 |
+
self.predictors["soil"] = self.soil
|
226 |
+
if "dist_sea" in self.aux_data:
|
227 |
+
args = self.init_layers(
|
228 |
+
self.dim + [final_dim_aux["dist_sea"]], norm, activation
|
229 |
+
)
|
230 |
+
self.dist_sea = nn.Sequential(*args)
|
231 |
+
self.predictors["dist_sea"] = self.dist_sea
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
"""Predicts GPS coordinates from an image.
|
235 |
+
Args:
|
236 |
+
x: torch.Tensor with features
|
237 |
+
"""
|
238 |
+
if self.aux:
|
239 |
+
return {
|
240 |
+
col: self.predictors[col](x[:, 0, :]) for col in self.predictors.keys()
|
241 |
+
}
|
242 |
+
return torch.cat([self.classif(x[:, 0, :]), self.reg(x[:, 0, :])], dim=1)
|
243 |
+
|
244 |
+
|
245 |
+
class Identity(nn.Module):
|
246 |
+
def __init__(
|
247 |
+
self
|
248 |
+
):
|
249 |
+
"""
|
250 |
+
Initializes an Identity module
|
251 |
+
"""
|
252 |
+
super().__init__()
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
"""
|
256 |
+
Return same as input
|
257 |
+
"""
|
258 |
+
return x
|
models/networks/network.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from torch import nn
|
5 |
+
from hydra.utils import instantiate
|
6 |
+
import copy
|
7 |
+
from peft import LoraConfig, get_peft_model
|
8 |
+
from utils.model_utils import print_trainable_parameters
|
9 |
+
|
10 |
+
|
11 |
+
def freeze(model):
|
12 |
+
"""Freezes the parameters of a model."""
|
13 |
+
for p in model.parameters():
|
14 |
+
p.requires_grad = False
|
15 |
+
model.eval()
|
16 |
+
|
17 |
+
|
18 |
+
def unfreeze(model):
|
19 |
+
"""Unfreezes the parameters of a model.
|
20 |
+
for p in model.parameters():
|
21 |
+
p.requires_grad = True"""
|
22 |
+
model_parameters = model.named_parameters()
|
23 |
+
for name, param in model_parameters:
|
24 |
+
if name in [
|
25 |
+
"clip.vision_model.post_layernorm.weight",
|
26 |
+
"clip.vision_model.post_layernorm.bias",
|
27 |
+
]:
|
28 |
+
param.requires_grad = False
|
29 |
+
else:
|
30 |
+
param.requires_grad = True
|
31 |
+
model.train()
|
32 |
+
|
33 |
+
|
34 |
+
def unfreeze_last(model):
|
35 |
+
"""Unfreezes the parameters of a model.
|
36 |
+
for p in model.parameters():
|
37 |
+
p.requires_grad = True"""
|
38 |
+
model_parameters = model.named_parameters()
|
39 |
+
for name, param in model_parameters:
|
40 |
+
if len(name.split(".")) > 5:
|
41 |
+
if name.split(".")[4] == "11":
|
42 |
+
param.requires_grad = True
|
43 |
+
else:
|
44 |
+
param.requires_grad = False
|
45 |
+
else:
|
46 |
+
param.requires_grad = False
|
47 |
+
model.train()
|
48 |
+
|
49 |
+
|
50 |
+
class FrozenBackbone(nn.Module):
|
51 |
+
"""Freezes the backbone of a network."""
|
52 |
+
|
53 |
+
def __init__(self, backbone, mid, head):
|
54 |
+
super().__init__()
|
55 |
+
self.backbone = backbone.instance
|
56 |
+
self.mid = mid.instance
|
57 |
+
self.head = head.instance
|
58 |
+
self.target_key = head.target_key
|
59 |
+
freeze(self.backbone)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
"""Forward pass of the network.
|
63 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
64 |
+
"""
|
65 |
+
with torch.no_grad():
|
66 |
+
x = self.backbone(x)
|
67 |
+
x = self.mid(x)
|
68 |
+
x = self.head(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class UnfrozenBackbone(nn.Module):
|
73 |
+
"""Unfreezes the backbone of a network."""
|
74 |
+
|
75 |
+
def __init__(self, backbone, mid, head):
|
76 |
+
super().__init__()
|
77 |
+
self.backbone = backbone.instance
|
78 |
+
self.mid = mid.instance
|
79 |
+
self.head = head.instance
|
80 |
+
self.target_key = head.target_key
|
81 |
+
unfreeze(self.backbone)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
"""Forward pass of the network.
|
85 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
86 |
+
"""
|
87 |
+
x = self.backbone(x)
|
88 |
+
x = self.mid(x)
|
89 |
+
x = self.head(x)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
class UnfrozenPartBackbone(nn.Module):
|
94 |
+
"""Unfreezes the backbone of a network."""
|
95 |
+
|
96 |
+
def __init__(self, backbone, mid, head):
|
97 |
+
super().__init__()
|
98 |
+
self.backbone = backbone.instance
|
99 |
+
self.mid = mid.instance
|
100 |
+
self.head = head.instance
|
101 |
+
self.target_key = head.target_key
|
102 |
+
unfreeze_last(self.backbone)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
"""Forward pass of the network.
|
106 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
107 |
+
"""
|
108 |
+
x = self.backbone(x)
|
109 |
+
x = self.mid(x)
|
110 |
+
x = self.head(x)
|
111 |
+
return x
|
112 |
+
|
113 |
+
|
114 |
+
class NoFeatureBackbone(nn.Module):
|
115 |
+
"""Randomizes the backbone of a network."""
|
116 |
+
|
117 |
+
def __init__(self, head):
|
118 |
+
super().__init__()
|
119 |
+
self.head = head.instance
|
120 |
+
self.target_key = head.target_key
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
"""Forward pass of the network.
|
124 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
125 |
+
"""
|
126 |
+
return self.head(x)
|
127 |
+
|
128 |
+
|
129 |
+
class ContrastiveFrozenBackbone(FrozenBackbone):
|
130 |
+
"""Freezes the backbone of a network."""
|
131 |
+
|
132 |
+
def __init__(self, backbone, mid, head, mode):
|
133 |
+
super().__init__(backbone, mid, head)
|
134 |
+
self.mode = mode
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
with torch.no_grad():
|
138 |
+
features = self.backbone(x)
|
139 |
+
if self.mode != "eval":
|
140 |
+
x_pos = {
|
141 |
+
k.strip("pos_"): v.clone()
|
142 |
+
if isinstance(v, torch.Tensor)
|
143 |
+
else copy.deepcopy(v)
|
144 |
+
for k, v in x.items()
|
145 |
+
if k.startswith("pos_")
|
146 |
+
}
|
147 |
+
pos_features = self.backbone(x_pos)
|
148 |
+
x = self.mid(features)
|
149 |
+
x = self.head(x)
|
150 |
+
if self.mode != "eval":
|
151 |
+
return {
|
152 |
+
"features": features[:, 0, :],
|
153 |
+
"pos_features": pos_features[:, 0, :],
|
154 |
+
**x,
|
155 |
+
}
|
156 |
+
return {
|
157 |
+
"features": features[:, 0, :],
|
158 |
+
**x,
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
class ContrastiveUnFrozenPartBackbone(UnfrozenPartBackbone):
|
163 |
+
"""Freezes the backbone of a network."""
|
164 |
+
|
165 |
+
def __init__(self, backbone, mid, head, mode):
|
166 |
+
super().__init__(backbone, mid, head)
|
167 |
+
self.mode = mode
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
features = self.backbone(x)
|
171 |
+
if self.mode != "eval":
|
172 |
+
x_pos = {
|
173 |
+
k.strip("pos_"): v.clone()
|
174 |
+
if isinstance(v, torch.Tensor)
|
175 |
+
else copy.deepcopy(v)
|
176 |
+
for k, v in x.items()
|
177 |
+
if k.startswith("pos_")
|
178 |
+
}
|
179 |
+
pos_features = self.backbone(x_pos)
|
180 |
+
x = self.mid(features)
|
181 |
+
x = self.head(x)
|
182 |
+
if self.mode != "eval":
|
183 |
+
return {
|
184 |
+
"features": features[:, 0, :],
|
185 |
+
"pos_features": pos_features[:, 0, :],
|
186 |
+
**x,
|
187 |
+
}
|
188 |
+
return {
|
189 |
+
"features": features[:, 0, :],
|
190 |
+
**x,
|
191 |
+
}
|
192 |
+
|
193 |
+
|
194 |
+
class ContrastiveUnFrozenBackbone(UnfrozenBackbone):
|
195 |
+
"""Freezes the backbone of a network."""
|
196 |
+
|
197 |
+
def __init__(self, backbone, mid, head, mode):
|
198 |
+
super().__init__(backbone, mid, head)
|
199 |
+
self.mode = mode
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
features = self.backbone(x)
|
203 |
+
if self.mode != "eval":
|
204 |
+
x_pos = {
|
205 |
+
k.strip("pos_"): v.clone()
|
206 |
+
if isinstance(v, torch.Tensor)
|
207 |
+
else copy.deepcopy(v)
|
208 |
+
for k, v in x.items()
|
209 |
+
if k.startswith("pos_")
|
210 |
+
}
|
211 |
+
pos_features = self.backbone(x_pos)
|
212 |
+
x = self.mid(features)
|
213 |
+
x = self.head(x)
|
214 |
+
if self.mode != "eval":
|
215 |
+
return {
|
216 |
+
"features": features[:, 0, :],
|
217 |
+
"pos_features": pos_features[:, 0, :],
|
218 |
+
**x,
|
219 |
+
}
|
220 |
+
return {
|
221 |
+
"features": features[:, 0, :],
|
222 |
+
**x,
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
class TextContrastiveUnFrozenBackbone(UnfrozenBackbone):
|
227 |
+
"""Freezes the backbone of a network."""
|
228 |
+
|
229 |
+
def __init__(self, backbone, mid, head):
|
230 |
+
super().__init__(backbone, mid, head)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
con, features = self.backbone(x)
|
234 |
+
x = self.mid(features)
|
235 |
+
x = self.head(x)
|
236 |
+
return {
|
237 |
+
"features": con,
|
238 |
+
**x,
|
239 |
+
}
|
240 |
+
|
241 |
+
|
242 |
+
class LoraBackbone(nn.Module):
|
243 |
+
"""Wraps the backbone in a PEFT model for LoRA tuning."""
|
244 |
+
|
245 |
+
def __init__(self, backbone, mid, head, r, alpha, dropout, bias):
|
246 |
+
super().__init__()
|
247 |
+
self.backbone = backbone.instance
|
248 |
+
self.mid = mid.instance
|
249 |
+
self.head = head.instance
|
250 |
+
self.target_key = head.target_key
|
251 |
+
freeze(self.backbone)
|
252 |
+
|
253 |
+
config = LoraConfig(
|
254 |
+
r=r,
|
255 |
+
lora_alpha=alpha,
|
256 |
+
lora_dropout=dropout,
|
257 |
+
bias=bias,
|
258 |
+
target_modules=["q_proj", "k_proj", "v_proj"],
|
259 |
+
)
|
260 |
+
self.backbone = get_peft_model(self.backbone, config)
|
261 |
+
print_trainable_parameters(self)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass of the network.
|
265 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
266 |
+
"""
|
267 |
+
x = self.backbone(x)
|
268 |
+
x = self.mid(x)
|
269 |
+
return self.head(x)
|
270 |
+
|
271 |
+
|
272 |
+
class HybridFrozenBackbone(FrozenBackbone):
|
273 |
+
"""Freezes the backbone of a network."""
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
"""Forward pass of the network.
|
277 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
278 |
+
"""
|
279 |
+
|
280 |
+
gt_label = x["label"] if self.training else None
|
281 |
+
|
282 |
+
with torch.no_grad():
|
283 |
+
x = self.backbone(x)
|
284 |
+
x = self.mid(x)
|
285 |
+
x = self.head(x, gt_label)
|
286 |
+
return x
|
287 |
+
|
288 |
+
|
289 |
+
class HybridUnfrozenBackbone(UnfrozenBackbone):
|
290 |
+
"""Unfreezes the backbone of a network."""
|
291 |
+
|
292 |
+
def forward(self, x):
|
293 |
+
"""Forward pass of the network.
|
294 |
+
x : Union[torch.Tensor, dict] with the output of the backbone.
|
295 |
+
"""
|
296 |
+
|
297 |
+
gt_label = x["label"] if self.training else None
|
298 |
+
|
299 |
+
x = self.backbone(x)
|
300 |
+
x = self.mid(x)
|
301 |
+
x = self.head(x, gt_label)
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
class ContrastiveHybridUnFrozenBackbone(UnfrozenBackbone):
|
306 |
+
"""Freezes the backbone of a network."""
|
307 |
+
|
308 |
+
def __init__(self, backbone, mid, head, mode):
|
309 |
+
super().__init__(backbone, mid, head)
|
310 |
+
self.mode = mode
|
311 |
+
|
312 |
+
def forward(self, x):
|
313 |
+
gt_label = x["label"] if self.training else None
|
314 |
+
features = self.backbone(x)
|
315 |
+
if self.mode != "eval":
|
316 |
+
x_pos = {
|
317 |
+
k.strip("pos_"): v.clone()
|
318 |
+
if isinstance(v, torch.Tensor)
|
319 |
+
else copy.deepcopy(v)
|
320 |
+
for k, v in x.items()
|
321 |
+
if k.startswith("pos_")
|
322 |
+
}
|
323 |
+
pos_features = self.backbone(x_pos)
|
324 |
+
x = self.mid(features)
|
325 |
+
x = self.head(x, gt_label)
|
326 |
+
if self.mode != "eval":
|
327 |
+
return {
|
328 |
+
"features": features[:, 0, :],
|
329 |
+
"pos_features": pos_features[:, 0, :],
|
330 |
+
**x,
|
331 |
+
}
|
332 |
+
return {
|
333 |
+
"features": features[:, 0, :],
|
334 |
+
**x,
|
335 |
+
}
|
models/networks/utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class NormGPS(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
def forward(self, x):
|
11 |
+
"""Normalize latitude longtitude radians to -1, 1.""" # not used currently
|
12 |
+
return x / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)
|
13 |
+
|
14 |
+
|
15 |
+
class UnormGPS(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
"""Unormalize latitude longtitude radians to -1, 1."""
|
21 |
+
x = torch.clamp(x, -1, 1)
|
22 |
+
return x * torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)
|
models/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import abspath as abp
|
3 |
+
import torch
|
4 |
+
import hydra
|
5 |
+
from hydra import initialize, compose
|
6 |
+
from models.module import Geolocalizer
|
7 |
+
from omegaconf import OmegaConf, open_dict
|
8 |
+
from os.path import join
|
9 |
+
from hydra.utils import instantiate
|
10 |
+
|
11 |
+
def load_model_config(path):
|
12 |
+
# given the directory of os.cwd()
|
13 |
+
# compute the relative path to path
|
14 |
+
path = abp(path)
|
15 |
+
rel_path = os.path.relpath(path, start=os.path.split(__file__)[0])
|
16 |
+
|
17 |
+
with initialize(version_base=None, config_path=rel_path):
|
18 |
+
cfg = compose(config_name="config", overrides=[])
|
19 |
+
|
20 |
+
checkpoint = torch.load(join(path, "last.ckpt"))
|
21 |
+
del checkpoint["state_dict"][
|
22 |
+
"model.backbone.clip.vision_model.embeddings.position_ids"
|
23 |
+
]
|
24 |
+
torch.save(checkpoint, join(path, "last2.ckpt"))
|
25 |
+
|
26 |
+
with open_dict(cfg):
|
27 |
+
cfg.checkpoint = join(path, "last2.ckpt")
|
28 |
+
|
29 |
+
cfg.num_classes = 11399
|
30 |
+
cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
|
31 |
+
cfg.model.network.head.final_dim = cfg.num_classes * 3
|
32 |
+
cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")
|
33 |
+
|
34 |
+
cfg.dataset.train_dataset.path = ""
|
35 |
+
cfg.dataset.val_dataset.path = ""
|
36 |
+
cfg.dataset.test_dataset.path = ""
|
37 |
+
cfg.logger.save_dir = ""
|
38 |
+
cfg.data_dir = ""
|
39 |
+
cfg.root_dir = ""
|
40 |
+
cfg.mode = "test"
|
41 |
+
cfg.model.network.backbone.instance.path = (
|
42 |
+
"laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
|
43 |
+
)
|
44 |
+
return cfg.dataset.test_transform, cfg.model, join(path, "last2.ckpt"), True
|
45 |
+
|
46 |
+
def load_model(path):
|
47 |
+
transform_config, model_config, checkpoint_path, delete = load_model_config(path)
|
48 |
+
|
49 |
+
transform = instantiate(transform_config)
|
50 |
+
model = Geolocalizer.load_from_checkpoint(checkpoint_path, cfg=model_config)
|
51 |
+
if delete:
|
52 |
+
os.remove(checkpoint_path)
|
53 |
+
|
54 |
+
return model, transform
|
scripts/download-dataset.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, zipfile
|
2 |
+
from huggingface_hub import snapshot_download
|
3 |
+
|
4 |
+
# Define the base directory
|
5 |
+
base_dir = os.path.join(os.getcwd(), 'datasets')
|
6 |
+
|
7 |
+
# Ensure the base directory exists
|
8 |
+
if not os.path.exists(base_dir):
|
9 |
+
os.mkdir(base_dir)
|
10 |
+
|
11 |
+
# Define the specific dataset directory
|
12 |
+
dataset_dir = os.path.join(base_dir, "osv5m")
|
13 |
+
|
14 |
+
# Ensure the specific dataset directory exists
|
15 |
+
if not os.path.exists(dataset_dir):
|
16 |
+
os.mkdir(dataset_dir)
|
17 |
+
|
18 |
+
# Download the dataset
|
19 |
+
snapshot_download(repo_id="osv5m/osv5m", local_dir=dataset_dir, repo_type='dataset')
|
20 |
+
|
21 |
+
# Extract zip files and remove them after extraction
|
22 |
+
for root, dirs, files in os.walk(dataset_dir):
|
23 |
+
for file in files:
|
24 |
+
if file.endswith(".zip"):
|
25 |
+
with zipfile.ZipFile(os.path.join(root, file), 'r') as zip_ref:
|
26 |
+
zip_ref.extractall(root)
|
27 |
+
os.remove(os.path.join(root, file))
|
scripts/preprocessing/enrich-metadata-adaptive-quadtrees.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import statistics
|
6 |
+
from os.path import join, dirname
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
|
10 |
+
class QuadTree(object):
|
11 |
+
def __init__(self, data, id="", depth=3, do_split=5000):
|
12 |
+
self.id = id
|
13 |
+
self.data = data
|
14 |
+
|
15 |
+
coord = data[["latitude", "longitude"]].to_numpy()
|
16 |
+
|
17 |
+
# if mins is None:
|
18 |
+
mins = coord.min(0)
|
19 |
+
# if maxs is None:
|
20 |
+
maxs = coord.max(0)
|
21 |
+
|
22 |
+
self.mins = np.asarray(mins)
|
23 |
+
self.maxs = np.asarray(maxs)
|
24 |
+
self.sizes = self.maxs - self.mins
|
25 |
+
|
26 |
+
self.children = []
|
27 |
+
|
28 |
+
# sort by latitude
|
29 |
+
sorted_data_lat = sorted(coord, key=lambda point: point[0])
|
30 |
+
|
31 |
+
# get the median lat
|
32 |
+
median_lat = statistics.median(point[0] for point in sorted_data_lat)
|
33 |
+
|
34 |
+
# Divide the cell into two half-cells based on the median lat
|
35 |
+
data_left = [point for point in sorted_data_lat if point[0] <= median_lat]
|
36 |
+
data_right = [point for point in sorted_data_lat if point[0] > median_lat]
|
37 |
+
|
38 |
+
# Sort the data points by long in each half-cell
|
39 |
+
sorted_data_left_lon = sorted(data_left, key=lambda point: point[1])
|
40 |
+
sorted_data_right_lon = sorted(data_right, key=lambda point: point[1])
|
41 |
+
|
42 |
+
# Calculate the median ylong coordinate in each half-cell
|
43 |
+
median_lon_left = statistics.median(point[1] for point in sorted_data_left_lon)
|
44 |
+
median_lon_right = statistics.median(
|
45 |
+
point[1] for point in sorted_data_right_lon
|
46 |
+
)
|
47 |
+
|
48 |
+
if (depth > 0) and (len(self.data) >= do_split):
|
49 |
+
# split the data into four quadrants
|
50 |
+
data_q1 = data[
|
51 |
+
(data["latitude"] < median_lat) & (data["longitude"] < median_lon_left)
|
52 |
+
]
|
53 |
+
data_q2 = data[
|
54 |
+
(data["latitude"] < median_lat) & (data["longitude"] >= median_lon_left)
|
55 |
+
]
|
56 |
+
data_q3 = data[
|
57 |
+
(data["latitude"] >= median_lat)
|
58 |
+
& (data["longitude"] < median_lon_right)
|
59 |
+
]
|
60 |
+
data_q4 = data[
|
61 |
+
(data["latitude"] >= median_lat)
|
62 |
+
& (data["longitude"] >= median_lon_right)
|
63 |
+
]
|
64 |
+
|
65 |
+
# recursively build a quad tree on each quadrant which has data
|
66 |
+
if data_q1.shape[0] > 0:
|
67 |
+
self.children.append(
|
68 |
+
QuadTree(
|
69 |
+
data_q1,
|
70 |
+
id + "0",
|
71 |
+
depth - 1,
|
72 |
+
do_split=do_split,
|
73 |
+
)
|
74 |
+
)
|
75 |
+
if data_q2.shape[0] > 0:
|
76 |
+
self.children.append(
|
77 |
+
QuadTree(
|
78 |
+
data_q2,
|
79 |
+
id + "1",
|
80 |
+
depth - 1,
|
81 |
+
do_split=do_split,
|
82 |
+
)
|
83 |
+
)
|
84 |
+
if data_q3.shape[0] > 0:
|
85 |
+
self.children.append(
|
86 |
+
QuadTree(
|
87 |
+
data_q3,
|
88 |
+
id + "2",
|
89 |
+
depth - 1,
|
90 |
+
do_split=do_split,
|
91 |
+
)
|
92 |
+
)
|
93 |
+
if data_q4.shape[0] > 0:
|
94 |
+
self.children.append(
|
95 |
+
QuadTree(
|
96 |
+
data_q4,
|
97 |
+
id + "3",
|
98 |
+
depth - 1,
|
99 |
+
do_split=do_split,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
|
103 |
+
def unwrap(self):
|
104 |
+
if len(self.children) == 0:
|
105 |
+
return {self.id: [self.mins, self.maxs, self.data.copy()]}
|
106 |
+
else:
|
107 |
+
d = dict()
|
108 |
+
for child in self.children:
|
109 |
+
d.update(child.unwrap())
|
110 |
+
return d
|
111 |
+
|
112 |
+
|
113 |
+
def extract(qt, name_new_column):
|
114 |
+
cluster = qt.unwrap()
|
115 |
+
boundaries, data = {}, []
|
116 |
+
for i, (id, vs) in zip(np.arange(len(cluster)), cluster.items()):
|
117 |
+
(min_lat, min_lon), (max_lat, max_lon), points = vs
|
118 |
+
points[name_new_column] = int(i)
|
119 |
+
data.append(points)
|
120 |
+
boundaries[i] = (
|
121 |
+
float(min_lat),
|
122 |
+
float(min_lon),
|
123 |
+
float(max_lat),
|
124 |
+
float(max_lon),
|
125 |
+
points["latitude"].mean(),
|
126 |
+
points["longitude"].mean(),
|
127 |
+
)
|
128 |
+
|
129 |
+
data = pd.concat(data)
|
130 |
+
return boundaries, data
|
131 |
+
|
132 |
+
|
133 |
+
def vizu(name_new_column, df_train, boundaries, do_split):
|
134 |
+
plt.hist(df_train[name_new_column], bins=len(boundaries))
|
135 |
+
plt.xlabel("Cluster ID")
|
136 |
+
plt.ylabel("Number of images")
|
137 |
+
plt.title("Cluster distribution")
|
138 |
+
plt.yscale("log")
|
139 |
+
plt.ylim(10, do_split)
|
140 |
+
plt.savefig(f"{name_new_column}_distrib.png")
|
141 |
+
plt.clf()
|
142 |
+
|
143 |
+
plt.scatter(
|
144 |
+
df_train["longitude"].to_numpy(),
|
145 |
+
df_train["latitude"].to_numpy(),
|
146 |
+
c=np.random.permutation(len(boundaries))[df_train[name_new_column].to_numpy()],
|
147 |
+
cmap="tab20",
|
148 |
+
s=0.1,
|
149 |
+
alpha=0.5,
|
150 |
+
)
|
151 |
+
plt.xlabel("Longitude")
|
152 |
+
plt.ylabel("Latitude")
|
153 |
+
plt.title("Quadtree map")
|
154 |
+
plt.savefig(f"{name_new_column}_map.png")
|
155 |
+
|
156 |
+
|
157 |
+
@hydra.main(
|
158 |
+
config_path="../configs/scripts",
|
159 |
+
config_name="enrich-metadata-quadtree",
|
160 |
+
version_base=None,
|
161 |
+
)
|
162 |
+
def main(cfg):
|
163 |
+
|
164 |
+
data_path = join(cfg.data_dir, "osv5m")
|
165 |
+
name_new_column = f"adaptive_quadtree_{cfg.depth}_{cfg.do_split}"
|
166 |
+
|
167 |
+
# Create clusters from train images
|
168 |
+
train_fp = join(data_path, f"train.csv")
|
169 |
+
df_train = pd.read_csv(train_fp)
|
170 |
+
|
171 |
+
qt = QuadTree(df_train, depth=cfg.depth, do_split=cfg.do_split)
|
172 |
+
boundaries, df_train = extract(qt, name_new_column)
|
173 |
+
|
174 |
+
vizu(name_new_column, df_train, boundaries, cfg.do_split)
|
175 |
+
|
176 |
+
# Save clusters
|
177 |
+
boundaries = pd.DataFrame.from_dict(
|
178 |
+
boundaries,
|
179 |
+
orient="index",
|
180 |
+
columns=["min_lat", "min_lon", "max_lat", "max_lon", "mean_lat", "mean_lon"],
|
181 |
+
)
|
182 |
+
boundaries.to_csv(f"{name_new_column}.csv", index_label="cluster_id")
|
183 |
+
|
184 |
+
# Assign test images to clusters
|
185 |
+
test_fp = join(data_path, f"test.csv")
|
186 |
+
df_test = pd.read_csv(test_fp)
|
187 |
+
|
188 |
+
above_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) > np.expand_dims(
|
189 |
+
boundaries["min_lat"].to_numpy(), 0
|
190 |
+
)
|
191 |
+
below_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) < np.expand_dims(
|
192 |
+
boundaries["max_lat"].to_numpy(), 0
|
193 |
+
)
|
194 |
+
above_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) > np.expand_dims(
|
195 |
+
boundaries["min_lon"].to_numpy(), 0
|
196 |
+
)
|
197 |
+
below_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) < np.expand_dims(
|
198 |
+
boundaries["max_lon"].to_numpy(), 0
|
199 |
+
)
|
200 |
+
|
201 |
+
mask = np.logical_and(
|
202 |
+
np.logical_and(above_lat, below_lat), np.logical_and(above_lon, below_lon)
|
203 |
+
)
|
204 |
+
|
205 |
+
df_test[name_new_column] = np.argmax(mask, axis=1)
|
206 |
+
|
207 |
+
# save index_to_gps_quadtree file
|
208 |
+
lat = torch.tensor(boundaries["mean_lat"])
|
209 |
+
lon = torch.tensor(boundaries["mean_lon"])
|
210 |
+
coord = torch.stack([lat / 90, lon / 180], dim=-1)
|
211 |
+
torch.save(
|
212 |
+
coord,
|
213 |
+
join(
|
214 |
+
data_path, f"index_to_gps_adaptive_quadtree_{cfg.depth}_{cfg.do_split}.pt"
|
215 |
+
),
|
216 |
+
)
|
217 |
+
|
218 |
+
# Overwrite test.csv and train.csv
|
219 |
+
if cfg.overwrite_csv:
|
220 |
+
df_train.to_csv(train_fp, index=False)
|
221 |
+
df_test.to_csv(test_fp, index=False)
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == "__main__":
|
225 |
+
main()
|
scripts/preprocessing/enrich-metadata-quadtree.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from os.path import join, dirname
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class QuadTree(object):
|
10 |
+
def __init__(self, data, mins=None, maxs=None, id="", depth=3, do_split=1000):
|
11 |
+
self.id = id
|
12 |
+
self.data = data
|
13 |
+
|
14 |
+
if mins is None:
|
15 |
+
mins = data[["latitude", "longitude"]].to_numpy().min(0)
|
16 |
+
if maxs is None:
|
17 |
+
maxs = data[["latitude", "longitude"]].to_numpy().max(0)
|
18 |
+
|
19 |
+
self.mins = np.asarray(mins)
|
20 |
+
self.maxs = np.asarray(maxs)
|
21 |
+
self.sizes = self.maxs - self.mins
|
22 |
+
|
23 |
+
self.children = []
|
24 |
+
|
25 |
+
mids = 0.5 * (self.mins + self.maxs)
|
26 |
+
xmin, ymin = self.mins
|
27 |
+
xmax, ymax = self.maxs
|
28 |
+
xmid, ymid = mids
|
29 |
+
|
30 |
+
if (depth > 0) and (len(self.data) >= do_split):
|
31 |
+
# split the data into four quadrants
|
32 |
+
data_q1 = data[(data["latitude"] < mids[0]) & (data["longitude"] < mids[1])]
|
33 |
+
data_q2 = data[
|
34 |
+
(data["latitude"] < mids[0]) & (data["longitude"] >= mids[1])
|
35 |
+
]
|
36 |
+
data_q3 = data[
|
37 |
+
(data["latitude"] >= mids[0]) & (data["longitude"] < mids[1])
|
38 |
+
]
|
39 |
+
data_q4 = data[
|
40 |
+
(data["latitude"] >= mids[0]) & (data["longitude"] >= mids[1])
|
41 |
+
]
|
42 |
+
|
43 |
+
# recursively build a quad tree on each quadrant which has data
|
44 |
+
if data_q1.shape[0] > 0:
|
45 |
+
self.children.append(
|
46 |
+
QuadTree(
|
47 |
+
data_q1,
|
48 |
+
[xmin, ymin],
|
49 |
+
[xmid, ymid],
|
50 |
+
id + "0",
|
51 |
+
depth - 1,
|
52 |
+
do_split=do_split,
|
53 |
+
)
|
54 |
+
)
|
55 |
+
if data_q2.shape[0] > 0:
|
56 |
+
self.children.append(
|
57 |
+
QuadTree(
|
58 |
+
data_q2,
|
59 |
+
[xmin, ymid],
|
60 |
+
[xmid, ymax],
|
61 |
+
id + "1",
|
62 |
+
depth - 1,
|
63 |
+
do_split=do_split,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
if data_q3.shape[0] > 0:
|
67 |
+
self.children.append(
|
68 |
+
QuadTree(
|
69 |
+
data_q3,
|
70 |
+
[xmid, ymin],
|
71 |
+
[xmax, ymid],
|
72 |
+
id + "2",
|
73 |
+
depth - 1,
|
74 |
+
do_split=do_split,
|
75 |
+
)
|
76 |
+
)
|
77 |
+
if data_q4.shape[0] > 0:
|
78 |
+
self.children.append(
|
79 |
+
QuadTree(
|
80 |
+
data_q4,
|
81 |
+
[xmid, ymid],
|
82 |
+
[xmax, ymax],
|
83 |
+
id + "3",
|
84 |
+
depth - 1,
|
85 |
+
do_split=do_split,
|
86 |
+
)
|
87 |
+
)
|
88 |
+
|
89 |
+
def unwrap(self):
|
90 |
+
if len(self.children) == 0:
|
91 |
+
return {self.id: [self.mins, self.maxs, self.data.copy()]}
|
92 |
+
else:
|
93 |
+
d = dict()
|
94 |
+
for child in self.children:
|
95 |
+
d.update(child.unwrap())
|
96 |
+
return d
|
97 |
+
|
98 |
+
|
99 |
+
def extract(qt, name_new_column):
|
100 |
+
cluster = qt.unwrap()
|
101 |
+
boundaries, data = {}, []
|
102 |
+
id_to_quad = np.array(list(cluster.keys()))
|
103 |
+
for i, (id, vs) in zip(np.arange(len(cluster)), cluster.items()):
|
104 |
+
(min_lat, min_lon), (max_lat, max_lon), points = vs
|
105 |
+
points[name_new_column] = int(i)
|
106 |
+
data.append(points)
|
107 |
+
boundaries[i] = (
|
108 |
+
float(min_lat),
|
109 |
+
float(min_lon),
|
110 |
+
float(max_lat),
|
111 |
+
float(max_lon),
|
112 |
+
points["latitude"].mean(),
|
113 |
+
points["longitude"].mean(),
|
114 |
+
)
|
115 |
+
|
116 |
+
data = pd.concat(data)
|
117 |
+
return boundaries, data, id_to_quad
|
118 |
+
|
119 |
+
|
120 |
+
def vizu(name_new_column, df_train, boundaries):
|
121 |
+
plt.hist(df_train[name_new_column], bins=len(boundaries))
|
122 |
+
plt.xlabel("Cluster ID")
|
123 |
+
plt.ylabel("Number of images")
|
124 |
+
plt.title("Cluster distribution")
|
125 |
+
plt.yscale("log")
|
126 |
+
plt.savefig(f"{name_new_column}_distrib.png")
|
127 |
+
plt.clf()
|
128 |
+
|
129 |
+
plt.scatter(
|
130 |
+
df_train["longitude"].to_numpy(),
|
131 |
+
df_train["latitude"].to_numpy(),
|
132 |
+
c=np.random.permutation(len(boundaries))[df_train[name_new_column].to_numpy()],
|
133 |
+
cmap="tab20",
|
134 |
+
s=0.1,
|
135 |
+
alpha=0.5,
|
136 |
+
)
|
137 |
+
plt.xlabel("Longitude")
|
138 |
+
plt.ylabel("Latitude")
|
139 |
+
plt.title("Quadtree map")
|
140 |
+
plt.savefig(f"{name_new_column}_map.png")
|
141 |
+
|
142 |
+
|
143 |
+
@hydra.main(
|
144 |
+
config_path="../configs/scripts",
|
145 |
+
config_name="enrich-metadata-quadtree",
|
146 |
+
version_base=None,
|
147 |
+
)
|
148 |
+
def main(cfg):
|
149 |
+
data_path = join(cfg.data_dir, "osv5m")
|
150 |
+
name_new_column = f"quadtree_{cfg.depth}_{cfg.do_split}"
|
151 |
+
|
152 |
+
# Create clusters from train images
|
153 |
+
train_fp = join(data_path, f"train.csv")
|
154 |
+
df_train = pd.read_csv(train_fp)
|
155 |
+
|
156 |
+
qt = QuadTree(df_train, depth=cfg.depth, do_split=cfg.do_split)
|
157 |
+
boundaries, df_train, id_to_quad = extract(qt, name_new_column)
|
158 |
+
|
159 |
+
vizu(name_new_column, df_train, boundaries)
|
160 |
+
|
161 |
+
# Save clusters
|
162 |
+
boundaries = pd.DataFrame.from_dict(
|
163 |
+
boundaries,
|
164 |
+
orient="index",
|
165 |
+
columns=["min_lat", "min_lon", "max_lat", "max_lon", "mean_lat", "mean_lon"],
|
166 |
+
)
|
167 |
+
boundaries.to_csv(f"{name_new_column}.csv", index_label="cluster_id")
|
168 |
+
|
169 |
+
# Assign test images to clusters
|
170 |
+
test_fp = join(data_path, f"test.csv")
|
171 |
+
df_test = pd.read_csv(test_fp)
|
172 |
+
|
173 |
+
above_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) > np.expand_dims(
|
174 |
+
boundaries["min_lat"].to_numpy(), 0
|
175 |
+
)
|
176 |
+
below_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) < np.expand_dims(
|
177 |
+
boundaries["max_lat"].to_numpy(), 0
|
178 |
+
)
|
179 |
+
above_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) > np.expand_dims(
|
180 |
+
boundaries["min_lon"].to_numpy(), 0
|
181 |
+
)
|
182 |
+
below_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) < np.expand_dims(
|
183 |
+
boundaries["max_lon"].to_numpy(), 0
|
184 |
+
)
|
185 |
+
|
186 |
+
mask = np.logical_and(
|
187 |
+
np.logical_and(above_lat, below_lat), np.logical_and(above_lon, below_lon)
|
188 |
+
)
|
189 |
+
|
190 |
+
df_test[name_new_column] = np.argmax(mask, axis=1)
|
191 |
+
|
192 |
+
# save index_to_gps_quadtree file
|
193 |
+
lat = torch.tensor(boundaries["mean_lat"])
|
194 |
+
lon = torch.tensor(boundaries["mean_lon"])
|
195 |
+
coord = torch.stack([lat / 90, lon / 180], dim=-1)
|
196 |
+
torch.save(
|
197 |
+
coord, join(data_path, f"index_to_gps_quadtree_{cfg.depth}_{cfg.do_split}.pt")
|
198 |
+
)
|
199 |
+
|
200 |
+
torch.save(id_to_quad, join(data_path, f"id_to_quad_{cfg.depth}_{cfg.do_split}.pt"))
|
201 |
+
# Overwrite test.csv and train.csv
|
202 |
+
if cfg.overwrite_csv:
|
203 |
+
df_train.to_csv(train_fp, index=False)
|
204 |
+
df_test.to_csv(test_fp, index=False)
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
main()
|
scripts/preprocessing/enrich-metadata.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import joblib
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import reverse_geocoder
|
7 |
+
from os.path import join, dirname
|
8 |
+
|
9 |
+
|
10 |
+
class QuadTree(object):
|
11 |
+
def __init__(
|
12 |
+
self, data, mins=None, maxs=None, id="", depth=3, min_split=0, do_split=1000
|
13 |
+
):
|
14 |
+
self.id = id
|
15 |
+
self.data = data
|
16 |
+
|
17 |
+
if mins is None:
|
18 |
+
mins = data[["latitude", "longitude"]].to_numpy().min(0)
|
19 |
+
if maxs is None:
|
20 |
+
maxs = data[["latitude", "longitude"]].to_numpy().max(0)
|
21 |
+
|
22 |
+
self.mins = np.asarray(mins)
|
23 |
+
self.maxs = np.asarray(maxs)
|
24 |
+
self.sizes = self.maxs - self.mins
|
25 |
+
|
26 |
+
self.children = []
|
27 |
+
|
28 |
+
mids = 0.5 * (self.mins + self.maxs)
|
29 |
+
xmin, ymin = self.mins
|
30 |
+
xmax, ymax = self.maxs
|
31 |
+
xmid, ymid = mids
|
32 |
+
|
33 |
+
if depth > 0 and len(self.data) >= do_split:
|
34 |
+
# split the data into four quadrants
|
35 |
+
data_q1 = data[(data["latitude"] < mids[0]) & (data["longitude"] < mids[1])]
|
36 |
+
data_q2 = data[
|
37 |
+
(data["latitude"] < mids[0]) & (data["longitude"] >= mids[1])
|
38 |
+
]
|
39 |
+
data_q3 = data[
|
40 |
+
(data["latitude"] >= mids[0]) & (data["longitude"] < mids[1])
|
41 |
+
]
|
42 |
+
data_q4 = data[
|
43 |
+
(data["latitude"] >= mids[0]) & (data["longitude"] >= mids[1])
|
44 |
+
]
|
45 |
+
|
46 |
+
# recursively build a quad tree on each quadrant which has data
|
47 |
+
if data_q1.shape[0] > min_split:
|
48 |
+
self.children.append(
|
49 |
+
QuadTree(data_q1, [xmin, ymin], [xmid, ymid], id + "0", depth - 1)
|
50 |
+
)
|
51 |
+
if data_q2.shape[0] > min_split:
|
52 |
+
self.children.append(
|
53 |
+
QuadTree(data_q2, [xmin, ymid], [xmid, ymax], id + "1", depth - 1)
|
54 |
+
)
|
55 |
+
if data_q3.shape[0] > min_split:
|
56 |
+
self.children.append(
|
57 |
+
QuadTree(data_q3, [xmid, ymin], [xmax, ymid], id + "2", depth - 1)
|
58 |
+
)
|
59 |
+
if data_q4.shape[0] > min_split:
|
60 |
+
self.children.append(
|
61 |
+
QuadTree(data_q4, [xmid, ymid], [xmax, ymax], id + "3", depth - 1)
|
62 |
+
)
|
63 |
+
|
64 |
+
def unwrap(self):
|
65 |
+
if len(self.children) == 0:
|
66 |
+
return {self.id: [self.mins, self.maxs, self.data.copy()]}
|
67 |
+
else:
|
68 |
+
d = dict()
|
69 |
+
for child in self.children:
|
70 |
+
d.update(child.unwrap())
|
71 |
+
return d
|
72 |
+
|
73 |
+
|
74 |
+
def extract(qt):
|
75 |
+
cluster = qt.unwrap()
|
76 |
+
boundaries, data = {}, []
|
77 |
+
for id, vs in cluster.items():
|
78 |
+
(min_lat, min_lon), (max_lat, max_lon), points = vs
|
79 |
+
points["category"] = id
|
80 |
+
data.append(points)
|
81 |
+
boundaries[id] = (
|
82 |
+
float(min_lat),
|
83 |
+
float(min_lon),
|
84 |
+
float(max_lat),
|
85 |
+
float(max_lon),
|
86 |
+
)
|
87 |
+
|
88 |
+
data = pd.concat(data)
|
89 |
+
return boundaries, data
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
# merge into one DataFrame
|
94 |
+
data_path = join(dirname(dirname(__file__)), "datasets", "osv5m")
|
95 |
+
train_fp = join(data_path, f"train.csv")
|
96 |
+
test_fp = join(data_path, f"test.csv")
|
97 |
+
|
98 |
+
df_train = pd.read_csv(train_fp)
|
99 |
+
df_train["split"] = "train"
|
100 |
+
|
101 |
+
df_test = pd.read_csv(test_fp)
|
102 |
+
df_test["split"] = "test"
|
103 |
+
|
104 |
+
df = pd.concat([df_train, df_test])
|
105 |
+
size_before = df.shape[0]
|
106 |
+
qt = QuadTree(df, depth=15)
|
107 |
+
boundaries, df = extract(qt)
|
108 |
+
assert df.shape[0] == size_before
|
109 |
+
|
110 |
+
location = reverse_geocoder.search(
|
111 |
+
[(lat, lon) for lat, lon in zip(df["latitude"], df["longitude"])]
|
112 |
+
)
|
113 |
+
df["city"] = [l.get("name", "") for l in location]
|
114 |
+
df["country"] = [l.get("cc", "") for l in location]
|
115 |
+
del location
|
116 |
+
|
117 |
+
df_train = df[df["split"] == "train"].drop(["split"], axis=1)
|
118 |
+
df_test = df[df["split"] == "test"].drop(["split"], axis=1)
|
119 |
+
assert (df_train.shape[0] + df_test.shape[0]) == size_before
|
120 |
+
|
121 |
+
json.dump(boundaries, open(join(data_path, "borders.json"), "w"))
|
122 |
+
df_train.to_csv(train_fp, index=False)
|
123 |
+
df_test.to_csv(test_fp, index=False)
|
scripts/preprocessing/fix_namimbia.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import join, dirname
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
# Define the list of cities
|
7 |
+
cities = [
|
8 |
+
"Walvis Bay",
|
9 |
+
"Keetmanshoop",
|
10 |
+
"Warmbad",
|
11 |
+
"Rundu",
|
12 |
+
"Outapi",
|
13 |
+
"Karibib",
|
14 |
+
"Otjimbingwe",
|
15 |
+
"Ondangwa",
|
16 |
+
"Oranjemund",
|
17 |
+
"Maltahohe",
|
18 |
+
"Otavi",
|
19 |
+
"Outjo",
|
20 |
+
"Swakopmund",
|
21 |
+
"Gobabis",
|
22 |
+
"Karasburg",
|
23 |
+
"Opuwo",
|
24 |
+
"Hentiesbaai",
|
25 |
+
"Katima Mulilo",
|
26 |
+
"Oshikango",
|
27 |
+
"Bethanie",
|
28 |
+
"Ongandjera",
|
29 |
+
"Mariental",
|
30 |
+
"Bagani",
|
31 |
+
"Nkurenkuru",
|
32 |
+
"Usakos",
|
33 |
+
"Rehoboth",
|
34 |
+
"Aranos",
|
35 |
+
"Omaruru",
|
36 |
+
"Arandis",
|
37 |
+
"Windhoek",
|
38 |
+
"Khorixas",
|
39 |
+
"Okahandja",
|
40 |
+
"Grootfontein",
|
41 |
+
"Tsumeb",
|
42 |
+
]
|
43 |
+
|
44 |
+
csv_dtype = {"category": str, "country": str, "city": str}
|
45 |
+
for split in ["train", "test"]:
|
46 |
+
fp = join(
|
47 |
+
dirname(dirname(__file__)), "datasets", "osv5m", f"{split}.csv"
|
48 |
+
)
|
49 |
+
|
50 |
+
# Read the CSV file into a pandas DataFrame
|
51 |
+
df = pd.read_csv(fp, dtype=csv_dtype)
|
52 |
+
|
53 |
+
# Check if the "country" column contains any of the cities in the list
|
54 |
+
mask = df["city"].isin(cities)
|
55 |
+
|
56 |
+
# If a city is found, set the corresponding rows in the "country" column to 'NMB'
|
57 |
+
df.loc[mask, "country"] = "NMB"
|
58 |
+
assert all(map(lambda x: isinstance(x, str), df["country"].unique().tolist()))
|
59 |
+
|
60 |
+
# Drop the columns that are all NaN
|
61 |
+
df.dropna(subset=["id", "latitude", "longitude"], inplace=True)
|
62 |
+
|
63 |
+
# Save the modified DataFrame back to the CSV file
|
64 |
+
df.to_csv(fp, index=False)
|
scripts/preprocessing/nearest-neighbors.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import json
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
from os.path import dirname, join
|
6 |
+
|
7 |
+
sys.path.append(dirname(dirname(__file__)))
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import AutoImageProcessor, AutoModel
|
11 |
+
from transformers import CLIPProcessor, CLIPModel
|
12 |
+
from transformers import pipeline
|
13 |
+
|
14 |
+
from data.data import osv5m
|
15 |
+
from json_stream import streamable_list
|
16 |
+
|
17 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
|
20 |
+
def load_model_clip():
|
21 |
+
model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
|
22 |
+
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
|
23 |
+
return processor, model.to(DEVICE)
|
24 |
+
|
25 |
+
|
26 |
+
def load_model_dino():
|
27 |
+
model = AutoModel.from_pretrained("facebook/dinov2-base")
|
28 |
+
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
29 |
+
return processor, model.to(DEVICE)
|
30 |
+
|
31 |
+
|
32 |
+
def compute_dino(processor, model, x):
|
33 |
+
inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
|
34 |
+
outputs = model(**inputs)
|
35 |
+
last_hidden_states = outputs.last_hidden_state.cpu().numpy()
|
36 |
+
for i in range(len(x[0])):
|
37 |
+
yield [last_hidden_states[i].tolist(), x[1][i], x[2][i], x[3][i]]
|
38 |
+
|
39 |
+
|
40 |
+
def compute_clip(processor, model, x):
|
41 |
+
inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
|
42 |
+
features = model.get_image_features(**inputs)
|
43 |
+
features /= features.norm(dim=-1, keepdim=True)
|
44 |
+
features = features.cpu().numpy()
|
45 |
+
for i in range(len(x[0])):
|
46 |
+
yield [features[i].tolist(), x[1][i], x[2][i], x[3][i]]
|
47 |
+
|
48 |
+
|
49 |
+
def get_batch(dataset, batch_size):
|
50 |
+
data, lats, lons, ids = [], [], [], []
|
51 |
+
for i in range(len(dataset)):
|
52 |
+
id, lat, lon = dataset.df.iloc[i]
|
53 |
+
data.append(Image.open(join(dataset.image_folder, f"{int(id)}.jpg")))
|
54 |
+
lats.append(lat)
|
55 |
+
lons.append(lon)
|
56 |
+
ids.append(id)
|
57 |
+
if len(data) == batch_size:
|
58 |
+
yield data, lats, lons, ids
|
59 |
+
data, lats, lons, ids = [], [], [], []
|
60 |
+
|
61 |
+
if len(data) > 0:
|
62 |
+
yield data, lats, lons, ids
|
63 |
+
data, lats, lons, ids = [], [], [], []
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
import argparse
|
68 |
+
|
69 |
+
parser = argparse.ArgumentParser()
|
70 |
+
parser.add_argument("--batch_size", type=int, default=256)
|
71 |
+
parser.add_argument("--compute_features", action="store_true")
|
72 |
+
parser.add_argument("--compute_nearest", action="store_true")
|
73 |
+
parser.add_argument("--json_path", default="features")
|
74 |
+
parser.add_argument("--which", type=str, default="clip", choices=["clip", "dino"])
|
75 |
+
args = parser.parse_args()
|
76 |
+
json_path = join(args.json_path, args.which)
|
77 |
+
|
78 |
+
os.makedirs(json_path, exist_ok=True)
|
79 |
+
if args.compute_features:
|
80 |
+
processor, model = (
|
81 |
+
load_model_clip() if args.which == "clip" else load_model_dino()
|
82 |
+
)
|
83 |
+
compute_fn = compute_clip if args.which == "clip" else compute_dino
|
84 |
+
|
85 |
+
for split in ["test"]: #'train',
|
86 |
+
# open existing json and read as dictionary
|
87 |
+
json_path_ = join(json_path, f"{split}.json")
|
88 |
+
|
89 |
+
dataset = osv5m(
|
90 |
+
"datasets/osv5m", transforms=None, split=split, dont_split=True
|
91 |
+
)
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def compute(batch_size):
|
95 |
+
for data in tqdm(
|
96 |
+
get_batch(dataset, batch_size),
|
97 |
+
total=len(dataset) // batch_size,
|
98 |
+
desc=f"Computing {split} on {args.which}",
|
99 |
+
):
|
100 |
+
features = compute_fn(processor, model, data)
|
101 |
+
for feature, lat, lon, id in features:
|
102 |
+
yield feature, lat, lon, id
|
103 |
+
|
104 |
+
data = streamable_list(compute(args.batch_size))
|
105 |
+
json.dump(data, open(json_path_, "w"), indent=4)
|
106 |
+
|
107 |
+
if args.compute_nearest:
|
108 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
109 |
+
import numpy as np
|
110 |
+
|
111 |
+
train, test = [
|
112 |
+
json.load(open(join(json_path, f"{split}.json"), "r"))
|
113 |
+
for split in ["train", "test"]
|
114 |
+
]
|
115 |
+
|
116 |
+
def get_neighbors(k=10):
|
117 |
+
for i, test_data in enumerate(tqdm(test)):
|
118 |
+
feature, lat, lon, id = test_data
|
119 |
+
features_train = np.stack(
|
120 |
+
[np.array(train_data[0]) for train_data in train]
|
121 |
+
)
|
122 |
+
cs = np.squeeze(
|
123 |
+
cosine_similarity(np.expand_dims(feature, axis=0), features_train),
|
124 |
+
axis=0,
|
125 |
+
)
|
126 |
+
i = np.argsort(cs)[-k:][::-1].tolist()
|
127 |
+
yield [
|
128 |
+
{n: x}
|
129 |
+
for idx in i
|
130 |
+
for n, x in zip(
|
131 |
+
["feature", "lat", "lon", "id", "distance"],
|
132 |
+
train[idx]
|
133 |
+
+ [
|
134 |
+
cs[idx],
|
135 |
+
],
|
136 |
+
)
|
137 |
+
]
|
138 |
+
|
139 |
+
data = streamable_list(get_neighbors())
|
140 |
+
json.dump(data, open(join(json_path, "nearest.json"), "w"), indent=4)
|
scripts/preprocessing/preprocess.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from os.path import join
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import hydra
|
7 |
+
|
8 |
+
|
9 |
+
class QuadTree(object):
|
10 |
+
def __init__(self, data, mins=None, maxs=None, id="", depth=3, do_split=1000):
|
11 |
+
self.id = id
|
12 |
+
self.data = data
|
13 |
+
|
14 |
+
if mins is None:
|
15 |
+
mins = data[["latitude", "longitude"]].to_numpy().min(0)
|
16 |
+
if maxs is None:
|
17 |
+
maxs = data[["latitude", "longitude"]].to_numpy().max(0)
|
18 |
+
|
19 |
+
self.mins = np.asarray(mins)
|
20 |
+
self.maxs = np.asarray(maxs)
|
21 |
+
self.sizes = self.maxs - self.mins
|
22 |
+
|
23 |
+
self.children = []
|
24 |
+
|
25 |
+
mids = 0.5 * (self.mins + self.maxs)
|
26 |
+
xmin, ymin = self.mins
|
27 |
+
xmax, ymax = self.maxs
|
28 |
+
xmid, ymid = mids
|
29 |
+
|
30 |
+
if (depth > 0) and (len(self.data) >= do_split):
|
31 |
+
# split the data into four quadrants
|
32 |
+
data_q1 = data[(data["latitude"] < mids[0]) & (data["longitude"] < mids[1])]
|
33 |
+
data_q2 = data[
|
34 |
+
(data["latitude"] < mids[0]) & (data["longitude"] >= mids[1])
|
35 |
+
]
|
36 |
+
data_q3 = data[
|
37 |
+
(data["latitude"] >= mids[0]) & (data["longitude"] < mids[1])
|
38 |
+
]
|
39 |
+
data_q4 = data[
|
40 |
+
(data["latitude"] >= mids[0]) & (data["longitude"] >= mids[1])
|
41 |
+
]
|
42 |
+
|
43 |
+
# recursively build a quad tree on each quadrant which has data
|
44 |
+
if data_q1.shape[0] > 0:
|
45 |
+
self.children.append(
|
46 |
+
QuadTree(
|
47 |
+
data_q1,
|
48 |
+
[xmin, ymin],
|
49 |
+
[xmid, ymid],
|
50 |
+
id + "0",
|
51 |
+
depth - 1,
|
52 |
+
do_split=do_split,
|
53 |
+
)
|
54 |
+
)
|
55 |
+
if data_q2.shape[0] > 0:
|
56 |
+
self.children.append(
|
57 |
+
QuadTree(
|
58 |
+
data_q2,
|
59 |
+
[xmin, ymid],
|
60 |
+
[xmid, ymax],
|
61 |
+
id + "1",
|
62 |
+
depth - 1,
|
63 |
+
do_split=do_split,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
if data_q3.shape[0] > 0:
|
67 |
+
self.children.append(
|
68 |
+
QuadTree(
|
69 |
+
data_q3,
|
70 |
+
[xmid, ymin],
|
71 |
+
[xmax, ymid],
|
72 |
+
id + "2",
|
73 |
+
depth - 1,
|
74 |
+
do_split=do_split,
|
75 |
+
)
|
76 |
+
)
|
77 |
+
if data_q4.shape[0] > 0:
|
78 |
+
self.children.append(
|
79 |
+
QuadTree(
|
80 |
+
data_q4,
|
81 |
+
[xmid, ymid],
|
82 |
+
[xmax, ymax],
|
83 |
+
id + "3",
|
84 |
+
depth - 1,
|
85 |
+
do_split=do_split,
|
86 |
+
)
|
87 |
+
)
|
88 |
+
|
89 |
+
def unwrap(self):
|
90 |
+
if len(self.children) == 0:
|
91 |
+
return {self.id: [self.mins, self.maxs, self.data.copy()]}
|
92 |
+
else:
|
93 |
+
d = dict()
|
94 |
+
for child in self.children:
|
95 |
+
d.update(child.unwrap())
|
96 |
+
return d
|
97 |
+
|
98 |
+
|
99 |
+
def extract(qt, name_new_column):
|
100 |
+
cluster = qt.unwrap()
|
101 |
+
boundaries, data = {}, []
|
102 |
+
id_to_quad = np.array(list(cluster.keys()))
|
103 |
+
for i, (id, vs) in zip(np.arange(len(cluster)), cluster.items()):
|
104 |
+
(min_lat, min_lon), (max_lat, max_lon), points = vs
|
105 |
+
points[name_new_column] = int(i)
|
106 |
+
data.append(points)
|
107 |
+
boundaries[i] = (
|
108 |
+
float(min_lat),
|
109 |
+
float(min_lon),
|
110 |
+
float(max_lat),
|
111 |
+
float(max_lon),
|
112 |
+
points["latitude"].mean(),
|
113 |
+
points["longitude"].mean(),
|
114 |
+
)
|
115 |
+
|
116 |
+
data = pd.concat(data)
|
117 |
+
return boundaries, data, id_to_quad
|
118 |
+
|
119 |
+
|
120 |
+
def vizu(name_new_column, df_train, boundaries, save_path):
|
121 |
+
plt.hist(df_train[name_new_column], bins=len(boundaries))
|
122 |
+
plt.xlabel("Cluster ID")
|
123 |
+
plt.ylabel("Number of images")
|
124 |
+
plt.title("Cluster distribution")
|
125 |
+
plt.yscale("log")
|
126 |
+
plt.savefig(join(save_path, f"{name_new_column}_distrib.png"))
|
127 |
+
plt.clf()
|
128 |
+
|
129 |
+
plt.scatter(
|
130 |
+
df_train["longitude"].to_numpy(),
|
131 |
+
df_train["latitude"].to_numpy(),
|
132 |
+
c=np.random.permutation(len(boundaries))[df_train[name_new_column].to_numpy()],
|
133 |
+
cmap="tab20",
|
134 |
+
s=0.1,
|
135 |
+
alpha=0.5,
|
136 |
+
)
|
137 |
+
plt.xlabel("Longitude")
|
138 |
+
plt.ylabel("Latitude")
|
139 |
+
plt.title("Quadtree map")
|
140 |
+
plt.savefig(join(save_path, f"{name_new_column}_map.png"))
|
141 |
+
|
142 |
+
|
143 |
+
@hydra.main(
|
144 |
+
config_path="../../configs/scripts",
|
145 |
+
config_name="preprocess",
|
146 |
+
version_base=None,
|
147 |
+
)
|
148 |
+
def main(cfg):
|
149 |
+
data_path = join(cfg.data_dir, "osv5m")
|
150 |
+
save_path = cfg.data_dir
|
151 |
+
name_new_column = f"quadtree_{cfg.depth}_{cfg.do_split}"
|
152 |
+
|
153 |
+
# Create clusters from train images
|
154 |
+
train_fp = join(data_path, f"train.csv")
|
155 |
+
df_train = pd.read_csv(train_fp, low_memory=False)
|
156 |
+
|
157 |
+
qt = QuadTree(df_train, depth=cfg.depth, do_split=cfg.do_split)
|
158 |
+
boundaries, df_train, id_to_quad = extract(qt, name_new_column)
|
159 |
+
|
160 |
+
vizu(name_new_column, df_train, boundaries, save_path)
|
161 |
+
|
162 |
+
# Save clusters
|
163 |
+
boundaries = pd.DataFrame.from_dict(
|
164 |
+
boundaries,
|
165 |
+
orient="index",
|
166 |
+
columns=["min_lat", "min_lon", "max_lat", "max_lon", "mean_lat", "mean_lon"],
|
167 |
+
)
|
168 |
+
boundaries.to_csv(
|
169 |
+
join(save_path, f"{name_new_column}.csv"), index_label="cluster_id"
|
170 |
+
)
|
171 |
+
|
172 |
+
# Assign test images to clusters
|
173 |
+
test_fp = join(data_path, f"test.csv")
|
174 |
+
df_test = pd.read_csv(test_fp)
|
175 |
+
|
176 |
+
above_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) > np.expand_dims(
|
177 |
+
boundaries["min_lat"].to_numpy(), 0
|
178 |
+
)
|
179 |
+
below_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) < np.expand_dims(
|
180 |
+
boundaries["max_lat"].to_numpy(), 0
|
181 |
+
)
|
182 |
+
above_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) > np.expand_dims(
|
183 |
+
boundaries["min_lon"].to_numpy(), 0
|
184 |
+
)
|
185 |
+
below_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) < np.expand_dims(
|
186 |
+
boundaries["max_lon"].to_numpy(), 0
|
187 |
+
)
|
188 |
+
|
189 |
+
mask = np.logical_and(
|
190 |
+
np.logical_and(above_lat, below_lat), np.logical_and(above_lon, below_lon)
|
191 |
+
)
|
192 |
+
|
193 |
+
df_test[name_new_column] = np.argmax(mask, axis=1)
|
194 |
+
|
195 |
+
# save index_to_gps_quadtree file
|
196 |
+
lat = torch.tensor(boundaries["mean_lat"])
|
197 |
+
lon = torch.tensor(boundaries["mean_lon"])
|
198 |
+
coord = torch.stack([lat, lon], dim=-1)
|
199 |
+
torch.save(
|
200 |
+
coord, join(save_path, f"index_to_gps_quadtree_{cfg.depth}_{cfg.do_split}.pt")
|
201 |
+
)
|
202 |
+
|
203 |
+
torch.save(id_to_quad, join(save_path, f"id_to_quad_{cfg.depth}_{cfg.do_split}.pt"))
|
204 |
+
# Overwrite test.csv and train.csv
|
205 |
+
if cfg.overwrite_csv:
|
206 |
+
df_train.to_csv(train_fp, index=False)
|
207 |
+
df_test.to_csv(test_fp, index=False)
|
208 |
+
|
209 |
+
df = pd.read_csv(join(data_path, "train.csv"), low_memory=False).fillna("NaN")
|
210 |
+
# Compute the average location for each unique country
|
211 |
+
country_avg = (
|
212 |
+
df.groupby("unique_country")[["latitude", "longitude"]].mean().reset_index()
|
213 |
+
)
|
214 |
+
country_avg.to_csv(
|
215 |
+
join(save_path, "country_center.csv"),
|
216 |
+
columns=["unique_country", "latitude", "longitude"],
|
217 |
+
index=False,
|
218 |
+
)
|
219 |
+
# Compute the average location for each unique admin1 (region)
|
220 |
+
region_avg = (
|
221 |
+
df.groupby(["unique_region"])[["latitude", "longitude"]].mean().reset_index()
|
222 |
+
)
|
223 |
+
region_avg.to_csv(
|
224 |
+
join(save_path, "region_center.csv"),
|
225 |
+
columns=["unique_region", "latitude", "longitude"],
|
226 |
+
index=False,
|
227 |
+
)
|
228 |
+
# Compute the average location for each unique admin2 (area)
|
229 |
+
area_avg = (
|
230 |
+
df.groupby(["unique_sub-region"])[["latitude", "longitude"]]
|
231 |
+
.mean()
|
232 |
+
.reset_index()
|
233 |
+
)
|
234 |
+
area_avg.to_csv(
|
235 |
+
join(save_path, "sub-region_center.csv"),
|
236 |
+
columns=["unique_sub-region", "latitude", "longitude"],
|
237 |
+
index=False,
|
238 |
+
)
|
239 |
+
# Compute the average location for each unique city
|
240 |
+
city_avg = (
|
241 |
+
df.groupby(["unique_city"])[["latitude", "longitude"]].mean().reset_index()
|
242 |
+
)
|
243 |
+
city_avg.to_csv(
|
244 |
+
join(save_path, "city_center.csv"),
|
245 |
+
columns=["unique_city", "latitude", "longitude"],
|
246 |
+
index=False,
|
247 |
+
)
|
248 |
+
|
249 |
+
for class_name in [
|
250 |
+
"unique_country",
|
251 |
+
"unique_sub-region",
|
252 |
+
"unique_region",
|
253 |
+
"unique_city",
|
254 |
+
]:
|
255 |
+
# Load CSV data into a Pandas DataFrame
|
256 |
+
csv_file = class_name.split("_")[-1] + "_center.csv"
|
257 |
+
df = pd.read_csv(join(save_path, csv_file), low_memory=False)
|
258 |
+
|
259 |
+
splits = ["train"]
|
260 |
+
categories = sorted(
|
261 |
+
pd.concat(
|
262 |
+
[
|
263 |
+
pd.read_csv(
|
264 |
+
join(data_path, f"{split}.csv"), low_memory=False
|
265 |
+
)[class_name]
|
266 |
+
for split in splits
|
267 |
+
]
|
268 |
+
)
|
269 |
+
.fillna("NaN")
|
270 |
+
.unique()
|
271 |
+
.tolist()
|
272 |
+
)
|
273 |
+
|
274 |
+
if "NaN" in categories:
|
275 |
+
categories.remove("NaN")
|
276 |
+
|
277 |
+
# compute the total number of categories - this name is fixed and will be used as a lookup during init
|
278 |
+
num_classes = len(categories)
|
279 |
+
|
280 |
+
# create a mapping from category to index
|
281 |
+
category_to_index = {category: i for i, category in enumerate(categories)}
|
282 |
+
|
283 |
+
dictionary = torch.zeros((num_classes, 2))
|
284 |
+
for index, row in df.iterrows():
|
285 |
+
key = row.iloc[0]
|
286 |
+
value = [row.iloc[1], row.iloc[2]]
|
287 |
+
if key in categories:
|
288 |
+
(
|
289 |
+
dictionary[category_to_index[key], 0],
|
290 |
+
dictionary[category_to_index[key], 1],
|
291 |
+
) = np.radians(row.iloc[1]), np.radians(row.iloc[2])
|
292 |
+
|
293 |
+
# Save the PyTorch tensor to a .pt file
|
294 |
+
output_file = join(save_path, "index_to_gps_" + class_name + ".pt")
|
295 |
+
torch.save(dictionary, output_file)
|
296 |
+
|
297 |
+
train = pd.read_csv(join(data_path, "train.csv"), low_memory=False).fillna(
|
298 |
+
"NaN"
|
299 |
+
)
|
300 |
+
|
301 |
+
u = train.groupby("unique_city").sample(n=1)
|
302 |
+
|
303 |
+
country_df = (
|
304 |
+
u.pivot(index="unique_city", columns="unique_country", values="unique_city")
|
305 |
+
.notna()
|
306 |
+
.astype(int)
|
307 |
+
.fillna(0)
|
308 |
+
)
|
309 |
+
country_to_idx = {
|
310 |
+
category: i for i, category in enumerate(list(country_df.columns))
|
311 |
+
}
|
312 |
+
city_country_matrix = torch.tensor(country_df.values) / 1.0
|
313 |
+
|
314 |
+
region_df = (
|
315 |
+
u.pivot(index="unique_city", columns="unique_region", values="unique_city")
|
316 |
+
.notna()
|
317 |
+
.astype(int)
|
318 |
+
.fillna(0)
|
319 |
+
)
|
320 |
+
region_to_idx = {category: i for i, category in enumerate(list(region_df.columns))}
|
321 |
+
city_region_matrix = torch.tensor(region_df.values) / 1.0
|
322 |
+
|
323 |
+
country_df = (
|
324 |
+
u.pivot(index="unique_city", columns="unique_country", values="unique_city")
|
325 |
+
.notna()
|
326 |
+
.astype(int)
|
327 |
+
.fillna(0)
|
328 |
+
)
|
329 |
+
country_to_idx = {
|
330 |
+
category: i for i, category in enumerate(list(country_df.columns))
|
331 |
+
}
|
332 |
+
city_country_matrix = torch.tensor(country_df.values) / 1.0
|
333 |
+
|
334 |
+
output_file = join(save_path, "city_to_country.pt")
|
335 |
+
torch.save(city_country_matrix, output_file)
|
336 |
+
|
337 |
+
output_file = join(save_path, "country_to_idx.pt")
|
338 |
+
torch.save(country_to_idx, output_file)
|
339 |
+
|
340 |
+
region_df = (
|
341 |
+
u.pivot(index="unique_city", columns="unique_region", values="unique_city")
|
342 |
+
.notna()
|
343 |
+
.astype(int)
|
344 |
+
.fillna(0)
|
345 |
+
)
|
346 |
+
region_to_idx = {category: i for i, category in enumerate(list(region_df.columns))}
|
347 |
+
city_region_matrix = torch.tensor(region_df.values) / 1.0
|
348 |
+
|
349 |
+
output_file = join(save_path, "city_to_region.pt")
|
350 |
+
torch.save(city_region_matrix, output_file)
|
351 |
+
|
352 |
+
output_file = join(save_path, "region_to_idx.pt")
|
353 |
+
torch.save(region_to_idx, output_file)
|
354 |
+
|
355 |
+
area_df = (
|
356 |
+
u.pivot(index="unique_city", columns="unique_sub-region", values="unique_city")
|
357 |
+
.notna()
|
358 |
+
.astype(int)
|
359 |
+
.fillna(0)
|
360 |
+
)
|
361 |
+
area_to_idx = {category: i for i, category in enumerate(list(area_df.columns))}
|
362 |
+
city_area_matrix = torch.tensor(area_df.values) / 1.0
|
363 |
+
|
364 |
+
output_file = join(save_path, "city_to_area.pt")
|
365 |
+
torch.save(city_area_matrix, output_file)
|
366 |
+
|
367 |
+
output_file = join(save_path, "area_to_idx.pt")
|
368 |
+
torch.save(area_to_idx, output_file)
|
369 |
+
gt = torch.load(join(save_path, f"id_to_quad_{cfg.depth}_{cfg.do_split}.pt"))
|
370 |
+
matrixes = []
|
371 |
+
dicts = []
|
372 |
+
for i in range(1, cfg.depth):
|
373 |
+
# Step 2: Truncate strings to size cfg.depth - 1
|
374 |
+
l = [s[: cfg.depth - i] if len(s) >= cfg.depth + 1 - i else s for s in gt]
|
375 |
+
|
376 |
+
# Step 3: Get unique values in the modified list l
|
377 |
+
h = list(set(l))
|
378 |
+
|
379 |
+
# Step 4: Create a dictionary to map unique values to their index
|
380 |
+
h_dict = {value: index for index, value in enumerate(h)}
|
381 |
+
dicts.append(h_dict)
|
382 |
+
|
383 |
+
# Step 5: Initialize a torch matrix with zeros
|
384 |
+
matrix = torch.zeros((len(gt), len(h)))
|
385 |
+
|
386 |
+
# Step 6: Fill in the matrix with 1s based on the mapping
|
387 |
+
for h in range(len(gt)):
|
388 |
+
j = h_dict[l[h]]
|
389 |
+
matrix[h, j] = 1
|
390 |
+
matrixes.append(matrix)
|
391 |
+
|
392 |
+
output_file = join(save_path, "quadtree_matrixes.pt")
|
393 |
+
torch.save(matrixes, output_file)
|
394 |
+
|
395 |
+
output_file = join(save_path, "quadtree_dicts.pt")
|
396 |
+
torch.save(dicts, output_file)
|
397 |
+
|
398 |
+
|
399 |
+
if __name__ == "__main__":
|
400 |
+
main()
|
scripts/preprocessing/train-val-split.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import dirname, join
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
data_path = join(dirname(dirname(__file__)), "datasets", "osv5m")
|
9 |
+
train_fp = join(data_path, f"train.csv")
|
10 |
+
val_fp = join(data_path, f"val.csv")
|
11 |
+
os.makedirs(dirname(val_fp), exist_ok=True)
|
12 |
+
df = pd.read_csv(train_fp, dtype={"category": str, "country": str, "city": str})
|
13 |
+
df_train, df_val = train_test_split(df, stratify=df["category"], test_size=0.1)
|
14 |
+
df_train.to_csv(train_fp, index=False)
|
15 |
+
df_val.to_csv(val_fp, index=False)
|
scripts/retrieval/backbone.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import join
|
2 |
+
import PIL
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import reverse_geocoder
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
|
9 |
+
class GeoDataset(Dataset):
|
10 |
+
def __init__(self, image_folder, annotation_file, transformation, tag="image_id"):
|
11 |
+
self.image_folder = image_folder
|
12 |
+
gt = pd.read_csv(annotation_file, dtype={tag: str})
|
13 |
+
files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)])
|
14 |
+
gt = gt[gt[tag].isin(files)]
|
15 |
+
self.processor = transformation
|
16 |
+
self.gt = [
|
17 |
+
(g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows()
|
18 |
+
]
|
19 |
+
self.tag = tag
|
20 |
+
|
21 |
+
def fid(self, i):
|
22 |
+
return self.gt[i][0]
|
23 |
+
|
24 |
+
def latlon(self, i):
|
25 |
+
return self.gt[i][1]
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.gt)
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
fp = join(self.image_folder, self.gt[idx][0] + ".jpg")
|
32 |
+
return self.processor(self, idx, fp)
|
33 |
+
|
34 |
+
|
35 |
+
def load_plonk(path):
|
36 |
+
import hydra
|
37 |
+
from hydra import initialize, compose
|
38 |
+
from models.module import Geolocalizer
|
39 |
+
from omegaconf import OmegaConf, open_dict
|
40 |
+
from os.path import join
|
41 |
+
from hydra.utils import instantiate
|
42 |
+
|
43 |
+
# load config from path
|
44 |
+
# make path relative to current_dir
|
45 |
+
with initialize(version_base=None, config_path="osv5m__best_model"):
|
46 |
+
cfg = compose(config_name="config", overrides=[])
|
47 |
+
|
48 |
+
checkpoint = torch.load(join(path, "last.ckpt"))
|
49 |
+
del checkpoint["state_dict"][
|
50 |
+
"model.backbone.clip.vision_model.embeddings.position_ids"
|
51 |
+
]
|
52 |
+
torch.save(checkpoint, join(path, "last2.ckpt"))
|
53 |
+
|
54 |
+
with open_dict(cfg):
|
55 |
+
cfg.checkpoint = join(path, "last2.ckpt")
|
56 |
+
|
57 |
+
cfg.num_classes = 11399
|
58 |
+
cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
|
59 |
+
cfg.model.network.head.final_dim = cfg.num_classes * 3
|
60 |
+
cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")
|
61 |
+
|
62 |
+
cfg.dataset.train_dataset.path = ""
|
63 |
+
cfg.dataset.val_dataset.path = ""
|
64 |
+
cfg.dataset.test_dataset.path = ""
|
65 |
+
cfg.logger.save_dir = ""
|
66 |
+
cfg.data_dir = ""
|
67 |
+
cfg.root_dir = ""
|
68 |
+
cfg.mode = "test"
|
69 |
+
cfg.model.network.backbone.instance.path = (
|
70 |
+
"laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
|
71 |
+
)
|
72 |
+
transform = instantiate(cfg.dataset.test_transform)
|
73 |
+
model = Geolocalizer.load_from_checkpoint(join(path, "last2.ckpt"), cfg=cfg.model)
|
74 |
+
os.remove(join(path, "last2.ckpt"))
|
75 |
+
|
76 |
+
@torch.no_grad()
|
77 |
+
def inference(model, x):
|
78 |
+
return x[0], model.model.backbone({"img": x[1].to(model.device)})[:, 0, :].cpu()
|
79 |
+
|
80 |
+
def collate_fn(batch):
|
81 |
+
return [b[0] for b in batch], torch.stack([b[1] for b in batch], dim=0)
|
82 |
+
|
83 |
+
def operate(self, idx, fp):
|
84 |
+
proc = self.processor(PIL.Image.open(fp))
|
85 |
+
return self.gt[idx][0], proc
|
86 |
+
|
87 |
+
return model, operate, inference, collate_fn
|
88 |
+
|
89 |
+
|
90 |
+
def load_clip(which):
|
91 |
+
# We evaluate on:
|
92 |
+
# - "openai/clip-vit-base-patch32"
|
93 |
+
# - "openai/clip-vit-large-patch14-336"
|
94 |
+
# - "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
95 |
+
# - "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
|
96 |
+
# - "geolocal/StreetCLIP"
|
97 |
+
from transformers import CLIPProcessor, CLIPModel
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def inference(model, img):
|
101 |
+
image_ids = img.data.pop("image_id")
|
102 |
+
image_input = img.to(model.device)
|
103 |
+
image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
|
104 |
+
features = model.get_image_features(**image_input)
|
105 |
+
features /= features.norm(dim=-1, keepdim=True)
|
106 |
+
return image_ids, features.cpu()
|
107 |
+
|
108 |
+
processor = CLIPProcessor.from_pretrained(which)
|
109 |
+
|
110 |
+
def operate(self, idx, fp):
|
111 |
+
pil = PIL.Image.open(fp)
|
112 |
+
proc = processor(images=pil, return_tensors="pt")
|
113 |
+
proc["image_id"] = self.gt[idx][0]
|
114 |
+
return proc
|
115 |
+
|
116 |
+
return CLIPModel.from_pretrained(which), operate, inference, None
|
117 |
+
|
118 |
+
|
119 |
+
def load_dino(which):
|
120 |
+
# We evaluate on:
|
121 |
+
# - 'facebook/dinov2-large'
|
122 |
+
from transformers import AutoImageProcessor, AutoModel
|
123 |
+
|
124 |
+
@torch.no_grad()
|
125 |
+
def inference(model, img):
|
126 |
+
image_ids = img.data.pop("image_id")
|
127 |
+
image_input = img.to(model.device)
|
128 |
+
image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
|
129 |
+
features = model(**image_input).last_hidden_state[:, 0]
|
130 |
+
features /= features.norm(dim=-1, keepdim=True)
|
131 |
+
return image_ids, features.cpu()
|
132 |
+
|
133 |
+
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
|
134 |
+
|
135 |
+
def operate(self, idx, fp):
|
136 |
+
pil = PIL.Image.open(fp)
|
137 |
+
proc = processor(images=pil, return_tensors="pt")
|
138 |
+
proc["image_id"] = self.gt[idx][0]
|
139 |
+
return proc
|
140 |
+
|
141 |
+
return AutoModel.from_pretrained("facebook/dinov2-large"), operate, inference, None
|
142 |
+
|
143 |
+
|
144 |
+
def get_backbone(name):
|
145 |
+
if os.path.isdir(name):
|
146 |
+
return load_plonk(name)
|
147 |
+
elif "clip" in name.lower():
|
148 |
+
return load_clip(name)
|
149 |
+
elif "dino" in name.lower():
|
150 |
+
return load_dino(name)
|
scripts/retrieval/retrieval.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import PIL
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import operator
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
from itertools import cycle
|
12 |
+
from tqdm.auto import tqdm, trange
|
13 |
+
from os.path import join
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
from torch.utils.data import Dataset, DataLoader
|
18 |
+
from torch.nn import functional as F
|
19 |
+
|
20 |
+
from backbone import get_backbone
|
21 |
+
from utils import haversine, get_filenames, get_match_values, compute_print_accuracy
|
22 |
+
|
23 |
+
|
24 |
+
def compute_features(path, data_dir, csv_file, tag, args):
|
25 |
+
data = GeoDataset(data_dir, csv_file, tag=tag)
|
26 |
+
if not os.path.isdir(test_features_dir) or len(
|
27 |
+
os.listdir(test_features_dir)
|
28 |
+
) != len(data):
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
model, transform, inference, collate_fn = get_backbone(args.name)
|
31 |
+
dataloader = DataLoader(
|
32 |
+
data,
|
33 |
+
batch_size=args.batch_size,
|
34 |
+
shuffle=False,
|
35 |
+
num_workers=8,
|
36 |
+
collate_fn=collate_fn,
|
37 |
+
)
|
38 |
+
model = model.to(device)
|
39 |
+
os.makedirs(path, exist_ok=True)
|
40 |
+
|
41 |
+
for i, x in enumerate(tqdm(dataloader)):
|
42 |
+
image_ids, features = inference(model, x)
|
43 |
+
# save features as numpy array
|
44 |
+
for j, image_id in zip(range(features.shape[0]), image_ids):
|
45 |
+
np.save(join(path, f"{image_id}.npy"), features[j].unsqueeze(0).numpy())
|
46 |
+
|
47 |
+
|
48 |
+
def get_results(args, train_test):
|
49 |
+
import joblib
|
50 |
+
|
51 |
+
if not os.path.isfile(join(args.features_parent, ".cache", "1-nn.pkl")):
|
52 |
+
import faiss, glob, bisect
|
53 |
+
|
54 |
+
# import sys; sys.exit(0)
|
55 |
+
indexes = [
|
56 |
+
get_filenames(idx) for idx in tqdm(range(1, 6), desc="Loading indexes...")
|
57 |
+
]
|
58 |
+
|
59 |
+
train_gt = pd.read_csv(
|
60 |
+
join(args.data_parent, args.annotation_file), dtype={"image_id": str}
|
61 |
+
)[["image_id", "latitude", "longitude"]]
|
62 |
+
test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[
|
63 |
+
["id", "latitude", "longitude"]
|
64 |
+
]
|
65 |
+
|
66 |
+
# make a map between image_id and lat/lon
|
67 |
+
train_gt = {
|
68 |
+
g[1]["image_id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
|
69 |
+
for g in tqdm(
|
70 |
+
train_gt.iterrows(), total=len(train_gt), desc="Loading train_gt"
|
71 |
+
)
|
72 |
+
}
|
73 |
+
test_gt = {
|
74 |
+
g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
|
75 |
+
for g in tqdm(
|
76 |
+
test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt"
|
77 |
+
)
|
78 |
+
}
|
79 |
+
|
80 |
+
train_test = []
|
81 |
+
os.makedirs(join(args.features_parent, ".cache"), exist_ok=True)
|
82 |
+
for f in tqdm(os.listdir(test_features_dir)):
|
83 |
+
query_vector = np.load(join(test_features_dir, f))
|
84 |
+
|
85 |
+
neighbors = []
|
86 |
+
for index, ids in indexes:
|
87 |
+
distances, indices = index.search(query_vector, 1)
|
88 |
+
distances, indices = np.squeeze(distances), np.squeeze(indices)
|
89 |
+
bisect.insort(
|
90 |
+
neighbors, (ids[indices], distances), key=operator.itemgetter(1)
|
91 |
+
)
|
92 |
+
|
93 |
+
neighbors = list(reversed(neighbors))
|
94 |
+
train_gps = train_gt[neighbors[0][0].replace(".npy", "")][None, :]
|
95 |
+
test_gps = test_gt[f.replace(".npy", "")][None, :]
|
96 |
+
train_test.append((train_gps, test_gps))
|
97 |
+
joblib.dump(train_test, join(args.features_parent, ".cache", "1-nn.pkl"))
|
98 |
+
else:
|
99 |
+
train_test = joblib.load(join(args.features_parent, ".cache", "1-nn.pkl"))
|
100 |
+
|
101 |
+
return train_test
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
# make a train/eval argparser
|
106 |
+
import argparse
|
107 |
+
|
108 |
+
parser = argparse.ArgumentParser()
|
109 |
+
parser.add_argument("--id", type=int, default=1) # maybe need to remove/refactor
|
110 |
+
parser.add_argument("--batch_size", type=int, default=512)
|
111 |
+
parser.add_argument(
|
112 |
+
"--annotation_file", type=str, required=False, default="train.csv"
|
113 |
+
)
|
114 |
+
parser.add_argument("--name", type=str, default="openai/clip-vit-base-patch32")
|
115 |
+
parser.add_argument("--features_parent", type=str, default="faiss/")
|
116 |
+
parser.add_argument("--data_parent", type=str, default="data/")
|
117 |
+
parser.add_argument("--test", action="store_true")
|
118 |
+
|
119 |
+
args = parser.parse_args()
|
120 |
+
args.features_parent = join(args.features_parent, args.name)
|
121 |
+
if args.test:
|
122 |
+
csv_file = join(args.data_parent, "test.csv")
|
123 |
+
data_dir = join(args.data_parent, "test")
|
124 |
+
path = join(args.features_parent, "features-test")
|
125 |
+
model = get_backbone(args.name)
|
126 |
+
compute_features(path, data_dir, csv_file, tag="id", args=args)
|
127 |
+
train_test = get_results(args, train_test)
|
128 |
+
|
129 |
+
from collections import Counter
|
130 |
+
|
131 |
+
N, pos = Counter(), Counter()
|
132 |
+
for train_gps, test_gps in tqdm(train_test, desc="Computing accuracy..."):
|
133 |
+
get_match_values(train_gps, test_gps, N, pos)
|
134 |
+
|
135 |
+
for train_gps, test_gps in tqdm(train_test, desc="Computing haversine..."):
|
136 |
+
haversine(train_gps, test_gps, N, pos)
|
137 |
+
|
138 |
+
compute_print_accuracy(N, pos)
|
139 |
+
else:
|
140 |
+
csv_file = join(args.data_parent, args.annotation_file)
|
141 |
+
path = join(args.features_parent, f"features-{args.id}")
|
142 |
+
data_dir = join(args.data_parent, f"images-{args.id}", "train")
|
143 |
+
compute_features(path, data_dir, csv_file, tag="image_id", args=args)
|
scripts/retrieval/street-clip-zero-shot.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import PIL
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import operator
|
10 |
+
import joblib
|
11 |
+
import reverse_geocoder
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
from itertools import cycle
|
15 |
+
from tqdm.auto import tqdm, trange
|
16 |
+
from os.path import join
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
from tqdm import tqdm
|
20 |
+
from collections import Counter
|
21 |
+
from transformers import CLIPProcessor, CLIPModel
|
22 |
+
from torch.utils.data import Dataset, DataLoader
|
23 |
+
from torch.nn import functional as F
|
24 |
+
from utils import haversine
|
25 |
+
|
26 |
+
|
27 |
+
class GeoDataset(Dataset):
|
28 |
+
def __init__(self, image_folder, annotation_file, tag="image_id"):
|
29 |
+
self.image_folder = image_folder
|
30 |
+
gt = pd.read_csv(annotation_file, dtype={tag: str})
|
31 |
+
files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)])
|
32 |
+
gt = gt[gt[tag].isin(files)]
|
33 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
34 |
+
self.gt = [
|
35 |
+
(g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows()
|
36 |
+
]
|
37 |
+
self.tag = tag
|
38 |
+
|
39 |
+
def fid(self, i):
|
40 |
+
return self.gt[i][0]
|
41 |
+
|
42 |
+
def latlon(self, i):
|
43 |
+
return self.gt[i][1]
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.gt)
|
47 |
+
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
fp = join(self.image_folder, self.gt[idx][0] + ".jpg")
|
50 |
+
pil = PIL.Image.open(fp)
|
51 |
+
proc = self.processor(images=pil, return_tensors="pt")
|
52 |
+
proc["image_id"] = self.gt[idx][0]
|
53 |
+
return proc
|
54 |
+
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def compute_features_clip(img, model):
|
58 |
+
image_ids = img.data.pop("image_id")
|
59 |
+
image_input = img.to(model.device)
|
60 |
+
image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
|
61 |
+
features = model.get_image_features(**image_input)
|
62 |
+
features /= features.norm(dim=-1, keepdim=True)
|
63 |
+
return image_ids, features.cpu()
|
64 |
+
|
65 |
+
|
66 |
+
def get_prompts(country, region, sub_region, city):
|
67 |
+
a = country if country != "" else None
|
68 |
+
b, c, d = None, None, None
|
69 |
+
if a is not None:
|
70 |
+
b = country + ", " + region if region != "" else None
|
71 |
+
if b is not None:
|
72 |
+
c = (
|
73 |
+
country + ", " + region + ", " + sub_region
|
74 |
+
if sub_region != ""
|
75 |
+
else None
|
76 |
+
)
|
77 |
+
d = (
|
78 |
+
country + ", " + region + ", " + sub_region + ", " + city
|
79 |
+
if city != ""
|
80 |
+
else None
|
81 |
+
)
|
82 |
+
return a, b, c, d
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
# make a train/eval argparser
|
87 |
+
import argparse
|
88 |
+
|
89 |
+
parser = argparse.ArgumentParser()
|
90 |
+
parser.add_argument(
|
91 |
+
"--annotation_file", type=str, required=False, default="train.csv"
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--features_parent", type=str, default="/home/isig/gaia-v2/faiss/street-clip"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--data_parent", type=str, default="/home/isig/gaia-v2/loic-data/"
|
98 |
+
)
|
99 |
+
|
100 |
+
args = parser.parse_args()
|
101 |
+
test_path_csv = join(args.data_parent, "test.csv")
|
102 |
+
test_image_dir = join(args.data_parent, "test")
|
103 |
+
save_path = join(args.features_parent, "indexes/test.index")
|
104 |
+
test_features_dir = join(args.features_parent, "indexes/features-test")
|
105 |
+
|
106 |
+
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
|
107 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
108 |
+
model = CLIPModel.from_pretrained("geolocal/StreetCLIP").to(device)
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def compute_text_features_clip(text):
|
112 |
+
text_pt = processor(text=text, return_tensors="pt").to(device)
|
113 |
+
features = model.get_text_features(**text_pt)
|
114 |
+
features /= features.norm(dim=-1, keepdim=True)
|
115 |
+
return features.cpu().squeeze(0).numpy()
|
116 |
+
|
117 |
+
import country_converter as coco
|
118 |
+
|
119 |
+
if not os.path.isfile("text_street-clip-features.pkl"):
|
120 |
+
if not os.path.isfile("rg_cities1000.csv"):
|
121 |
+
os.system(
|
122 |
+
"wget https://raw.githubusercontent.com/thampiman/reverse-geocoder/master/reverse_geocoder/rg_cities1000.csv"
|
123 |
+
)
|
124 |
+
|
125 |
+
cities = pd.read_csv("rg_cities1000.csv")
|
126 |
+
cities = cities[["lat", "lon", "name", "admin1", "admin2", "cc"]]
|
127 |
+
reprs = {0: {}, 1: {}, 2: {}, 3: {}}
|
128 |
+
for line in tqdm(
|
129 |
+
cities.iterrows(), total=len(cities), desc="Creating hierarchy"
|
130 |
+
):
|
131 |
+
lat, lon, city, region, sub_region, cc = line[1]
|
132 |
+
try:
|
133 |
+
city, region, sub_region, cc = [
|
134 |
+
("" if pd.isna(x) else x)
|
135 |
+
for x in [
|
136 |
+
city,
|
137 |
+
region,
|
138 |
+
sub_region,
|
139 |
+
coco.convert(cc, to="name_short"),
|
140 |
+
]
|
141 |
+
]
|
142 |
+
a, b, c, d = get_prompts(cc, region, sub_region, city)
|
143 |
+
if a is not None:
|
144 |
+
if a not in reprs[0]:
|
145 |
+
reprs[0][a] = {
|
146 |
+
"gps": {(lat, lon)},
|
147 |
+
"embedding": compute_text_features_clip(a),
|
148 |
+
}
|
149 |
+
else:
|
150 |
+
reprs[0][a]["gps"].add((lat, lon))
|
151 |
+
|
152 |
+
if b is not None:
|
153 |
+
if b not in reprs[1]:
|
154 |
+
reprs[1][b] = {
|
155 |
+
"gps": {(lat, lon)},
|
156 |
+
"embedding": compute_text_features_clip(b),
|
157 |
+
}
|
158 |
+
else:
|
159 |
+
reprs[1][b]["gps"].add((lat, lon))
|
160 |
+
|
161 |
+
if c is not None:
|
162 |
+
if c not in reprs[2]:
|
163 |
+
reprs[2][c] = {
|
164 |
+
"gps": {(lat, lon)},
|
165 |
+
"embedding": compute_text_features_clip(c),
|
166 |
+
}
|
167 |
+
else:
|
168 |
+
reprs[2][c]["gps"].add((lat, lon))
|
169 |
+
|
170 |
+
if d is not None:
|
171 |
+
if d not in reprs[3]:
|
172 |
+
reprs[3][d] = {
|
173 |
+
"gps": {(lat, lon)},
|
174 |
+
"embedding": compute_text_features_clip(
|
175 |
+
d.replace(", , ", ", ")
|
176 |
+
),
|
177 |
+
}
|
178 |
+
else:
|
179 |
+
reprs[3][d]["gps"].add((lat, lon))
|
180 |
+
except Exception as e:
|
181 |
+
# print stack trace into file log.txt
|
182 |
+
with open("log.txt", "a") as f:
|
183 |
+
print(traceback.format_exc(), file=f)
|
184 |
+
|
185 |
+
reprs[-1] = {"": {"gps": (0, 0), "embedding": compute_text_features_clip("")}}
|
186 |
+
|
187 |
+
# compute mean for gps of all 'a' and 'b' and 'c' and 'd'
|
188 |
+
for i in range(4):
|
189 |
+
for k in reprs[i].keys():
|
190 |
+
reprs[i][k]["gps"] = tuple(
|
191 |
+
np.array(list(reprs[i][k]["gps"])).mean(axis=0).tolist()
|
192 |
+
)
|
193 |
+
|
194 |
+
joblib.dump(reprs, "text_street-clip-features.pkl")
|
195 |
+
else:
|
196 |
+
reprs = joblib.load("text_street-clip-features.pkl")
|
197 |
+
|
198 |
+
def get_loc(x):
|
199 |
+
location = reverse_geocoder.search(x[0].tolist())[0]
|
200 |
+
country = coco.convert(names=location["cc"], to="name_short")
|
201 |
+
region = location.get("admin1", "")
|
202 |
+
sub_region = location.get("admin2", "")
|
203 |
+
city = location.get("name", "")
|
204 |
+
a, b, c, d = get_prompts(country, region, sub_region, city)
|
205 |
+
return a, b, c, d
|
206 |
+
|
207 |
+
def matches(embed, repr, control, gt, sw=None):
|
208 |
+
first_max = max(
|
209 |
+
(
|
210 |
+
(k, embed.dot(v["embedding"]))
|
211 |
+
for k, v in repr.items()
|
212 |
+
if sw is None or k.startswith(sw)
|
213 |
+
),
|
214 |
+
key=operator.itemgetter(1),
|
215 |
+
)
|
216 |
+
if first_max[1] > embed.dot(control["embedding"]):
|
217 |
+
return repr[first_max[0]]["gps"], gt == first_max[0]
|
218 |
+
else:
|
219 |
+
return control["gps"], False
|
220 |
+
|
221 |
+
def get_match_values(gt, embed, N, pos):
|
222 |
+
xa, xb, xc, xd = get_loc(gt)
|
223 |
+
|
224 |
+
if xa is not None:
|
225 |
+
N["country"] += 1
|
226 |
+
gps, flag = matches(embed, reprs[0], reprs[-1][""], xa)
|
227 |
+
if flag:
|
228 |
+
pos["country"] += 1
|
229 |
+
if xb is not None:
|
230 |
+
N["region"] += 1
|
231 |
+
gps, flag = matches(embed, reprs[1], reprs[0][xa], xb, sw=xa)
|
232 |
+
if flag:
|
233 |
+
pos["region"] += 1
|
234 |
+
if xc is not None:
|
235 |
+
N["sub-region"] += 1
|
236 |
+
gps, flag = matches(
|
237 |
+
embed, reprs[2], reprs[1][xb], xc, sw=xb
|
238 |
+
)
|
239 |
+
if flag:
|
240 |
+
pos["sub-region"] += 1
|
241 |
+
if xd is not None:
|
242 |
+
N["city"] += 1
|
243 |
+
gps, flag = matches(
|
244 |
+
embed, reprs[3], reprs[2][xc], xd, sw=xc
|
245 |
+
)
|
246 |
+
if flag:
|
247 |
+
pos["city"] += 1
|
248 |
+
else:
|
249 |
+
if xd is not None:
|
250 |
+
N["city"] += 1
|
251 |
+
gps, flag = matches(
|
252 |
+
embed, reprs[3], reprs[1][xb], xd, sw=xb + ", "
|
253 |
+
)
|
254 |
+
if flag:
|
255 |
+
pos["city"] += 1
|
256 |
+
|
257 |
+
haversine(np.array(gps)[None, :], np.array(gt), N, pos)
|
258 |
+
|
259 |
+
def compute_print_accuracy(N, pos):
|
260 |
+
for k in N.keys():
|
261 |
+
pos[k] /= N[k]
|
262 |
+
|
263 |
+
# pretty-print accuracy in percentage with 2 floating points
|
264 |
+
print(
|
265 |
+
f'Accuracy: {pos["country"]*100.0:.2f} (country), {pos["region"]*100.0:.2f} (region), {pos["sub-region"]*100.0:.2f} (sub-region), {pos["city"]*100.0:.2f} (city)'
|
266 |
+
)
|
267 |
+
print(
|
268 |
+
f'Haversine: {pos["haversine"]:.2f} (haversine), {pos["geoguessr"]:.2f} (geoguessr)'
|
269 |
+
)
|
270 |
+
|
271 |
+
import joblib
|
272 |
+
|
273 |
+
data = GeoDataset(test_image_dir, test_path_csv, tag="id")
|
274 |
+
test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[
|
275 |
+
["id", "latitude", "longitude"]
|
276 |
+
]
|
277 |
+
test_gt = {
|
278 |
+
g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
|
279 |
+
for g in tqdm(test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt")
|
280 |
+
}
|
281 |
+
|
282 |
+
with open("/home/isig/gaia-v2/loic/plonk/test3_indices.txt", "r") as f:
|
283 |
+
# read lines
|
284 |
+
lines = f.readlines()
|
285 |
+
# remove whitespace characters like `\n` at the end of each line
|
286 |
+
lines = [l.strip() for l in lines]
|
287 |
+
# and convert to set
|
288 |
+
lines = set(lines)
|
289 |
+
|
290 |
+
train_test = []
|
291 |
+
N, pos = Counter(), Counter()
|
292 |
+
for f in tqdm(os.listdir(test_features_dir)):
|
293 |
+
if f.replace(".npy", "") not in lines:
|
294 |
+
continue
|
295 |
+
query_vector = np.squeeze(np.load(join(test_features_dir, f)))
|
296 |
+
test_gps = test_gt[f.replace(".npy", "")][None, :]
|
297 |
+
get_match_values(test_gps, query_vector, N, pos)
|
298 |
+
|
299 |
+
compute_print_accuracy(N, pos)
|
scripts/retrieval/utils.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import reverse_geocoder
|
4 |
+
|
5 |
+
|
6 |
+
def get_loc(x):
|
7 |
+
location = reverse_geocoder.search(x[0].tolist())[0]
|
8 |
+
country = location.get("cc", "")
|
9 |
+
region = location.get("admin1", "")
|
10 |
+
sub_region = location.get("admin2", "")
|
11 |
+
city = location.get("name", "")
|
12 |
+
|
13 |
+
a = country if country != "" else None
|
14 |
+
b, c, d = None, None, None
|
15 |
+
if a is not None:
|
16 |
+
b = country + "," + region if region != "" else None
|
17 |
+
if b is not None:
|
18 |
+
c = country + "," + region + "," + sub_region if sub_region != "" else None
|
19 |
+
d = (
|
20 |
+
country + "," + region + "," + sub_region + "," + city
|
21 |
+
if city != ""
|
22 |
+
else None
|
23 |
+
)
|
24 |
+
|
25 |
+
return a, b, c, d
|
26 |
+
|
27 |
+
|
28 |
+
def get_match_values(pred, gt, N, pos):
|
29 |
+
xa, xb, xc, xd = get_loc(gt)
|
30 |
+
ya, yb, yc, yd = get_loc(pred)
|
31 |
+
|
32 |
+
if xa is not None:
|
33 |
+
N["country"] += 1
|
34 |
+
if xa == ya:
|
35 |
+
pos["country"] += 1
|
36 |
+
if xb is not None:
|
37 |
+
N["region"] += 1
|
38 |
+
if xb == yb:
|
39 |
+
pos["region"] += 1
|
40 |
+
if xc is not None:
|
41 |
+
N["sub-region"] += 1
|
42 |
+
if xc == yc:
|
43 |
+
pos["sub-region"] += 1
|
44 |
+
if xd is not None:
|
45 |
+
N["city"] += 1
|
46 |
+
if xd == yd:
|
47 |
+
pos["city"] += 1
|
48 |
+
|
49 |
+
|
50 |
+
def compute_print_accuracy(N, pos):
|
51 |
+
for k in N.keys():
|
52 |
+
pos[k] /= N[k]
|
53 |
+
|
54 |
+
# pretty-print accuracy in percentage with 2 floating points
|
55 |
+
print(
|
56 |
+
f'Accuracy: {pos["country"]*100.0:.2f} (country), {pos["region"]*100.0:.2f} (region), {pos["sub-region"]*100.0:.2f} (sub-region), {pos["city"]*100.0:.2f} (city)'
|
57 |
+
)
|
58 |
+
print(
|
59 |
+
f'Haversine: {pos["haversine"]:.2f} (haversine), {pos["geoguessr"]:.2f} (geoguessr)'
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def get_filenames(idx):
|
64 |
+
from autofaiss import build_index
|
65 |
+
|
66 |
+
path = join(args.features_parent, f"features-{idx}/")
|
67 |
+
files = [f for f in os.listdir(path)]
|
68 |
+
full_files = [join(path, f) for f in os.listdir(path)]
|
69 |
+
index = build_index(
|
70 |
+
embeddings=np.concatenate([np.load(f) for f in tqdm(full_files)], axis=0),
|
71 |
+
nb_cores=12,
|
72 |
+
save_on_disk=False,
|
73 |
+
)[0]
|
74 |
+
return index, files
|
75 |
+
|
76 |
+
|
77 |
+
def normalize(x):
|
78 |
+
lat, lon = x[:, 0], x[:, 1]
|
79 |
+
"""Used to put all lat lon inside ±90 and ±180."""
|
80 |
+
lat = (lat + 90) % 360 - 90
|
81 |
+
if lat > 90:
|
82 |
+
lat = 180 - lat
|
83 |
+
lon += 180
|
84 |
+
lon = (lon + 180) % 360 - 180
|
85 |
+
return np.stack([lat, lon], axis=1)
|
86 |
+
|
87 |
+
|
88 |
+
def haversine(pred, gt, N, p):
|
89 |
+
# expects inputs to be np arrays in (lat, lon) format as radians
|
90 |
+
# N x 2
|
91 |
+
pred = np.radians(normalize(pred))
|
92 |
+
gt = np.radians(normalize(gt))
|
93 |
+
|
94 |
+
# calculate the difference in latitude and longitude between the predicted and ground truth points
|
95 |
+
lat_diff = pred[:, 0] - gt[:, 0]
|
96 |
+
lon_diff = pred[:, 1] - gt[:, 1]
|
97 |
+
|
98 |
+
# calculate the haversine formula components
|
99 |
+
lhs = np.sin(lat_diff / 2) ** 2
|
100 |
+
rhs = np.cos(pred[:, 0]) * np.cos(gt[:, 0]) * np.sin(lon_diff / 2) ** 2
|
101 |
+
a = lhs + rhs
|
102 |
+
|
103 |
+
# calculate the final distance using the haversine formula
|
104 |
+
c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
|
105 |
+
|
106 |
+
haversine_distance = 6371 * c[0]
|
107 |
+
geoguessr_sum = 5000 * np.exp(-haversine_distance / 1492.7)
|
108 |
+
|
109 |
+
N["geoguessr"] += 1
|
110 |
+
p["geoguessr"] += geoguessr_sum
|
111 |
+
|
112 |
+
N["haversine"] += 1
|
113 |
+
p["haversine"] += haversine_distance
|
utils/__init__.py
ADDED
File without changes
|
utils/image_processing.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torchvision
|
4 |
+
|
5 |
+
|
6 |
+
def remap_image_torch(image):
|
7 |
+
image_torch = ((image + 1) / 2.0) * 255.0
|
8 |
+
image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8)
|
9 |
+
return image_torch
|
10 |
+
|
11 |
+
|
12 |
+
class CenterCrop(torch.nn.Module):
|
13 |
+
"""Crops the given image at the center. Allows to crop to the maximum possible size.
|
14 |
+
Args:
|
15 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
16 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
17 |
+
made.
|
18 |
+
ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, size=None, ratio="1:1"):
|
22 |
+
super().__init__()
|
23 |
+
self.size = size
|
24 |
+
self.ratio = ratio
|
25 |
+
|
26 |
+
def forward(self, img):
|
27 |
+
"""
|
28 |
+
Args:
|
29 |
+
img (PIL Image or Tensor): Image to be cropped.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
PIL Image or Tensor: Cropped image.
|
33 |
+
"""
|
34 |
+
if self.size is None:
|
35 |
+
if isinstance(img, torch.Tensor):
|
36 |
+
h, w = img.shape[-2:]
|
37 |
+
else:
|
38 |
+
w, h = img.size
|
39 |
+
ratio = self.ratio.split(":")
|
40 |
+
ratio = float(ratio[0]) / float(ratio[1])
|
41 |
+
ratioed_w = int(h * ratio)
|
42 |
+
ratioed_h = int(w / ratio)
|
43 |
+
if w >= h:
|
44 |
+
if ratioed_h <= h:
|
45 |
+
size = (ratioed_h, w)
|
46 |
+
else:
|
47 |
+
size = (h, ratioed_w)
|
48 |
+
else:
|
49 |
+
if ratioed_w <= w:
|
50 |
+
size = (h, ratioed_w)
|
51 |
+
else:
|
52 |
+
size = (ratioed_h, w)
|
53 |
+
else:
|
54 |
+
size = self.size
|
55 |
+
return torchvision.transforms.functional.center_crop(img, size)
|
56 |
+
|
57 |
+
def __repr__(self) -> str:
|
58 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
utils/lr_scheduler.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
|
4 |
+
class WarmupLR:
|
5 |
+
"""
|
6 |
+
Linear Warmup learning rate scheduler. After warmup, learning rate is
|
7 |
+
constant.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
optimizer (torch.optim.Optimizer): optimizer
|
11 |
+
warmup_steps (int): number of warmup steps
|
12 |
+
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, optimizer, warmup_steps):
|
16 |
+
self.optimizer = optimizer
|
17 |
+
self.warmup_steps = warmup_steps
|
18 |
+
self.base_lr = None
|
19 |
+
|
20 |
+
def get_lr(self, lr, step):
|
21 |
+
return lr * min(step / max(self.warmup_steps, 1), 1.0)
|
22 |
+
|
23 |
+
def step(self, step):
|
24 |
+
if self.base_lr is None:
|
25 |
+
self.base_lr = [
|
26 |
+
param_group["lr"] for param_group in self.optimizer.param_groups
|
27 |
+
]
|
28 |
+
for param_group, base_lr_group in zip(
|
29 |
+
self.optimizer.param_groups, self.base_lr
|
30 |
+
):
|
31 |
+
param_group["lr"] = self.get_lr(base_lr_group, step)
|
32 |
+
|
33 |
+
def state_dict(self):
|
34 |
+
return {
|
35 |
+
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
36 |
+
}
|
37 |
+
|
38 |
+
def load_state_dict(self, state_dict):
|
39 |
+
self.__dict__.update(state_dict)
|
40 |
+
|
41 |
+
|
42 |
+
class WarmupCosineDecayLR:
|
43 |
+
"""
|
44 |
+
Linear Warmup learning rate scheduler. After warmup, learning rate is
|
45 |
+
constant.
|
46 |
+
After warmup, learning rate follows a cosine decay.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
optimizer (torch.optim.Optimizer): optimizer
|
50 |
+
warmup_steps (int): number of warmup steps
|
51 |
+
total_steps (int): total number of steps
|
52 |
+
rate (float): cosine decay rate
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, optimizer, warmup_steps, total_steps, rate=1.0):
|
56 |
+
self.optimizer = optimizer
|
57 |
+
self.warmup_steps = warmup_steps
|
58 |
+
self.base_lr = None
|
59 |
+
self.total_steps = total_steps
|
60 |
+
self.rate = rate
|
61 |
+
|
62 |
+
def get_lr(self, lr, step):
|
63 |
+
if step < self.warmup_steps:
|
64 |
+
return lr * min(step / max(self.warmup_steps, 1), 1.0)
|
65 |
+
else:
|
66 |
+
return (
|
67 |
+
0.5
|
68 |
+
* lr
|
69 |
+
* (
|
70 |
+
1
|
71 |
+
+ math.cos(
|
72 |
+
self.rate
|
73 |
+
* math.pi
|
74 |
+
* (step - self.warmup_steps)
|
75 |
+
/ (self.total_steps - self.warmup_steps)
|
76 |
+
)
|
77 |
+
)
|
78 |
+
)
|
79 |
+
|
80 |
+
def step(self, step):
|
81 |
+
if self.base_lr is None:
|
82 |
+
self.base_lr = [
|
83 |
+
param_group["lr"] for param_group in self.optimizer.param_groups
|
84 |
+
]
|
85 |
+
for param_group, base_lr_group in zip(
|
86 |
+
self.optimizer.param_groups, self.base_lr
|
87 |
+
):
|
88 |
+
param_group["lr"] = self.get_lr(base_lr_group, step)
|
89 |
+
|
90 |
+
def state_dict(self):
|
91 |
+
return {
|
92 |
+
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
93 |
+
}
|
94 |
+
|
95 |
+
def load_state_dict(self, state_dict):
|
96 |
+
self.__dict__.update(state_dict)
|
utils/model_utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def print_trainable_parameters(model):
|
2 |
+
"""
|
3 |
+
Prints the number and percentage of trainable parameters in the model.
|
4 |
+
Useful for tracking % parameters trained for LoRA.
|
5 |
+
"""
|
6 |
+
trainable_params = 0
|
7 |
+
all_param = 0
|
8 |
+
for _, param in model.named_parameters():
|
9 |
+
all_param += param.numel()
|
10 |
+
if param.requires_grad:
|
11 |
+
trainable_params += param.numel()
|
12 |
+
print(
|
13 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
14 |
+
)
|
utils/quadtree_10_1000.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|