yunusserhat commited on
Commit
94f372a
1 Parent(s): abd15df

Upload 40 files

Browse files
metrics/__init__.py ADDED
File without changes
metrics/distance_based.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from metrics.utils import haversine, reverse
4
+
5
+ from torchmetrics import Metric
6
+
7
+
8
+ class HaversineMetrics(Metric):
9
+ """
10
+ Computes the average haversine distance between the predicted and ground truth points.
11
+ Compute the accuracy given some radiuses.
12
+ Compute the Geoguessr score given some radiuses.
13
+
14
+ Args:
15
+ acc_radiuses (list): list of radiuses to compute the accuracy from
16
+ acc_area (list): list of areas to compute the accuracy from.
17
+ acc_data (list): list of auxilliary data to compute the accuracy from.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ acc_radiuses=[],
23
+ acc_area=["country", "region", "sub-region", "city"],
24
+ aux_data=[],
25
+ ):
26
+ super().__init__()
27
+ self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
28
+ self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
29
+ for acc in acc_radiuses:
30
+ self.add_state(
31
+ f"close_enough_points_{acc}",
32
+ default=torch.tensor(0.0),
33
+ dist_reduce_fx="sum",
34
+ )
35
+ for acc in acc_area:
36
+ self.add_state(
37
+ f"close_enough_points_{acc}",
38
+ default=torch.tensor(0.0),
39
+ dist_reduce_fx="sum",
40
+ )
41
+ self.add_state(
42
+ f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum"
43
+ )
44
+ self.acc_radius = acc_radiuses
45
+ self.acc_area = acc_area
46
+ self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
47
+ self.aux = len(aux_data) > 0
48
+ self.aux_list = aux_data
49
+ if self.aux:
50
+ self.aux_count = {}
51
+ for col in self.aux_list:
52
+ self.add_state(
53
+ f"aux_{col}",
54
+ default=torch.tensor(0.0),
55
+ dist_reduce_fx="sum",
56
+ )
57
+
58
+ def update(self, pred, gt):
59
+ haversine_distance = haversine(pred["gps"], gt["gps"])
60
+ for acc in self.acc_radius:
61
+ self.__dict__[f"close_enough_points_{acc}"] += (
62
+ haversine_distance < acc
63
+ ).sum()
64
+ if len(self.acc_area) > 0:
65
+ area_pred, area_gt = reverse(pred["gps"], gt, self.acc_area)
66
+ for acc in self.acc_area:
67
+ self.__dict__[f"close_enough_points_{acc}"] += (
68
+ area_pred[acc] == area_gt["_".join(["unique", acc])]
69
+ ).sum()
70
+ self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])])
71
+ self.haversine_sum += haversine_distance.sum()
72
+ self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum()
73
+
74
+ if self.aux:
75
+ if "land_cover" in self.aux_list:
76
+ col = "land_cover"
77
+ self.__dict__[f"aux_{col}"] += (
78
+ pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
79
+ ).sum()
80
+ if "road_index" in self.aux_list:
81
+ col = "road_index"
82
+ self.__dict__[f"aux_{col}"] += (
83
+ pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
84
+ ).sum()
85
+ if "drive_side" in self.aux_list:
86
+ col = "drive_side"
87
+ self.__dict__[f"aux_{col}"] += (
88
+ (pred[col] > 0.5).float() == gt[col]
89
+ ).sum()
90
+ if "climate" in self.aux_list:
91
+ col = "climate"
92
+ self.__dict__[f"aux_{col}"] += (
93
+ pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
94
+ ).sum()
95
+ if "soil" in self.aux_list:
96
+ col = "soil"
97
+ self.__dict__[f"aux_{col}"] += (
98
+ pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
99
+ ).sum()
100
+ if "dist_sea" in self.aux_list:
101
+ col = "dist_sea"
102
+ self.__dict__[f"aux_{col}"] += (
103
+ (pred[col] - gt[col]).pow(2).sum(dim=1).sum()
104
+ )
105
+
106
+ self.count += pred["gps"].shape[0]
107
+
108
+ def compute(self):
109
+ output = {
110
+ "Haversine": self.haversine_sum / self.count,
111
+ "Geoguessr": self.geoguessr_sum / self.count,
112
+ }
113
+ for acc in self.acc_radius:
114
+ output[f"Accuracy_{acc}_km_radius"] = (
115
+ self.__dict__[f"close_enough_points_{acc}"] / self.count
116
+ )
117
+ for acc in self.acc_area:
118
+ output[f"Accuracy_{acc}"] = (
119
+ self.__dict__[f"close_enough_points_{acc}"]
120
+ / self.__dict__[f"count_{acc}"]
121
+ )
122
+
123
+ if self.aux:
124
+ for col in self.aux_list:
125
+ output["_".join(["Accuracy", col])] = (
126
+ self.__dict__[f"aux_{col}"] / self.count
127
+ )
128
+
129
+ return output
metrics/elo.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from metrics.utils import haversine
4
+
5
+ from torchmetrics import Metric
6
+
7
+
8
+ class HaversineELOMetric(Metric):
9
+ """
10
+ Computes the ELO score of the current network given previous players
11
+
12
+ Args:
13
+ previous_players_scores (str): path to the csv containing the scores of the previous players
14
+ previous_players_predictions (str): path to the folder containing the predictions of the previous players
15
+ tag (str): tag of the current experiment
16
+
17
+ """
18
+
19
+ def __init__(self, cache_folder, tag):
20
+ ### TODO
21
+ pass
metrics/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import reverse_geocoder
3
+ import numpy as np
4
+
5
+
6
+ def haversine(pred, gt):
7
+ # expects inputs to be np arrays in (lat, lon) format as radians
8
+ # N x 2
9
+
10
+ # calculate the difference in latitude and longitude between the predicted and ground truth points
11
+ lat_diff = pred[:, 0] - gt[:, 0]
12
+ lon_diff = pred[:, 1] - gt[:, 1]
13
+
14
+ # calculate the haversine formula components
15
+ lhs = torch.sin(lat_diff / 2) ** 2
16
+ rhs = torch.cos(pred[:, 0]) * torch.cos(gt[:, 0]) * torch.sin(lon_diff / 2) ** 2
17
+ a = lhs + rhs
18
+
19
+ # calculate the final distance using the haversine formula
20
+ c = 2 * torch.arctan2(torch.sqrt(a), torch.sqrt(1 - a))
21
+ distance = 6371 * c
22
+
23
+ return distance
24
+
25
+
26
+ def reverse(pred, gt, area):
27
+ df = {}
28
+ gt_area = {}
29
+ nan_mask = {}
30
+ areas = ["_".join(["unique", ar]) for ar in area]
31
+ if "unique_continent" in areas:
32
+ areas.remove("unique_continent")
33
+ for ar in areas:
34
+ inter = np.array(gt[ar])
35
+ nan_mask[ar] = inter != "nan"
36
+ gt_area[ar] = inter[nan_mask[ar]]
37
+ location = reverse_geocoder.search(
38
+ [
39
+ (lat, lon)
40
+ for lat, lon in zip(
41
+ np.degrees(pred[:, 0].cpu()), np.degrees(pred[:, 1].cpu())
42
+ )
43
+ ]
44
+ )
45
+ if "continent" in area:
46
+ continent = torch.load("continent.pt")
47
+ inter = np.array([l.get("cc", "") for l in location])[
48
+ nan_mask["unique_country"]
49
+ ]
50
+ df["continent"] = np.array([continent[i] for i in inter])
51
+ gt_area["unique_continent"] = np.array(
52
+ [continent[i] for i in gt_area["unique_country"]]
53
+ )
54
+
55
+ if "country" in area:
56
+ df["country"] = np.array([l.get("cc", "") for l in location])[
57
+ nan_mask["unique_country"]
58
+ ]
59
+ if "region" in area:
60
+ df["region"] = np.array(
61
+ ["_".join([l.get("admin1", ""), l.get("cc", "")]) for l in location]
62
+ )[nan_mask["unique_region"]]
63
+ if "sub-region" in area:
64
+ df["sub-region"] = np.array(
65
+ [
66
+ "_".join([l.get("admin2", ""), l.get("admin1", ""), l.get("cc", "")])
67
+ for l in location
68
+ ]
69
+ )[nan_mask["unique_sub-region"]]
70
+ if "city" in area:
71
+ df["city"] = np.array(
72
+ [
73
+ "_".join(
74
+ [
75
+ l.get("name", ""),
76
+ l.get("admin2", ""),
77
+ l.get("admin1", ""),
78
+ l.get("cc", ""),
79
+ ]
80
+ )
81
+ for l in location
82
+ ]
83
+ )[nan_mask["unique_city"]]
84
+
85
+ return df, gt_area
models/__init__.py ADDED
File without changes
models/classification/utils_global.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import OrderedDict
3
+ from pathlib import Path
4
+ from typing import Union, List
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+
10
+ def check_is_valid_torchvision_architecture(architecture: str):
11
+ """Raises an ValueError if architecture is not part of available torchvision models
12
+ """
13
+ available = sorted(
14
+ name
15
+ for name in torchvision.models.__dict__
16
+ if name.islower()
17
+ and not name.startswith("__")
18
+ and callable(torchvision.models.__dict__[name])
19
+ )
20
+ if architecture not in available:
21
+ raise ValueError(f"{architecture} not in {available}")
22
+
23
+
24
+ def build_base_model(arch: str):
25
+
26
+ model = torchvision.models.__dict__[arch](pretrained=True)
27
+
28
+ # get input dimension before classification layer
29
+ if arch in ["mobilenet_v2"]:
30
+ nfeatures = model.classifier[-1].in_features
31
+ model = torch.nn.Sequential(*list(model.children())[:-1])
32
+ elif arch in ["densenet121", "densenet161", "densenet169"]:
33
+ nfeatures = model.classifier.in_features
34
+ model = torch.nn.Sequential(*list(model.children())[:-1])
35
+ elif "resne" in arch:
36
+ # usually all ResNet variants
37
+ nfeatures = model.fc.in_features
38
+ model = torch.nn.Sequential(*list(model.children())[:-2])
39
+ else:
40
+ raise NotImplementedError
41
+
42
+ model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
43
+ model.flatten = torch.nn.Flatten(start_dim=1)
44
+ return model, nfeatures
45
+
46
+
47
+ def load_weights_if_available(
48
+ model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path]
49
+ ):
50
+
51
+ checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
52
+
53
+ state_dict_features = OrderedDict()
54
+ state_dict_classifier = OrderedDict()
55
+ for k, w in checkpoint["state_dict"].items():
56
+ if k.startswith("model"):
57
+ state_dict_features[k.replace("model.", "")] = w
58
+ elif k.startswith("classifier"):
59
+ state_dict_classifier[k.replace("classifier.", "")] = w
60
+ else:
61
+ logging.warning(f"Unexpected prefix in state_dict: {k}")
62
+ model.load_state_dict(state_dict_features, strict=True)
63
+ return model, classifier
64
+
65
+
66
+ def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt):
67
+ R = 6371
68
+ factor_rad = 0.01745329252
69
+ longitudes = factor_rad * longitudes
70
+ longitudes_gt = factor_rad * longitudes_gt
71
+ latitudes = factor_rad * latitudes
72
+ latitudes_gt = factor_rad * latitudes_gt
73
+ delta_long = longitudes_gt - longitudes
74
+ delta_lat = latitudes_gt - latitudes
75
+ subterm0 = torch.sin(delta_lat / 2) ** 2
76
+ subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt)
77
+ subterm2 = torch.sin(delta_long / 2) ** 2
78
+ subterm1 = subterm1 * subterm2
79
+ a = subterm0 + subterm1
80
+ c = 2 * torch.asin(torch.sqrt(a))
81
+ gcd = R * c
82
+ return gcd
83
+
84
+
85
+ def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]):
86
+ # calculate accuracy for given gcd thresolds
87
+ results = {}
88
+ for thres in thresholds:
89
+ results[thres] = torch.true_divide(
90
+ torch.sum(gc_dists <= thres), len(gc_dists)
91
+ ).item()
92
+ return results
93
+
94
+
95
+ def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)):
96
+ def _accuracy(output, target, topk=(1,)):
97
+ """Computes the accuracy over the k top predictions for the specified values of k"""
98
+ with torch.no_grad():
99
+ maxk = max(topk)
100
+ batch_size = target.size(0)
101
+
102
+ _, pred = output.topk(maxk, 1, True, True)
103
+ pred = pred.t()
104
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
105
+
106
+ res = {}
107
+ for k in topk:
108
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
109
+ res[k] = correct_k / batch_size
110
+ return res
111
+
112
+ with torch.no_grad():
113
+ out_dict = {}
114
+ for i, pname in enumerate(partitioning_shortnames):
115
+ res_dict = _accuracy(output[i], target[i], topk=topk)
116
+ for k, v in res_dict.items():
117
+ out_dict[f"acc{k}_val/{pname}"] = v
118
+
119
+ return out_dict
120
+
121
+
122
+ def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None):
123
+ gcd_dict = {}
124
+ metric_names = [f"gcd_{p}_val" for p in pnames]
125
+ if hierarchy is not None:
126
+ metric_names.append("gcd_hierarchy_val")
127
+ for metric_name in metric_names:
128
+ distances_flat = [output[metric_name] for output in outputs]
129
+ distances_flat = torch.cat(distances_flat, dim=0)
130
+ gcd_results = gcd_threshold_eval(distances_flat)
131
+ for gcd_thres, acc in gcd_results.items():
132
+ gcd_dict[f"{metric_name}/{gcd_thres}"] = acc
133
+ return gcd_dict
134
+
135
+
136
+ def summarize_test_gcd(pnames, outputs, hierarchy=None):
137
+ def _eval(output):
138
+ # calculate acc@km for a list of given thresholds
139
+ accuracy_outputs = {}
140
+ if hierarchy is not None:
141
+ pnames.append("hierarchy")
142
+ for pname in pnames:
143
+ # concat batches of distances
144
+ distances_flat = torch.cat([x[pname] for x in output], dim=0)
145
+ # acc for all distances
146
+ acc_dict = gcd_threshold_eval(distances_flat)
147
+ accuracy_outputs[f"acc_test/{pname}"] = acc_dict
148
+ return accuracy_outputs
149
+
150
+ result = {}
151
+
152
+ if isinstance(outputs[0], dict): # only one testset
153
+ result = _eval(outputs)
154
+ elif isinstance(outputs[0], list): # multiple testsets
155
+ for testset_index, output in enumerate(outputs):
156
+ result[testset_index] = _eval(output)
157
+ else:
158
+ raise TypeError
159
+
160
+ return result
161
+
162
+
163
+ def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]):
164
+
165
+ loss_acc_dict = {}
166
+ metric_names = []
167
+ for k in topk:
168
+ accuracy_names = [f"acc{k}_val/{p}" for p in pnames]
169
+ metric_names.extend(accuracy_names)
170
+ metric_names.extend([f"loss_val/{p}" for p in pnames])
171
+ for metric_name in ["loss_val/total", *metric_names]:
172
+ metric_total = 0
173
+ for output in outputs:
174
+ metric_value = output[metric_name]
175
+ metric_total += metric_value
176
+ loss_acc_dict[metric_name] = metric_total / len(outputs)
177
+ return loss_acc_dict
models/eval_best_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+ import pytorch_lightning as L
4
+ import torch
5
+ from hydra.utils import instantiate
6
+ from models.huggingface import Geolocalizer
7
+
8
+ class EvalModule(L.LightningModule):
9
+ def __init__(self, cfg):
10
+ super().__init__()
11
+ self.cfg = cfg
12
+ os.chdir(cfg.network.root_dir)
13
+ self.model = Geolocalizer.from_pretrained('osv5m/baseline')
14
+ self.test_metrics = instantiate(cfg.test_metrics)
15
+
16
+ def training_step(self, batch, batch_idx):
17
+ pred = self.model(batch)
18
+ pass
19
+
20
+ @torch.no_grad()
21
+ def validation_step(self, batch, batch_idx):
22
+ pred = self.model(batch)
23
+ pass
24
+
25
+ def on_validation_epoch_end(self):
26
+ pass
27
+
28
+ @torch.no_grad()
29
+ def test_step(self, batch, batch_idx):
30
+ pred = self.model.forward_tensor(batch)
31
+ self.test_metrics.update({"gps": pred}, batch)
32
+
33
+ def on_test_epoch_end(self):
34
+ metrics = self.test_metrics.compute()
35
+ for metric_name, metric_value in metrics.items():
36
+ self.log(
37
+ f"test/{metric_name}",
38
+ metric_value,
39
+ sync_dist=True,
40
+ on_step=False,
41
+ on_epoch=True,
42
+ )
43
+
44
+ def lr_scheduler_step(self, scheduler, metric):
45
+ scheduler.step(self.global_step)
46
+
47
+
48
+ def get_parameter_names(model, forbidden_layer_types):
49
+ """
50
+ Returns the names of the model parameters that are not inside a forbidden layer.
51
+ Taken from HuggingFace transformers.
52
+ """
53
+ result = []
54
+ for name, child in model.named_children():
55
+ result += [
56
+ f"{name}.{n}"
57
+ for n in get_parameter_names(child, forbidden_layer_types)
58
+ if not isinstance(child, tuple(forbidden_layer_types))
59
+ ]
60
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
61
+ result += list(model._parameters.keys())
62
+ return result
models/huggingface.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from hydra.utils import instantiate
4
+ from omegaconf import OmegaConf
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+
7
+ class Geolocalizer(nn.Module, PyTorchModelHubMixin):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ self.config = OmegaConf.create(config)
11
+ self.transform = instantiate(self.config.transform)
12
+ self.model = instantiate(self.config.model)
13
+ self.head = self.model.head
14
+ self.mid = self.model.mid
15
+ self.backbone = self.model.backbone
16
+
17
+ def forward(self, img: torch.Tensor):
18
+ output = self.head(self.mid(self.backbone({"img": img})), None)
19
+ return output["gps"]
20
+
21
+ def forward_tensor(self, img: torch.Tensor):
22
+ output = self.head(self.mid(self.backbone(img)), None)
23
+ return output["gps"]
24
+
models/losses.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from os.path import join
6
+ from models.networks.utils import NormGPS
7
+
8
+
9
+ class L1(nn.Module):
10
+ def __init__(self):
11
+ super(L1, self).__init__()
12
+
13
+ def forward(self, x, y):
14
+ """
15
+ Args:
16
+ x: dict that contains "gps": torch.Tensor Bx2
17
+ y: dict that contains "gps": torch.Tensor Bx2
18
+ Returns:
19
+ torch.Tensor: L1 loss between x and y: torch.Tensor([B])
20
+ """
21
+ return {"L1_loss": torch.abs(x["gps"] - y["gps"]).mean(dim=-1)}
22
+
23
+
24
+ class L2(nn.Module):
25
+ def __init__(self):
26
+ super(L2, self).__init__()
27
+
28
+ def forward(self, x, y):
29
+ """
30
+ Args:
31
+ x: dict that contains "gps": torch.Tensor Bx2
32
+ y: dict that contains "gps": torch.Tensor Bx2
33
+ Returns:
34
+ torch.Tensor: L2 loss between x and y: torch.Tensor([B])
35
+ """
36
+ return {"L2_loss": ((x["gps"] - y["gps"]) ** 2).mean(dim=-1)}
37
+
38
+
39
+ class L2Hybrid(nn.Module):
40
+ def __init__(self):
41
+ super(L2Hybrid, self).__init__()
42
+ self.norm = NormGPS()
43
+
44
+ def forward(self, x, y):
45
+ """
46
+ Args:
47
+ x: dict that contains "gps": torch.Tensor Bx2
48
+ y: dict that contains "gps": torch.Tensor Bx2
49
+ Returns:
50
+ torch.Tensor: L2 loss between x and y: torch.Tensor([B])
51
+ """
52
+ return {
53
+ "L2_loss": (
54
+ (x["reg"] - (self.norm(y["gps"]) - x["center"]) * x["size"]) ** 2
55
+ ).mean(dim=-1)
56
+ }
57
+
58
+
59
+ class CrossEntropy(nn.Module):
60
+ def __init__(self):
61
+ super(CrossEntropy, self).__init__()
62
+ self.loss = nn.CrossEntropyLoss(reduction="none")
63
+
64
+ def forward(self, x, y):
65
+ """
66
+ Args:
67
+ x: dict that contains "label": torch.Tensor BxN
68
+ y: dict that contains "label": torch.Tensor BxN
69
+ Returns:
70
+ torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
71
+ """
72
+ return {"cross_entropy_loss": self.loss(x["label"], y["label"])}
73
+
74
+
75
+ class HierarchicalCrossEntropyQuad(nn.Module):
76
+ def __init__(self, data_path=""):
77
+ super(HierarchicalCrossEntropyQuad, self).__init__()
78
+ self.dict_losses = {"classif_loss": nn.CrossEntropyLoss(reduction="none")}
79
+ for i in range(1, 10):
80
+ self.dict_losses[f"quadtree_{i}_loss"] = nn.NLLLoss()
81
+ self.matrixes = torch.load(join(data_path, "quadtree_matrixes.pt"))
82
+ self.dicts = torch.load(join(data_path, "quadtree_dicts.pt"))
83
+ self.id_to_quad = torch.load(join(data_path, "id_to_quad_10_1000.pt"))
84
+
85
+ def forward(self, x, y):
86
+ """
87
+ Args:
88
+ x: dict that contains "label": torch.Tensor BxN
89
+ y: dict that contains "label": torch.Tensor BxN
90
+ Returns:
91
+ torch.Tensor: Hierarchical CrossEntropy for Quadtrees loss between x and y: torch.Tensor([B])
92
+ """
93
+ out = {"classif_loss": self.dict_losses["classif_loss"](x["label"], y["label"])}
94
+ probas = nn.functional.softmax(x["label"], dim=1)
95
+ device = x["label"].device
96
+ gt = self.id_to_quad[y["label"].cpu()]
97
+ for i in range(9):
98
+ logits = torch.log(torch.mm(probas, self.matrixes[i].to(device)) + 1e-10)
99
+ l = [s[: 9 - i] if len(s) >= 10 - i else s for s in gt]
100
+ out[f"quadtree_{i+1}_loss"] = self.dict_losses[f"quadtree_{i+1}_loss"](
101
+ logits, torch.tensor([self.dicts[i][item] for item in l]).to(device)
102
+ )
103
+
104
+ return out
105
+
106
+
107
+ class HierarchicalCrossEntropy(nn.Module):
108
+ def __init__(self, path=""):
109
+ super(HierarchicalCrossEntropy, self).__init__()
110
+ self.city_loss = nn.CrossEntropyLoss(reduction="none")
111
+ self.country_loss = nn.NLLLoss()
112
+ self.area_loss = nn.NLLLoss()
113
+ self.region_loss = nn.NLLLoss()
114
+ self.city_to_country = torch.load(path + "city_to_country.pt")
115
+ self.city_to_region = torch.load(path + "city_to_region.pt")
116
+ self.city_to_area = torch.load(path + "city_to_area.pt")
117
+ self.country_to_idx = torch.load(path + "country_to_idx.pt")
118
+ self.region_to_idx = torch.load(path + "region_to_idx.pt")
119
+ self.area_to_idx = torch.load(path + "area_to_idx.pt")
120
+
121
+ def forward(self, x, y):
122
+ """
123
+ Args:
124
+ x: dict that contains "label": torch.Tensor BxN
125
+ y: dict that contains "label": torch.Tensor BxN
126
+ Returns:
127
+ torch.Tensor: Hierarchical CrossEntropy loss between x and y: torch.Tensor([B])
128
+ """
129
+ country_mask = np.array(y["unique_country"]) != "NaN"
130
+ self.city_to_country = self.city_to_country.to(x["label"].device)
131
+ countries_probas = nn.functional.softmax(x["label"][country_mask], dim=1)
132
+ countries_logits = torch.log(
133
+ torch.mm(countries_probas, self.city_to_country) + 1e-10
134
+ )
135
+ country_gt = torch.tensor(
136
+ [
137
+ self.country_to_idx[item]
138
+ for item in np.array(y["unique_country"])[country_mask]
139
+ ]
140
+ ).to(x["label"].device)
141
+
142
+ region_mask = np.array(y["unique_region"]) != "NaN"
143
+ self.city_to_region = self.city_to_region.to(x["label"].device)
144
+ regions_probas = nn.functional.softmax(x["label"][region_mask], dim=1)
145
+ regions_logits = torch.log(
146
+ torch.mm(regions_probas, self.city_to_region) + 1e-10
147
+ )
148
+ region_gt = torch.tensor(
149
+ [
150
+ self.region_to_idx[item]
151
+ for item in np.array(y["unique_region"])[region_mask]
152
+ ]
153
+ ).to(x["label"].device)
154
+
155
+ area_mask = np.array(y["unique_sub-region"]) != "NaN"
156
+ self.city_to_area = self.city_to_area.to(x["label"].device)
157
+ areas_probas = nn.functional.softmax(x["label"][area_mask], dim=1)
158
+ areas_logits = torch.log(torch.mm(areas_probas, self.city_to_area) + 1e-10)
159
+ area_gt = torch.tensor(
160
+ [
161
+ self.area_to_idx[item]
162
+ for item in np.array(y["unique_sub-region"])[area_mask]
163
+ ]
164
+ ).to(x["label"].device)
165
+
166
+ return {
167
+ "cross_entropy_country_loss": self.country_loss(
168
+ countries_logits, country_gt
169
+ ),
170
+ "cross_entropy_city_loss": self.city_loss(x["label"], y["label"]),
171
+ "cross_entropy_area_loss": self.area_loss(areas_logits, area_gt),
172
+ "cross_entropy_region_loss": self.region_loss(regions_logits, region_gt),
173
+ }
174
+
175
+
176
+ class LandCoverLoss(nn.Module):
177
+ def __init__(self):
178
+ super(LandCoverLoss, self).__init__()
179
+ self.loss = nn.CrossEntropyLoss()
180
+
181
+ def forward(self, x, y):
182
+ """
183
+ Args:
184
+ x: dict that contains "land_cover": torch.Tensor BxN
185
+ y: dict that contains "land_cover": torch.Tensor BxN
186
+ Returns:
187
+ torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
188
+ """
189
+ return {
190
+ "land_cover_cross_entropy_loss": self.loss(x["land_cover"], y["land_cover"])
191
+ }
192
+
193
+
194
+ class RoadIndexLoss(nn.Module):
195
+ def __init__(self):
196
+ super(RoadIndexLoss, self).__init__()
197
+ self.loss = nn.MSELoss()
198
+
199
+ def forward(self, x, y):
200
+ """
201
+ Args:
202
+ x: dict that contains "road_index": torch.Tensor BxN
203
+ y: dict that contains "road_index": torch.Tensor BxN
204
+ Returns:
205
+ torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
206
+ """
207
+ return {"road_index_mse_loss": self.loss(x["road_index"], y["road_index"])}
208
+
209
+
210
+ class DriveSideLoss(nn.Module):
211
+ def __init__(self):
212
+ super(DriveSideLoss, self).__init__()
213
+ self.loss = nn.BCELoss()
214
+
215
+ def forward(self, x, y):
216
+ """
217
+ Args:
218
+ x: dict that contains "drive_side": torch.Tensor BxN
219
+ y: dict that contains "drive_side": torch.Tensor BxN
220
+ Returns:
221
+ torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
222
+ """
223
+ return {"drive_side_bce_loss": self.loss(x["drive_side"], y["drive_side"])}
224
+
225
+
226
+ class ClimateLoss(nn.Module):
227
+ def __init__(self):
228
+ super(ClimateLoss, self).__init__()
229
+ self.loss = nn.CrossEntropyLoss()
230
+
231
+ def forward(self, x, y):
232
+ """
233
+ Args:
234
+ x: dict that contains "climate": torch.Tensor BxN
235
+ y: dict that contains "climate": torch.Tensor BxN
236
+ Returns:
237
+ torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
238
+ """
239
+ return {"climate_cross_entropy_loss": self.loss(x["climate"], y["climate"])}
240
+
241
+
242
+ class SoilLoss(nn.Module):
243
+ def __init__(self):
244
+ super(SoilLoss, self).__init__()
245
+ self.loss = nn.CrossEntropyLoss()
246
+
247
+ def forward(self, x, y):
248
+ """
249
+ Args:
250
+ x: dict that contains "soil": torch.Tensor BxN
251
+ y: dict that contains "soil": torch.Tensor BxN
252
+ Returns:
253
+ torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
254
+ """
255
+ return {"soil_cross_entropy_loss": self.loss(x["soil"], y["soil"])}
256
+
257
+
258
+ class DistSeaLoss(nn.Module):
259
+ def __init__(self):
260
+ super(DistSeaLoss, self).__init__()
261
+ self.loss = nn.MSELoss()
262
+
263
+ def forward(self, x, y):
264
+ """
265
+ Args:
266
+ x: dict that contains "dist_sea": torch.Tensor BxN
267
+ y: dict that contains "dist_sea": torch.Tensor BxN
268
+ Returns:
269
+ torch.Tensor: CrossEntropy loss between x and y: torch.Tensor([B])
270
+ """
271
+ return {"dist_sea_mse_loss": self.loss(x["dist_sea"], y["dist_sea"])}
272
+
273
+
274
+ class Haversine(nn.Module):
275
+ def __init__(self):
276
+ super(Haversine, self).__init__()
277
+
278
+ def forward(self, x, y):
279
+ """
280
+ Args:
281
+ x: dict that contains "gps": torch.Tensor Bx2
282
+ y: dict that contains "gps": torch.Tensor Bx2
283
+ Returns:
284
+ torch.Tensor: Haversine loss between x and y: torch.Tensor([B])
285
+ Note:
286
+ Haversine distance doesn't contain the 2 * 6371 constant.
287
+ """
288
+ x, y = x["gps"], y["gps"]
289
+ lhs = torch.sin((x[:, 0] - y[:, 0]) / 2) ** 2
290
+ rhs = (
291
+ torch.cos(x[:, 0])
292
+ * torch.cos(y[:, 0])
293
+ * torch.sin((x[:, 1] - y[:, 1]) / 2) ** 2
294
+ )
295
+ a = lhs + rhs
296
+ return {
297
+ "haversine_loss": torch.arctan2(torch.sqrt(a), torch.sqrt(1 - a))
298
+ } # ommitting 2 * 6371 as both are a constant
299
+
300
+
301
+ class GeoguessrLoss(Haversine):
302
+ def __init__(self):
303
+ super(GeoguessrLoss, self).__init__()
304
+
305
+ def forward(self, x, y):
306
+ distance = super().forward(x, y)["haversine_loss"]
307
+ loss = torch.exp(-distance / 1852)
308
+ return {"geoguessr_loss": loss}
309
+
310
+
311
+ class InfoNCE(nn.Module):
312
+ def __init__(self, tau=0.1):
313
+ super(InfoNCE, self).__init__()
314
+ self.tau = tau
315
+
316
+ def cosine_similarity(self, a, b, normalize=True):
317
+ if normalize:
318
+ w1 = a.norm(p=2, dim=1, keepdim=True)
319
+ w2 = b.norm(p=2, dim=1, keepdim=True)
320
+ sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
321
+ else:
322
+ sim_matrix = torch.mm(a, b.t())
323
+ return sim_matrix
324
+
325
+ def forward(self, x, y=None):
326
+ """
327
+ neg_sim: BxB
328
+ pos_sim: Bx1
329
+ """
330
+ features = x["features"]
331
+ positive_features = x["pos_features"]
332
+ pos_sim = F.cosine_similarity(
333
+ features, positive_features, dim=1, eps=1e-8
334
+ ).unsqueeze(1)
335
+ neg_sim = self.cosine_similarity(features, features, normalize=True)
336
+
337
+ b = neg_sim.shape[0]
338
+ logits = (1 - torch.eye(b)).type_as(neg_sim) * neg_sim + torch.eye(b).type_as(
339
+ pos_sim
340
+ ) * pos_sim
341
+ logits = logits / self.tau
342
+ labels = torch.arange(b, dtype=torch.long).cuda()
343
+ loss = F.cross_entropy(logits, labels)
344
+ return {
345
+ "contrastive_loss": loss,
346
+ }
347
+
348
+
349
+ class TextNCE(nn.Module):
350
+ def __init__(self, tau=0.1, num_devices=1):
351
+ super(TextNCE, self).__init__()
352
+ self.distributed = num_devices > 1
353
+ self.tau = tau
354
+
355
+ def cosine_similarity(self, a, b, normalize=True):
356
+ if normalize:
357
+ w1 = a.norm(p=2, dim=1, keepdim=True)
358
+ w2 = b.norm(p=2, dim=1, keepdim=True)
359
+ sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
360
+ else:
361
+ sim_matrix = torch.mm(a, b.t())
362
+ return sim_matrix
363
+
364
+ def forward(self, x, y=None):
365
+ """
366
+ neg_sim: BxB
367
+ pos_sim: Bx1
368
+ """
369
+ if self.distributed:
370
+ all_image_features = torch.cat(
371
+ torch.distributed.nn.all_gather(x["features"]), dim=0
372
+ )
373
+ all_text_features = torch.cat(
374
+ torch.distributed.nn.all_gather(x["text_features"]), dim=0
375
+ )
376
+ all_labels = torch.cat(torch.distributed.nn.all_gather(y["label"]), dim=0)
377
+ else:
378
+ all_image_features = x["features"]
379
+ all_text_features = x["text_features"]
380
+ all_labels = y["label"]
381
+ labels_u = torch.unique(all_labels)
382
+ logits = self.cosine_similarity(
383
+ all_image_features, all_text_features, normalize=True
384
+ )
385
+ rows, cols = logits.size()
386
+ indices = torch.arange(0, rows, device=all_image_features.device)
387
+ loss = torch.sum(
388
+ torch.logsumexp(
389
+ logits[indices != indices.view(-1, 1)].view(rows, cols - 1) / self.tau,
390
+ dim=1,
391
+ )
392
+ )
393
+ for label in labels_u:
394
+ if not (label == "NaN"):
395
+ # Get the positive and negative examples
396
+ idx = all_labels == label
397
+ pos_logits = logits[idx][:, idx]
398
+ # Compute the MIL-NCE loss
399
+ loss += torch.sum(-torch.logsumexp(pos_logits / self.tau, dim=1))
400
+ return {
401
+ "contrastive_loss": loss,
402
+ }
403
+
404
+
405
+ class MILNCE(nn.Module):
406
+ def __init__(self, tau=0.1, num_devices=1):
407
+ super(MILNCE, self).__init__()
408
+ self.distributed = num_devices > 1
409
+ self.tau = tau
410
+
411
+ def cosine_similarity(self, a, b, normalize=True):
412
+ if normalize:
413
+ w1 = a.norm(p=2, dim=1, keepdim=True)
414
+ w2 = b.norm(p=2, dim=1, keepdim=True)
415
+ sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
416
+ else:
417
+ sim_matrix = torch.mm(a, b.t())
418
+ return sim_matrix
419
+
420
+ def forward(self, x, y=None):
421
+ """
422
+ COmpute MIL-NCE loss
423
+ """
424
+ if self.distributed:
425
+ all_image_features = torch.cat(
426
+ torch.distributed.nn.all_gather(x["features"]), dim=0
427
+ )
428
+ all_pos_features = torch.cat(
429
+ torch.distributed.nn.all_gather(x["pos_features"]), dim=0
430
+ )
431
+ all_labels = torch.cat(torch.distributed.nn.all_gather(y["label"]), dim=0)
432
+ else:
433
+ all_image_features = x["features"]
434
+ all_pos_features = x["pos_features"]
435
+ all_labels = y["label"]
436
+ labels_u = torch.unique(all_labels)
437
+ features = torch.cat([all_image_features, all_pos_features])
438
+ labels = torch.cat([all_labels, all_labels])
439
+ logits = self.cosine_similarity(features, features, normalize=True)
440
+ rows, cols = logits.size()
441
+ indices = torch.arange(0, rows, device=features.device)
442
+ loss = torch.sum(
443
+ torch.logsumexp(
444
+ logits[indices != indices.view(-1, 1)].view(rows, cols - 1) / self.tau,
445
+ dim=1,
446
+ )
447
+ )
448
+ for label in labels_u:
449
+ if not (label == "NaN"):
450
+ # Get the positive and negative examples
451
+ idx = labels == label
452
+ pos_logits = logits[idx][:, idx]
453
+
454
+ rows, cols = pos_logits.size()
455
+ indices = torch.arange(0, rows, device=features.device)
456
+ pos_logits = pos_logits[indices != indices.view(-1, 1)].view(
457
+ rows, cols - 1
458
+ )
459
+
460
+ # Compute the MIL-NCE loss
461
+ loss += torch.sum(-torch.logsumexp(pos_logits / self.tau, dim=1))
462
+ return {
463
+ "contrastive_loss": loss,
464
+ }
465
+
466
+
467
+ class RegionMILNCE(nn.Module):
468
+ def __init__(self, tau=0.1, num_devices=1):
469
+ super(RegionMILNCE, self).__init__()
470
+ self.distributed = num_devices > 1
471
+ self.tau = tau
472
+
473
+ def cosine_similarity(self, a, b, normalize=True):
474
+ if normalize:
475
+ w1 = a.norm(p=2, dim=1, keepdim=True)
476
+ w2 = b.norm(p=2, dim=1, keepdim=True)
477
+ sim_matrix = torch.mm(a, b.t()) / (w1 * w2.t()).clamp(min=1e-8)
478
+ else:
479
+ sim_matrix = torch.mm(a, b.t())
480
+ return sim_matrix
481
+
482
+ def forward(self, x, y=None):
483
+ """
484
+ neg_sim: BxB
485
+ pos_sim: Bx1
486
+ """
487
+ if self.distributed:
488
+ all_image_features = torch.cat(
489
+ torch.distributed.nn.all_gather(x["features"]), dim=0
490
+ )
491
+ all_pos_features = torch.cat(
492
+ torch.distributed.nn.all_gather(x["pos_features"]), dim=0
493
+ )
494
+ all_labels = torch.cat(torch.distributed.nn.all_gather(y["label"]), dim=0)
495
+ else:
496
+ all_image_features = x["features"]
497
+ all_pos_features = x["pos_features"]
498
+ all_labels = y["label"]
499
+ labels_u = torch.unique(all_labels)
500
+ features = torch.cat([all_image_features, all_pos_features])
501
+ labels = torch.cat([all_labels, all_labels])
502
+ logits = self.cosine_similarity(features, features, normalize=True)
503
+ rows, cols = logits.size()
504
+ indices = torch.arange(0, rows, device=features.device)
505
+ loss = torch.sum(
506
+ torch.logsumexp(
507
+ logits[indices != indices.view(-1, 1)].view(rows, cols - 1) / self.tau,
508
+ dim=1,
509
+ )
510
+ )
511
+ for label in labels_u:
512
+ if not (label == "NaN"):
513
+ # Get the positive and negative examples
514
+ idx = labels == label
515
+ pos_logits = logits[idx][:, idx]
516
+
517
+ rows, cols = pos_logits.size()
518
+ indices = torch.arange(0, rows, device=features.device)
519
+ pos_logits = pos_logits[indices != indices.view(-1, 1)].view(
520
+ rows, cols - 1
521
+ )
522
+
523
+ # Compute the MIL-NCE loss
524
+ loss += torch.sum(-torch.logsumexp(pos_logits / self.tau, dim=1))
525
+ return {
526
+ "contrastive_loss": loss / len(all_labels),
527
+ }
528
+
529
+
530
+ LOSSES = {
531
+ "l1": L1,
532
+ "l2": L2,
533
+ "l2_hybrid": L2Hybrid,
534
+ "haversine": Haversine,
535
+ "geoguessr": GeoguessrLoss,
536
+ "crossentropy": CrossEntropy,
537
+ "infonce": InfoNCE,
538
+ "mil-nce": MILNCE,
539
+ "text-nce": TextNCE,
540
+ "land_cover": LandCoverLoss,
541
+ "road_index": RoadIndexLoss,
542
+ "drive_side": DriveSideLoss,
543
+ "climate": ClimateLoss,
544
+ "soil": SoilLoss,
545
+ "dist_sea": DistSeaLoss,
546
+ "hierarchical": HierarchicalCrossEntropy,
547
+ "hier_quad": HierarchicalCrossEntropyQuad,
548
+ "region_mil": RegionMILNCE,
549
+ }
550
+ AVERAGE = {False: lambda x: x, True: lambda x: x.mean(dim=-1)}
551
+
552
+
553
+ class Losses(nn.Module):
554
+ """The Losses meta-object that can take a mix of losses."""
555
+
556
+ def __init__(self, mix={}, aux_data=[], path="", num_devices=1):
557
+ """Initializes the Losses object.
558
+ Args:
559
+ mix (dict): dictionary with keys "loss_name" and values weight
560
+ """
561
+ super(Losses, self).__init__()
562
+ assert len(mix)
563
+ self.aux = len(aux_data) > 0
564
+ if self.aux:
565
+ self.aux_list = aux_data
566
+ total = ["land_cover", "drive_side", "climate", "soil", "dist_sea"]
567
+ for col in self.aux_list:
568
+ total.remove(col)
569
+ for col in total:
570
+ del mix[col]
571
+ self.init_losses(mix, path, num_devices)
572
+
573
+ def init_losses(self, mix, path="", num_devices=1):
574
+ """Initializes the losses.
575
+ Args:
576
+ mix (dict): dictionary with keys "loss_name" and values weight
577
+ """
578
+ self.loss = {}
579
+ for m, v in mix.items():
580
+ m = m.lower()
581
+ if m in ["hierarchical", "hier_quad"]:
582
+ try:
583
+ self.loss[m] = (LOSSES[m](path), v)
584
+ except KeyError:
585
+ raise KeyError(f"Loss {m} not found in {LOSSES.keys()}")
586
+ elif m in ["region_mil", "mil-nce", "text-nce"]:
587
+ try:
588
+ self.loss[m] = (LOSSES[m](num_devices=num_devices), v)
589
+ except KeyError:
590
+ raise KeyError(f"Loss {m} not found in {LOSSES.keys()}")
591
+ else:
592
+ try:
593
+ self.loss[m] = (LOSSES[m](), v)
594
+ except KeyError:
595
+ raise KeyError(f"Loss {m} not found in {LOSSES.keys()}")
596
+
597
+ def forward(self, x, y, average=True):
598
+ """Computes the losses.
599
+ Args:
600
+ x: dict that contains "gps": torch.Tensor Bx2 or "label": torch.Tensor BxN
601
+ y: dict that contains "gps": torch.Tensor Bx2 or "label": torch.Tensor BxN
602
+ average (bool): whether to average the losses or not
603
+ Returns:
604
+ dict: dictionary with losses
605
+ """
606
+ output = {"loss": 0}
607
+ for loss_name, (loss, weight) in self.loss.items():
608
+ loss_output = loss(x, y)
609
+ for k, v in loss_output.items():
610
+ v = AVERAGE[average](v)
611
+ if k.endswith("_loss"):
612
+ output["loss"] += weight * v
613
+ output[k] = v
614
+ return output
models/misc.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ class DoNothingOptimizer(nn.Module):
2
+ def __init__(self, *args, **kwargs):
3
+ pass
4
+
5
+ def step(self, *args, **kwargs):
6
+ pass
7
+
8
+ def zero_grad(self, *args, **kwargs):
9
+ pass
models/module.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+ import pytorch_lightning as L
4
+ import torch
5
+ import torch.nn as nn
6
+ from hydra.utils import instantiate
7
+ import copy
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+
12
+ class Geolocalizer(L.LightningModule):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+ self.model = instantiate(cfg.network.instance)
17
+ if cfg.text_tuning:
18
+ self.text_model = instantiate(cfg.text_network.instance)
19
+ self.loss = instantiate(cfg.loss)
20
+ self.val_metrics = instantiate(cfg.val_metrics)
21
+ self.test_metrics = instantiate(cfg.test_metrics)
22
+ self.text_tuning = cfg.text_tuning
23
+
24
+ def training_step(self, batch, batch_idx):
25
+ pred = self.model(batch)
26
+ if self.text_tuning:
27
+ pred["text_features"] = self.text_model(batch)
28
+ loss = self.loss(pred, batch, average=True)
29
+ for metric_name, metric_value in loss.items():
30
+ self.log(
31
+ f"train/{metric_name}",
32
+ metric_value,
33
+ sync_dist=True,
34
+ on_step=True,
35
+ on_epoch=True,
36
+ )
37
+ return loss
38
+
39
+ @torch.no_grad()
40
+ def validation_step(self, batch, batch_idx):
41
+ pred = self.model(batch)
42
+ if self.text_tuning:
43
+ pred["text_features"] = self.text_model(batch)
44
+ loss = self.loss(pred, batch, average=True)["loss"]
45
+ self.val_metrics.update(pred, batch)
46
+ self.log("val/loss", loss, sync_dist=True, on_step=False, on_epoch=True)
47
+
48
+ def on_validation_epoch_end(self):
49
+ metrics = self.val_metrics.compute()
50
+ for metric_name, metric_value in metrics.items():
51
+ self.log(
52
+ f"val/{metric_name}",
53
+ metric_value,
54
+ sync_dist=True,
55
+ on_step=False,
56
+ on_epoch=True,
57
+ )
58
+
59
+ @torch.no_grad()
60
+ def test_step(self, batch, batch_idx):
61
+ pred = self.model(batch)
62
+ self.test_metrics.update(pred, batch)
63
+
64
+ def on_test_epoch_end(self):
65
+ metrics = self.test_metrics.compute()
66
+ for metric_name, metric_value in metrics.items():
67
+ self.log(
68
+ f"test/{metric_name}",
69
+ metric_value,
70
+ sync_dist=True,
71
+ on_step=False,
72
+ on_epoch=True,
73
+ )
74
+
75
+ def configure_optimizers(self):
76
+ lora_params = []
77
+ backbone_params = []
78
+ other_params = []
79
+ last_block_params = []
80
+ for name, param in self.model.named_parameters():
81
+ if "lora" in name:
82
+ lora_params.append(param)
83
+ elif "backbone" in name:
84
+ if self.cfg.optimizer.diff_backbone_last and ".11." in name:
85
+ last_block_params.append(param)
86
+ else:
87
+ backbone_params.append(param)
88
+ else:
89
+ other_params.append(param)
90
+
91
+ params_to_optimize = [{"params": other_params}]
92
+ if self.cfg.optimizer.unfreeze_lr:
93
+ params_to_optimize += [
94
+ {"params": backbone_params, "lr": self.cfg.optimizer.backbone_lr}
95
+ ]
96
+ if self.cfg.optimizer.diff_backbone_last:
97
+ params_to_optimize += [
98
+ {
99
+ "params": last_block_params,
100
+ "lr": self.cfg.optimizer.last_block_lr,
101
+ }
102
+ ]
103
+ if len(lora_params) > 0:
104
+ # LoRA params sometimes train better with a different lr (~1e-4 for CLIP)
105
+ params_to_optimize += [
106
+ {"params": lora_params, "lr": self.cfg.optimizer.lora_lr}
107
+ ]
108
+ if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay:
109
+ parameters_names_wd = get_parameter_names(self.model, [nn.LayerNorm])
110
+ parameters_names_wd = [
111
+ name for name in parameters_names_wd if "bias" not in name
112
+ ]
113
+ optimizer_grouped_parameters = [
114
+ {
115
+ "params": [
116
+ p
117
+ for n, p in self.model.named_parameters()
118
+ if n in parameters_names_wd
119
+ ],
120
+ "weight_decay": self.cfg.optimizer.optim.weight_decay,
121
+ },
122
+ {
123
+ "params": [
124
+ p
125
+ for n, p in self.model.named_parameters()
126
+ if n not in parameters_names_wd
127
+ ],
128
+ "weight_decay": 0.0,
129
+ },
130
+ ]
131
+ optimizer = instantiate(
132
+ self.cfg.optimizer.optim, optimizer_grouped_parameters
133
+ )
134
+ else:
135
+ optimizer = instantiate(self.cfg.optimizer.optim, params_to_optimize)
136
+ scheduler = instantiate(self.cfg.lr_scheduler)(optimizer)
137
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
138
+
139
+ def lr_scheduler_step(self, scheduler, metric):
140
+ scheduler.step(self.global_step)
141
+
142
+
143
+ def get_parameter_names(model, forbidden_layer_types):
144
+ """
145
+ Returns the names of the model parameters that are not inside a forbidden layer.
146
+ Taken from HuggingFace transformers.
147
+ """
148
+ result = []
149
+ for name, child in model.named_children():
150
+ result += [
151
+ f"{name}.{n}"
152
+ for n in get_parameter_names(child, forbidden_layer_types)
153
+ if not isinstance(child, tuple(forbidden_layer_types))
154
+ ]
155
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
156
+ result += list(model._parameters.keys())
157
+ return result
models/networks/backbones.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.hub
2
+
3
+ from transformers import (
4
+ CLIPVisionModel,
5
+ CLIPVisionConfig,
6
+ CLIPModel,
7
+ CLIPProcessor,
8
+ AutoTokenizer,
9
+ CLIPTextModelWithProjection,
10
+ CLIPTextConfig,
11
+ CLIPVisionModelWithProjection,
12
+ ResNetModel,
13
+ ResNetConfig
14
+ )
15
+ from torch import nn
16
+
17
+ from PIL import Image
18
+ import requests
19
+
20
+
21
+ class CLIP(nn.Module):
22
+ def __init__(self, path):
23
+ """Initializes the CLIP model."""
24
+ super().__init__()
25
+ if path == "":
26
+ config_vision = CLIPVisionConfig()
27
+ self.clip = CLIPVisionModel(config_vision)
28
+ else:
29
+ self.clip = CLIPVisionModel.from_pretrained(path)
30
+
31
+ def forward(self, x):
32
+ """Predicts CLIP features from an image.
33
+ Args:
34
+ x (dict that contains "img": torch.Tensor): Input batch
35
+ """
36
+ features = self.clip(pixel_values=x["img"])["last_hidden_state"]
37
+ return features
38
+
39
+
40
+ class CLIPJZ(nn.Module):
41
+ def __init__(self, path):
42
+ """Initializes the CLIP model."""
43
+ super().__init__()
44
+ if path == "":
45
+ config_vision = CLIPVisionConfig()
46
+ self.clip = CLIPVisionModel(config_vision)
47
+ else:
48
+ self.clip = CLIPVisionModel.from_pretrained(path)
49
+
50
+ def forward(self, x):
51
+ """Predicts CLIP features from an image.
52
+ Args:
53
+ x (dict that contains "img": torch.Tensor): Input batch
54
+ """
55
+ features = self.clip(pixel_values=x["img"])["last_hidden_state"]
56
+ return features
57
+
58
+
59
+ class StreetCLIP(nn.Module):
60
+ def __init__(self, path):
61
+ """Initializes the CLIP model."""
62
+ super().__init__()
63
+ self.clip = CLIPModel.from_pretrained(path)
64
+ self.transform = CLIPProcessor.from_pretrained(path)
65
+
66
+ def forward(self, x):
67
+ """Predicts CLIP features from an image.
68
+ Args:
69
+ x (dict that contains "img": torch.Tensor): Input batch
70
+ """
71
+ features = self.clip.get_image_features(
72
+ **self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device)
73
+ ).unsqueeze(1)
74
+ return features
75
+
76
+
77
+ class CLIPText(nn.Module):
78
+ def __init__(self, path):
79
+ """Initializes the CLIP model."""
80
+ super().__init__()
81
+ if path == "":
82
+ config_vision = CLIPVisionConfig()
83
+ self.clip = CLIPVisionModel(config_vision)
84
+ else:
85
+ self.clip = CLIPVisionModelWithProjection.from_pretrained(path)
86
+
87
+ def forward(self, x):
88
+ """Predicts CLIP features from an image.
89
+ Args:
90
+ x (dict that contains "img": torch.Tensor): Input batch
91
+ """
92
+ features = self.clip(pixel_values=x["img"])
93
+ return features.image_embeds, features.last_hidden_state
94
+
95
+
96
+ class TextEncoder(nn.Module):
97
+ def __init__(self, path):
98
+ """Initializes the CLIP text model."""
99
+ super().__init__()
100
+ if path == "":
101
+ config_vision = CLIPTextConfig()
102
+ self.clip = CLIPTextModelWithProjection(config_vision)
103
+ self.transform = AutoTokenizer()
104
+ else:
105
+ self.clip = CLIPTextModelWithProjection.from_pretrained(path)
106
+ self.transform = AutoTokenizer.from_pretrained(path)
107
+ for p in self.clip.parameters():
108
+ p.requires_grad = False
109
+ self.clip.eval()
110
+
111
+ def forward(self, x):
112
+ """Predicts CLIP features from text.
113
+ Args:
114
+ x (dict that contains "text": list): Input batch
115
+ """
116
+ features = self.clip(
117
+ **self.transform(x["text"], padding=True, return_tensors="pt").to(
118
+ x["gps"].device
119
+ )
120
+ ).text_embeds
121
+ return features
122
+
123
+
124
+ class DINOv2(nn.Module):
125
+ def __init__(self, tag) -> None:
126
+ """Initializes the DINO model."""
127
+ super().__init__()
128
+ self.dino = torch.hub.load("facebookresearch/dinov2", tag)
129
+ self.stride = 14 # ugly but dinov2 stride = 14
130
+
131
+ def forward(self, x):
132
+ """Predicts DINO features from an image."""
133
+ x = x["img"]
134
+
135
+ # crop for stride
136
+ _, _, H, W = x.shape
137
+ H_new = H - H % self.stride
138
+ W_new = W - W % self.stride
139
+ x = x[:, :, :H_new, :W_new]
140
+
141
+ # forward features
142
+ x = self.dino.forward_features(x)
143
+ x = x["x_prenorm"]
144
+ return x
145
+
146
+ class ResNet(nn.Module):
147
+ def __init__(self, path):
148
+ """Initializes the ResNet model."""
149
+ super().__init__()
150
+ if path == "":
151
+ config_vision = ResNetConfig()
152
+ self.resnet = ResNetModel(config_vision)
153
+ else:
154
+ self.resnet = ResNetModel.from_pretrained(path)
155
+
156
+ def forward(self, x):
157
+ """Predicts ResNet50 features from an image.
158
+ Args:
159
+ x (dict that contains "img": torch.Tensor): Input batch
160
+ """
161
+ features = self.resnet(x["img"])["pooler_output"]
162
+ return features.squeeze()
models/networks/heads/__init__.py ADDED
File without changes
models/networks/heads/auxilliary.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.networks.utils import UnormGPS
3
+ from torch.nn.functional import tanh, sigmoid, softmax
4
+
5
+
6
+ class AuxHead(nn.Module):
7
+ def __init__(self, aux_data=[], use_tanh=False):
8
+ super().__init__()
9
+ self.aux_data = aux_data
10
+ self.unorm = UnormGPS()
11
+ self.use_tanh = use_tanh
12
+
13
+ def forward(self, x):
14
+ """Forward pass of the network.
15
+ x : Union[torch.Tensor, dict] with the output of the backbone.
16
+ """
17
+ if self.use_tanh:
18
+ gps = tanh(x["gps"])
19
+ gps = self.unorm(gps)
20
+ output = {"gps": gps}
21
+ if "land_cover" in self.aux_data:
22
+ output["land_cover"] = softmax(x["land_cover"])
23
+ if "road_index" in self.aux_data:
24
+ output["road_index"] = x["road_index"]
25
+ if "drive_side" in self.aux_data:
26
+ output["drive_side"] = sigmoid(x["drive_side"])
27
+ if "climate" in self.aux_data:
28
+ output["climate"] = softmax(x["climate"])
29
+ if "soil" in self.aux_data:
30
+ output["soil"] = softmax(x["soil"])
31
+ if "dist_sea" in self.aux_data:
32
+ output["dist_sea"] = x["dist_sea"]
33
+ return output
models/networks/heads/classification.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ClassificationHead(nn.Module):
6
+ """Classification head for the network."""
7
+
8
+ def __init__(self, id_to_gps):
9
+ super().__init__()
10
+ self.id_to_gps = id_to_gps
11
+
12
+ def forward(self, x):
13
+ """Forward pass of the network.
14
+ x : Union[torch.Tensor, dict] with the output of the backbone.
15
+ """
16
+ gps = self.id_to_gps(x.argmax(dim=-1))
17
+ return {"label": x, **gps}
models/networks/heads/hybrid.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pandas as pd
4
+
5
+ from models.networks.utils import UnormGPS
6
+
7
+
8
+ class HybridHead(nn.Module):
9
+ """Classification head followed by regression head for the network."""
10
+
11
+ def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh):
12
+ super().__init__()
13
+ self.final_dim = final_dim
14
+ self.use_tanh = use_tanh
15
+ self.scale_tanh = scale_tanh
16
+
17
+ self.unorm = UnormGPS()
18
+
19
+ if quadtree_path is not None:
20
+ quadtree = pd.read_csv(quadtree_path)
21
+ self.init_quadtree(quadtree)
22
+
23
+ def init_quadtree(self, quadtree):
24
+ quadtree[["min_lat", "max_lat"]] /= 90.0
25
+ quadtree[["min_lon", "max_lon"]] /= 180.0
26
+ self.register_buffer(
27
+ "cell_center",
28
+ 0.5 * torch.tensor(quadtree[["max_lat", "max_lon"]].values)
29
+ + 0.5 * torch.tensor(quadtree[["min_lat", "min_lon"]].values),
30
+ )
31
+ self.register_buffer(
32
+ "cell_size",
33
+ torch.tensor(quadtree[["max_lat", "max_lon"]].values)
34
+ - torch.tensor(quadtree[["min_lat", "min_lon"]].values),
35
+ )
36
+
37
+ def forward(self, x, gt_label):
38
+ """Forward pass of the network.
39
+ x : Union[torch.Tensor, dict] with the output of the backbone.
40
+ """
41
+
42
+ classification_logits = x[..., : self.final_dim]
43
+ classification = classification_logits.argmax(dim=-1)
44
+
45
+ regression = x[..., self.final_dim :]
46
+
47
+ if self.use_tanh:
48
+ regression = self.scale_tanh * torch.tanh(regression)
49
+
50
+ regression = regression.view(regression.shape[0], -1, 2)
51
+
52
+ if self.training:
53
+ regression = torch.gather(
54
+ regression,
55
+ 1,
56
+ gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2),
57
+ )[:, 0, :]
58
+ size = 2.0 / self.cell_size[gt_label]
59
+ center = self.cell_center[gt_label]
60
+ gps = (
61
+ self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0
62
+ )
63
+ else:
64
+ regression = torch.gather(
65
+ regression,
66
+ 1,
67
+ classification.unsqueeze(-1)
68
+ .unsqueeze(-1)
69
+ .expand(regression.shape[0], 1, 2),
70
+ )[:, 0, :]
71
+ size = 2.0 / self.cell_size[classification]
72
+ center = self.cell_center[classification]
73
+ gps = (
74
+ self.cell_center[classification]
75
+ + regression * self.cell_size[classification] / 2.0
76
+ )
77
+
78
+ gps = self.unorm(gps)
79
+
80
+ return {
81
+ "label": classification_logits,
82
+ "gps": gps,
83
+ "size": size,
84
+ "center": center,
85
+ "reg": regression,
86
+ }
87
+
88
+ class HybridHeadCentroid(nn.Module):
89
+ """Classification head followed by regression head for the network."""
90
+
91
+ def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh):
92
+ super().__init__()
93
+ self.final_dim = final_dim
94
+ self.use_tanh = use_tanh
95
+ self.scale_tanh = scale_tanh
96
+
97
+ self.unorm = UnormGPS()
98
+ if quadtree_path is not None:
99
+ quadtree = pd.read_csv(quadtree_path)
100
+ self.init_quadtree(quadtree)
101
+
102
+ def init_quadtree(self, quadtree):
103
+ quadtree[["min_lat", "max_lat", "mean_lat"]] /= 90.0
104
+ quadtree[["min_lon", "max_lon", "mean_lon"]] /= 180.0
105
+ self.cell_center = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values)
106
+ self.cell_size_up = torch.tensor(quadtree[["max_lat", "max_lon"]].values) - torch.tensor(quadtree[["mean_lat", "mean_lon"]].values)
107
+ self.cell_size_down = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) - torch.tensor(quadtree[["min_lat", "min_lon"]].values)
108
+
109
+ def forward(self, x, gt_label):
110
+ """Forward pass of the network.
111
+ x : Union[torch.Tensor, dict] with the output of the backbone.
112
+ """
113
+ classification_logits = x[..., : self.final_dim]
114
+ classification = classification_logits.argmax(dim=-1)
115
+ self.cell_size_up = self.cell_size_up.to(classification.device)
116
+ self.cell_center = self.cell_center.to(classification.device)
117
+ self.cell_size_down = self.cell_size_down.to(classification.device)
118
+
119
+ regression = x[..., self.final_dim :]
120
+
121
+ if self.use_tanh:
122
+ regression = self.scale_tanh * torch.tanh(regression)
123
+
124
+ regression = regression.view(regression.shape[0], -1, 2)
125
+
126
+ if self.training:
127
+ regression = torch.gather(
128
+ regression,
129
+ 1,
130
+ gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2),
131
+ )[:, 0, :]
132
+ size = torch.where(
133
+ regression > 0,
134
+ self.cell_size_up[gt_label],
135
+ self.cell_size_down[gt_label],
136
+ )
137
+ center = self.cell_center[gt_label]
138
+ gps = self.cell_center[gt_label] + regression * size
139
+ else:
140
+ regression = torch.gather(
141
+ regression,
142
+ 1,
143
+ classification.unsqueeze(-1)
144
+ .unsqueeze(-1)
145
+ .expand(regression.shape[0], 1, 2),
146
+ )[:, 0, :]
147
+ size = torch.where(
148
+ regression > 0,
149
+ self.cell_size_up[classification],
150
+ self.cell_size_down[classification],
151
+ )
152
+ center = self.cell_center[classification]
153
+ gps = self.cell_center[classification] + regression * size
154
+
155
+ gps = self.unorm(gps)
156
+
157
+ return {
158
+ "label": classification_logits,
159
+ "gps": gps,
160
+ "size": 1.0 / size,
161
+ "center": center,
162
+ "reg": regression,
163
+ }
164
+
165
+
166
+ class SharedHybridHead(HybridHead):
167
+ """Classification head followed by SHARED regression head for the network."""
168
+
169
+ def forward(self, x, gt_label):
170
+ """Forward pass of the network.
171
+ x : Union[torch.Tensor, dict] with the output of the backbone.
172
+ """
173
+
174
+ classification_logits = x[..., : self.final_dim]
175
+ classification = classification_logits.argmax(dim=-1)
176
+
177
+ regression = x[..., self.final_dim :]
178
+
179
+ if self.use_tanh:
180
+ regression = self.scale_tanh * torch.tanh(regression)
181
+
182
+ if self.training:
183
+ gps = (
184
+ self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0
185
+ )
186
+ else:
187
+ gps = (
188
+ self.cell_center[classification]
189
+ + regression * self.cell_size[classification] / 2.0
190
+ )
191
+
192
+ gps = self.unorm(gps)
193
+
194
+ return {"label": classification_logits, "gps": gps}
models/networks/heads/id_to_gps.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.networks.utils import UnormGPS
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+
7
+ class IdToGPS(nn.Module):
8
+ def __init__(self, id_to_gps: str):
9
+ """Map index to gps coordinates (indices can be country or city ids)"""
10
+ super().__init__()
11
+ if "quadtree" in id_to_gps:
12
+ self.id_to_gps = torch.load(
13
+ "_".join(id_to_gps.split("_")[:-4] + id_to_gps.split("_")[-3:])
14
+ )
15
+ else:
16
+ self.id_to_gps = torch.load(id_to_gps)
17
+ #self.unorm = UnormGPS()
18
+
19
+ def forward(self, x):
20
+ """Mapping from country id to gps coordinates
21
+ Args:
22
+ x: torch.Tensor with features
23
+ """
24
+
25
+ if isinstance(x, dict):
26
+ # for oracle
27
+ labels, x = x["label"], x["img"]
28
+ else:
29
+ # predicted labels
30
+ labels = x
31
+ self.id_to_gps = self.id_to_gps.to(labels.device)
32
+ #return {"gps": self.unorm(self.id_to_gps[labels])}
33
+ return {"gps": self.id_to_gps[labels]}
models/networks/heads/random.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch import nn
4
+ from models.networks.utils import UnormGPS
5
+
6
+
7
+ class Random(nn.Module):
8
+ def __init__(self, num_output):
9
+ """Random"""
10
+ super().__init__()
11
+ self.num_output = num_output
12
+ self.unorm = UnormGPS()
13
+
14
+ def forward(self, x):
15
+ """Predicts GPS coordinates from an image.
16
+ Args:
17
+ x: torch.Tensor with features
18
+ """
19
+ #x = x["img"]
20
+ gps = torch.rand((x.shape[0], self.num_output), device=x.device) * 2 - 1
21
+ return {"gps": self.unorm(gps)}
22
+
23
+
24
+ class RandomCoords(nn.Module):
25
+ def __init__(self, coords_path: str):
26
+ """Randomly sample from a list of coordinates
27
+ Args:
28
+ coords_path: str with path to csv file with coordinates
29
+ """
30
+ super().__init__()
31
+ coordinates = pd.read_csv(coords_path)
32
+ longitudes = coordinates["longitude"].values / 180
33
+ latitudes = coordinates["latitude"].values / 90
34
+ self.unorm = UnormGPS()
35
+ del coordinates
36
+
37
+ self.N = len(longitudes)
38
+ assert len(longitudes) == len(latitudes)
39
+ self.coordinates = torch.stack(
40
+ [torch.tensor(latitudes), torch.tensor(longitudes)],
41
+ dim=-1,
42
+ )
43
+ del longitudes, latitudes
44
+
45
+ def forward(self, x):
46
+ """Predicts GPS coordinates from an image.
47
+ Args:
48
+ x: torch.Tensor with features
49
+ """
50
+ x = x["img"]
51
+ # randomly select a coordinate in the list
52
+ n = torch.randint(0, self.N, (x.shape[0],))
53
+ return {"gps": self.unorm(self.coordinates[n].to(x.device))}
models/networks/heads/regression.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.networks.utils import UnormGPS
2
+ import torch.nn as nn
3
+ from torch.nn.functional import tanh
4
+ import torch
5
+
6
+
7
+ class RegressionHead(nn.Module):
8
+ def __init__(self, use_tanh=False):
9
+ super().__init__()
10
+ self.unorm = UnormGPS()
11
+ self.use_tanh = use_tanh
12
+
13
+ def forward(self, x):
14
+ """Forward pass of the network.
15
+ x : Union[torch.Tensor, dict] with the output of the backbone.
16
+ """
17
+ if self.use_tanh:
18
+ x = tanh(x)
19
+ gps = self.unorm(x)
20
+ return {"gps": gps}
21
+
22
+
23
+ class RegressionHeadAngle(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.unorm = UnormGPS()
27
+
28
+ def forward(self, x):
29
+ """Forward pass of the network.
30
+ x : Union[torch.Tensor, dict] with the output of the backbone.
31
+ """
32
+ x1 = x[:, 0].pow(2)
33
+ x2 = x[:, 1].pow(2)
34
+ x3 = x[:, 2].pow(2)
35
+ x4 = x[:, 3].pow(2)
36
+ cos_lambda = x1 / (x1 + x2)
37
+ sin_lambda = x2 / (x1 + x2)
38
+ cos_phi = x3 / (x3 + x4)
39
+ sin_phi = x4 / (x3 + x4)
40
+ lbd = torch.atan2(sin_lambda, cos_lambda)
41
+ phi = torch.atan2(sin_phi, cos_phi)
42
+ gps = torch.cat((lbd.unsqueeze(1), phi.unsqueeze(1)), dim=1)
43
+ # gps = self.unorm(x)
44
+ return {"gps": gps}
models/networks/mlp.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class MLP(nn.Module):
6
+ def __init__(
7
+ self,
8
+ initial_dim=512,
9
+ hidden_dim=[128, 32, 2],
10
+ final_dim=2,
11
+ norm=nn.InstanceNorm1d,
12
+ activation=nn.ReLU,
13
+ aux_data=[],
14
+ ):
15
+ """
16
+ Initializes an MLP Classification Head
17
+ Args:
18
+ hidden_dim (list): list of hidden dimensions for the MLP
19
+ norm (nn.Module): normalization layer
20
+ activation (nn.Module): activation layer
21
+ """
22
+ super().__init__()
23
+ self.aux_data = aux_data
24
+ self.aux = len(self.aux_data) > 0
25
+ if self.aux:
26
+ hidden_dim_aux = hidden_dim
27
+ hidden_dim_aux[-1] = 128
28
+ final_dim_aux_dict = {
29
+ "land_cover": 12,
30
+ "climate": 30,
31
+ "soil": 14,
32
+ "road_index": 1,
33
+ "drive_side": 1,
34
+ "dist_sea": 1,
35
+ }
36
+ self.idx = {}
37
+ final_dim_aux = 0
38
+ for col in self.aux_data:
39
+ self.idx[col] = [
40
+ final_dim_aux + i for i in range(final_dim_aux_dict[col])
41
+ ]
42
+ final_dim_aux += final_dim_aux_dict[col]
43
+ dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
44
+ args = self.init_layers(dim, norm, activation)
45
+ self.mlp_aux = nn.Sequential(*args)
46
+ dim = [initial_dim] + hidden_dim + [final_dim]
47
+ args = self.init_layers(dim, norm, activation)
48
+ self.mlp = nn.Sequential(*args)
49
+
50
+ def init_layers(self, dim, norm, activation):
51
+ """Initializes the MLP layers."""
52
+ args = [nn.LayerNorm(dim[0])]
53
+ for i in range(len(dim) - 1):
54
+ args.append(nn.Linear(dim[i], dim[i + 1]))
55
+ if i < len(dim) - 2:
56
+ # args.append(norm(dim[i + 1]))
57
+ args.append(norm(4, dim[i + 1]))
58
+ args.append(activation())
59
+ return args
60
+
61
+ def forward(self, x):
62
+ """Predicts GPS coordinates from an image.
63
+ Args:
64
+ x: torch.Tensor with features
65
+ """
66
+ if self.aux:
67
+ out = {"gps": self.mlp(x[:, 0, :])}
68
+ x = self.mlp_aux(x[:, 0, :])
69
+ for col in list(self.idx.keys()):
70
+ out[col] = x[:, self.idx[col]]
71
+ return out
72
+ return self.mlp(x[:, 0, :])
73
+
74
+ class MLPResNet(nn.Module):
75
+ def __init__(
76
+ self,
77
+ initial_dim=512,
78
+ hidden_dim=[128, 32, 2],
79
+ final_dim=2,
80
+ norm=nn.InstanceNorm1d,
81
+ activation=nn.ReLU,
82
+ aux_data=[],
83
+ ):
84
+ """
85
+ Initializes an MLP Classification Head
86
+ Args:
87
+ hidden_dim (list): list of hidden dimensions for the MLP
88
+ norm (nn.Module): normalization layer
89
+ activation (nn.Module): activation layer
90
+ """
91
+ super().__init__()
92
+ self.aux_data = aux_data
93
+ self.aux = len(self.aux_data) > 0
94
+ if self.aux:
95
+ hidden_dim_aux = hidden_dim
96
+ hidden_dim_aux[-1] = 128
97
+ final_dim_aux_dict = {
98
+ "land_cover": 12,
99
+ "climate": 30,
100
+ "soil": 14,
101
+ "road_index": 1,
102
+ "drive_side": 1,
103
+ "dist_sea": 1,
104
+ }
105
+ self.idx = {}
106
+ final_dim_aux = 0
107
+ for col in self.aux_data:
108
+ self.idx[col] = [
109
+ final_dim_aux + i for i in range(final_dim_aux_dict[col])
110
+ ]
111
+ final_dim_aux += final_dim_aux_dict[col]
112
+ dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
113
+ args = self.init_layers(dim, norm, activation)
114
+ self.mlp_aux = nn.Sequential(*args)
115
+ dim = [initial_dim] + hidden_dim + [final_dim]
116
+ args = self.init_layers(dim, norm, activation)
117
+ self.mlp = nn.Sequential(*args)
118
+
119
+ def init_layers(self, dim, norm, activation):
120
+ """Initializes the MLP layers."""
121
+ args = [nn.LayerNorm(dim[0])]
122
+ for i in range(len(dim) - 1):
123
+ args.append(nn.Linear(dim[i], dim[i + 1]))
124
+ if i < len(dim) - 2:
125
+ # args.append(norm(dim[i + 1]))
126
+ args.append(norm(4, dim[i + 1]))
127
+ args.append(activation())
128
+ return args
129
+
130
+ def forward(self, x):
131
+ """Predicts GPS coordinates from an image.
132
+ Args:
133
+ x: torch.Tensor with features
134
+ """
135
+ if self.aux:
136
+ out = {"gps": self.mlp(x[:, 0, :])}
137
+ x = self.mlp_aux(x[:, 0, :])
138
+ for col in list(self.idx.keys()):
139
+ out[col] = x[:, self.idx[col]]
140
+ return out
141
+ return self.mlp(x)
142
+
143
+
144
+ class MLPCentroid(nn.Module):
145
+ def __init__(
146
+ self,
147
+ initial_dim=512,
148
+ hidden_dim=[128, 32, 2],
149
+ final_dim=2,
150
+ norm=nn.InstanceNorm1d,
151
+ activation=nn.ReLU,
152
+ aux_data=[],
153
+ ):
154
+ """
155
+ Initializes an MLP Classification Head
156
+ Args:
157
+ hidden_dim (list): list of hidden dimensions for the MLP
158
+ norm (nn.Module): normalization layer
159
+ activation (nn.Module): activation layer
160
+ """
161
+ super().__init__()
162
+ self.aux_data = aux_data
163
+ self.aux = len(self.aux_data) > 0
164
+ dim = [initial_dim] + hidden_dim + [final_dim // 3]
165
+ args = self.init_layers(dim, norm, activation)
166
+ self.classif = nn.Sequential(*args)
167
+ dim = [initial_dim] + hidden_dim + [2 * final_dim // 3]
168
+ args = self.init_layers(dim, norm, activation)
169
+ self.reg = nn.Sequential(*args)
170
+ # torch.nn.init.normal_(self.reg.weight, mean=0.0, std=0.01)
171
+ if self.aux:
172
+ self.dim = [initial_dim] + hidden_dim
173
+ self.predictors = {"gps": self.mlp}
174
+ self.init_aux(dim, norm, activation)
175
+
176
+ def init_layers(self, dim, norm, activation):
177
+ """Initializes the MLP layers."""
178
+ args = [nn.LayerNorm(dim[0])]
179
+ for i in range(len(dim) - 1):
180
+ args.append(nn.Linear(dim[i], dim[i + 1]))
181
+ if i < len(dim) - 2:
182
+ # args.append(norm(dim[i + 1]))
183
+ args.append(norm(4, dim[i + 1]))
184
+ args.append(activation())
185
+ return args
186
+
187
+ def init_aux(self, dim, norm, activation):
188
+ final_dim_aux = {
189
+ "land_cover": 12,
190
+ "climate": 30,
191
+ "soil": 14,
192
+ "road_index": 1,
193
+ "drive_side": 1,
194
+ "dist_sea": 1,
195
+ }
196
+ if "land_cover" in self.aux_data:
197
+ args = self.init_layers(
198
+ self.dim + [final_dim_aux["land_cover"]], norm, activation
199
+ )
200
+ self.land_cover = nn.Sequential(*args)
201
+ self.predictors["land_cover"] = self.land_cover
202
+ if "road_index" in self.aux_data:
203
+ args = self.init_layers(
204
+ self.dim + [final_dim_aux["road_index"]], norm, activation
205
+ )
206
+ self.road_index = nn.Sequential(*args)
207
+ self.predictors["road_index"] = self.road_index
208
+ if "drive_side" in self.aux_data:
209
+ args = self.init_layers(
210
+ self.dim + [final_dim_aux["drive_side"]], norm, activation
211
+ )
212
+ self.drive_side = nn.Sequential(*args)
213
+ self.predictors["drive_side"] = self.drive_side
214
+ if "climate" in self.aux_data:
215
+ args = self.init_layers(
216
+ self.dim + [final_dim_aux["climate"]], norm, activation
217
+ )
218
+ self.climate = nn.Sequential(*args)
219
+ self.predictors["climate"] = self.climate
220
+ if "soil" in self.aux_data:
221
+ args = self.init_layers(
222
+ self.dim + [final_dim_aux["soil"]], norm, activation
223
+ )
224
+ self.soil = nn.Sequential(*args)
225
+ self.predictors["soil"] = self.soil
226
+ if "dist_sea" in self.aux_data:
227
+ args = self.init_layers(
228
+ self.dim + [final_dim_aux["dist_sea"]], norm, activation
229
+ )
230
+ self.dist_sea = nn.Sequential(*args)
231
+ self.predictors["dist_sea"] = self.dist_sea
232
+
233
+ def forward(self, x):
234
+ """Predicts GPS coordinates from an image.
235
+ Args:
236
+ x: torch.Tensor with features
237
+ """
238
+ if self.aux:
239
+ return {
240
+ col: self.predictors[col](x[:, 0, :]) for col in self.predictors.keys()
241
+ }
242
+ return torch.cat([self.classif(x[:, 0, :]), self.reg(x[:, 0, :])], dim=1)
243
+
244
+
245
+ class Identity(nn.Module):
246
+ def __init__(
247
+ self
248
+ ):
249
+ """
250
+ Initializes an Identity module
251
+ """
252
+ super().__init__()
253
+
254
+ def forward(self, x):
255
+ """
256
+ Return same as input
257
+ """
258
+ return x
models/networks/network.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from abc import ABC, abstractmethod
4
+ from torch import nn
5
+ from hydra.utils import instantiate
6
+ import copy
7
+ from peft import LoraConfig, get_peft_model
8
+ from utils.model_utils import print_trainable_parameters
9
+
10
+
11
+ def freeze(model):
12
+ """Freezes the parameters of a model."""
13
+ for p in model.parameters():
14
+ p.requires_grad = False
15
+ model.eval()
16
+
17
+
18
+ def unfreeze(model):
19
+ """Unfreezes the parameters of a model.
20
+ for p in model.parameters():
21
+ p.requires_grad = True"""
22
+ model_parameters = model.named_parameters()
23
+ for name, param in model_parameters:
24
+ if name in [
25
+ "clip.vision_model.post_layernorm.weight",
26
+ "clip.vision_model.post_layernorm.bias",
27
+ ]:
28
+ param.requires_grad = False
29
+ else:
30
+ param.requires_grad = True
31
+ model.train()
32
+
33
+
34
+ def unfreeze_last(model):
35
+ """Unfreezes the parameters of a model.
36
+ for p in model.parameters():
37
+ p.requires_grad = True"""
38
+ model_parameters = model.named_parameters()
39
+ for name, param in model_parameters:
40
+ if len(name.split(".")) > 5:
41
+ if name.split(".")[4] == "11":
42
+ param.requires_grad = True
43
+ else:
44
+ param.requires_grad = False
45
+ else:
46
+ param.requires_grad = False
47
+ model.train()
48
+
49
+
50
+ class FrozenBackbone(nn.Module):
51
+ """Freezes the backbone of a network."""
52
+
53
+ def __init__(self, backbone, mid, head):
54
+ super().__init__()
55
+ self.backbone = backbone.instance
56
+ self.mid = mid.instance
57
+ self.head = head.instance
58
+ self.target_key = head.target_key
59
+ freeze(self.backbone)
60
+
61
+ def forward(self, x):
62
+ """Forward pass of the network.
63
+ x : Union[torch.Tensor, dict] with the output of the backbone.
64
+ """
65
+ with torch.no_grad():
66
+ x = self.backbone(x)
67
+ x = self.mid(x)
68
+ x = self.head(x)
69
+ return x
70
+
71
+
72
+ class UnfrozenBackbone(nn.Module):
73
+ """Unfreezes the backbone of a network."""
74
+
75
+ def __init__(self, backbone, mid, head):
76
+ super().__init__()
77
+ self.backbone = backbone.instance
78
+ self.mid = mid.instance
79
+ self.head = head.instance
80
+ self.target_key = head.target_key
81
+ unfreeze(self.backbone)
82
+
83
+ def forward(self, x):
84
+ """Forward pass of the network.
85
+ x : Union[torch.Tensor, dict] with the output of the backbone.
86
+ """
87
+ x = self.backbone(x)
88
+ x = self.mid(x)
89
+ x = self.head(x)
90
+ return x
91
+
92
+
93
+ class UnfrozenPartBackbone(nn.Module):
94
+ """Unfreezes the backbone of a network."""
95
+
96
+ def __init__(self, backbone, mid, head):
97
+ super().__init__()
98
+ self.backbone = backbone.instance
99
+ self.mid = mid.instance
100
+ self.head = head.instance
101
+ self.target_key = head.target_key
102
+ unfreeze_last(self.backbone)
103
+
104
+ def forward(self, x):
105
+ """Forward pass of the network.
106
+ x : Union[torch.Tensor, dict] with the output of the backbone.
107
+ """
108
+ x = self.backbone(x)
109
+ x = self.mid(x)
110
+ x = self.head(x)
111
+ return x
112
+
113
+
114
+ class NoFeatureBackbone(nn.Module):
115
+ """Randomizes the backbone of a network."""
116
+
117
+ def __init__(self, head):
118
+ super().__init__()
119
+ self.head = head.instance
120
+ self.target_key = head.target_key
121
+
122
+ def forward(self, x):
123
+ """Forward pass of the network.
124
+ x : Union[torch.Tensor, dict] with the output of the backbone.
125
+ """
126
+ return self.head(x)
127
+
128
+
129
+ class ContrastiveFrozenBackbone(FrozenBackbone):
130
+ """Freezes the backbone of a network."""
131
+
132
+ def __init__(self, backbone, mid, head, mode):
133
+ super().__init__(backbone, mid, head)
134
+ self.mode = mode
135
+
136
+ def forward(self, x):
137
+ with torch.no_grad():
138
+ features = self.backbone(x)
139
+ if self.mode != "eval":
140
+ x_pos = {
141
+ k.strip("pos_"): v.clone()
142
+ if isinstance(v, torch.Tensor)
143
+ else copy.deepcopy(v)
144
+ for k, v in x.items()
145
+ if k.startswith("pos_")
146
+ }
147
+ pos_features = self.backbone(x_pos)
148
+ x = self.mid(features)
149
+ x = self.head(x)
150
+ if self.mode != "eval":
151
+ return {
152
+ "features": features[:, 0, :],
153
+ "pos_features": pos_features[:, 0, :],
154
+ **x,
155
+ }
156
+ return {
157
+ "features": features[:, 0, :],
158
+ **x,
159
+ }
160
+
161
+
162
+ class ContrastiveUnFrozenPartBackbone(UnfrozenPartBackbone):
163
+ """Freezes the backbone of a network."""
164
+
165
+ def __init__(self, backbone, mid, head, mode):
166
+ super().__init__(backbone, mid, head)
167
+ self.mode = mode
168
+
169
+ def forward(self, x):
170
+ features = self.backbone(x)
171
+ if self.mode != "eval":
172
+ x_pos = {
173
+ k.strip("pos_"): v.clone()
174
+ if isinstance(v, torch.Tensor)
175
+ else copy.deepcopy(v)
176
+ for k, v in x.items()
177
+ if k.startswith("pos_")
178
+ }
179
+ pos_features = self.backbone(x_pos)
180
+ x = self.mid(features)
181
+ x = self.head(x)
182
+ if self.mode != "eval":
183
+ return {
184
+ "features": features[:, 0, :],
185
+ "pos_features": pos_features[:, 0, :],
186
+ **x,
187
+ }
188
+ return {
189
+ "features": features[:, 0, :],
190
+ **x,
191
+ }
192
+
193
+
194
+ class ContrastiveUnFrozenBackbone(UnfrozenBackbone):
195
+ """Freezes the backbone of a network."""
196
+
197
+ def __init__(self, backbone, mid, head, mode):
198
+ super().__init__(backbone, mid, head)
199
+ self.mode = mode
200
+
201
+ def forward(self, x):
202
+ features = self.backbone(x)
203
+ if self.mode != "eval":
204
+ x_pos = {
205
+ k.strip("pos_"): v.clone()
206
+ if isinstance(v, torch.Tensor)
207
+ else copy.deepcopy(v)
208
+ for k, v in x.items()
209
+ if k.startswith("pos_")
210
+ }
211
+ pos_features = self.backbone(x_pos)
212
+ x = self.mid(features)
213
+ x = self.head(x)
214
+ if self.mode != "eval":
215
+ return {
216
+ "features": features[:, 0, :],
217
+ "pos_features": pos_features[:, 0, :],
218
+ **x,
219
+ }
220
+ return {
221
+ "features": features[:, 0, :],
222
+ **x,
223
+ }
224
+
225
+
226
+ class TextContrastiveUnFrozenBackbone(UnfrozenBackbone):
227
+ """Freezes the backbone of a network."""
228
+
229
+ def __init__(self, backbone, mid, head):
230
+ super().__init__(backbone, mid, head)
231
+
232
+ def forward(self, x):
233
+ con, features = self.backbone(x)
234
+ x = self.mid(features)
235
+ x = self.head(x)
236
+ return {
237
+ "features": con,
238
+ **x,
239
+ }
240
+
241
+
242
+ class LoraBackbone(nn.Module):
243
+ """Wraps the backbone in a PEFT model for LoRA tuning."""
244
+
245
+ def __init__(self, backbone, mid, head, r, alpha, dropout, bias):
246
+ super().__init__()
247
+ self.backbone = backbone.instance
248
+ self.mid = mid.instance
249
+ self.head = head.instance
250
+ self.target_key = head.target_key
251
+ freeze(self.backbone)
252
+
253
+ config = LoraConfig(
254
+ r=r,
255
+ lora_alpha=alpha,
256
+ lora_dropout=dropout,
257
+ bias=bias,
258
+ target_modules=["q_proj", "k_proj", "v_proj"],
259
+ )
260
+ self.backbone = get_peft_model(self.backbone, config)
261
+ print_trainable_parameters(self)
262
+
263
+ def forward(self, x):
264
+ """Forward pass of the network.
265
+ x : Union[torch.Tensor, dict] with the output of the backbone.
266
+ """
267
+ x = self.backbone(x)
268
+ x = self.mid(x)
269
+ return self.head(x)
270
+
271
+
272
+ class HybridFrozenBackbone(FrozenBackbone):
273
+ """Freezes the backbone of a network."""
274
+
275
+ def forward(self, x):
276
+ """Forward pass of the network.
277
+ x : Union[torch.Tensor, dict] with the output of the backbone.
278
+ """
279
+
280
+ gt_label = x["label"] if self.training else None
281
+
282
+ with torch.no_grad():
283
+ x = self.backbone(x)
284
+ x = self.mid(x)
285
+ x = self.head(x, gt_label)
286
+ return x
287
+
288
+
289
+ class HybridUnfrozenBackbone(UnfrozenBackbone):
290
+ """Unfreezes the backbone of a network."""
291
+
292
+ def forward(self, x):
293
+ """Forward pass of the network.
294
+ x : Union[torch.Tensor, dict] with the output of the backbone.
295
+ """
296
+
297
+ gt_label = x["label"] if self.training else None
298
+
299
+ x = self.backbone(x)
300
+ x = self.mid(x)
301
+ x = self.head(x, gt_label)
302
+ return x
303
+
304
+
305
+ class ContrastiveHybridUnFrozenBackbone(UnfrozenBackbone):
306
+ """Freezes the backbone of a network."""
307
+
308
+ def __init__(self, backbone, mid, head, mode):
309
+ super().__init__(backbone, mid, head)
310
+ self.mode = mode
311
+
312
+ def forward(self, x):
313
+ gt_label = x["label"] if self.training else None
314
+ features = self.backbone(x)
315
+ if self.mode != "eval":
316
+ x_pos = {
317
+ k.strip("pos_"): v.clone()
318
+ if isinstance(v, torch.Tensor)
319
+ else copy.deepcopy(v)
320
+ for k, v in x.items()
321
+ if k.startswith("pos_")
322
+ }
323
+ pos_features = self.backbone(x_pos)
324
+ x = self.mid(features)
325
+ x = self.head(x, gt_label)
326
+ if self.mode != "eval":
327
+ return {
328
+ "features": features[:, 0, :],
329
+ "pos_features": pos_features[:, 0, :],
330
+ **x,
331
+ }
332
+ return {
333
+ "features": features[:, 0, :],
334
+ **x,
335
+ }
models/networks/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+
5
+
6
+ class NormGPS(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x):
11
+ """Normalize latitude longtitude radians to -1, 1.""" # not used currently
12
+ return x / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)
13
+
14
+
15
+ class UnormGPS(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, x):
20
+ """Unormalize latitude longtitude radians to -1, 1."""
21
+ x = torch.clamp(x, -1, 1)
22
+ return x * torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)
models/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import abspath as abp
3
+ import torch
4
+ import hydra
5
+ from hydra import initialize, compose
6
+ from models.module import Geolocalizer
7
+ from omegaconf import OmegaConf, open_dict
8
+ from os.path import join
9
+ from hydra.utils import instantiate
10
+
11
+ def load_model_config(path):
12
+ # given the directory of os.cwd()
13
+ # compute the relative path to path
14
+ path = abp(path)
15
+ rel_path = os.path.relpath(path, start=os.path.split(__file__)[0])
16
+
17
+ with initialize(version_base=None, config_path=rel_path):
18
+ cfg = compose(config_name="config", overrides=[])
19
+
20
+ checkpoint = torch.load(join(path, "last.ckpt"))
21
+ del checkpoint["state_dict"][
22
+ "model.backbone.clip.vision_model.embeddings.position_ids"
23
+ ]
24
+ torch.save(checkpoint, join(path, "last2.ckpt"))
25
+
26
+ with open_dict(cfg):
27
+ cfg.checkpoint = join(path, "last2.ckpt")
28
+
29
+ cfg.num_classes = 11399
30
+ cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
31
+ cfg.model.network.head.final_dim = cfg.num_classes * 3
32
+ cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")
33
+
34
+ cfg.dataset.train_dataset.path = ""
35
+ cfg.dataset.val_dataset.path = ""
36
+ cfg.dataset.test_dataset.path = ""
37
+ cfg.logger.save_dir = ""
38
+ cfg.data_dir = ""
39
+ cfg.root_dir = ""
40
+ cfg.mode = "test"
41
+ cfg.model.network.backbone.instance.path = (
42
+ "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
43
+ )
44
+ return cfg.dataset.test_transform, cfg.model, join(path, "last2.ckpt"), True
45
+
46
+ def load_model(path):
47
+ transform_config, model_config, checkpoint_path, delete = load_model_config(path)
48
+
49
+ transform = instantiate(transform_config)
50
+ model = Geolocalizer.load_from_checkpoint(checkpoint_path, cfg=model_config)
51
+ if delete:
52
+ os.remove(checkpoint_path)
53
+
54
+ return model, transform
scripts/download-dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, zipfile
2
+ from huggingface_hub import snapshot_download
3
+
4
+ # Define the base directory
5
+ base_dir = os.path.join(os.getcwd(), 'datasets')
6
+
7
+ # Ensure the base directory exists
8
+ if not os.path.exists(base_dir):
9
+ os.mkdir(base_dir)
10
+
11
+ # Define the specific dataset directory
12
+ dataset_dir = os.path.join(base_dir, "osv5m")
13
+
14
+ # Ensure the specific dataset directory exists
15
+ if not os.path.exists(dataset_dir):
16
+ os.mkdir(dataset_dir)
17
+
18
+ # Download the dataset
19
+ snapshot_download(repo_id="osv5m/osv5m", local_dir=dataset_dir, repo_type='dataset')
20
+
21
+ # Extract zip files and remove them after extraction
22
+ for root, dirs, files in os.walk(dataset_dir):
23
+ for file in files:
24
+ if file.endswith(".zip"):
25
+ with zipfile.ZipFile(os.path.join(root, file), 'r') as zip_ref:
26
+ zip_ref.extractall(root)
27
+ os.remove(os.path.join(root, file))
scripts/preprocessing/enrich-metadata-adaptive-quadtrees.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import statistics
6
+ from os.path import join, dirname
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ class QuadTree(object):
11
+ def __init__(self, data, id="", depth=3, do_split=5000):
12
+ self.id = id
13
+ self.data = data
14
+
15
+ coord = data[["latitude", "longitude"]].to_numpy()
16
+
17
+ # if mins is None:
18
+ mins = coord.min(0)
19
+ # if maxs is None:
20
+ maxs = coord.max(0)
21
+
22
+ self.mins = np.asarray(mins)
23
+ self.maxs = np.asarray(maxs)
24
+ self.sizes = self.maxs - self.mins
25
+
26
+ self.children = []
27
+
28
+ # sort by latitude
29
+ sorted_data_lat = sorted(coord, key=lambda point: point[0])
30
+
31
+ # get the median lat
32
+ median_lat = statistics.median(point[0] for point in sorted_data_lat)
33
+
34
+ # Divide the cell into two half-cells based on the median lat
35
+ data_left = [point for point in sorted_data_lat if point[0] <= median_lat]
36
+ data_right = [point for point in sorted_data_lat if point[0] > median_lat]
37
+
38
+ # Sort the data points by long in each half-cell
39
+ sorted_data_left_lon = sorted(data_left, key=lambda point: point[1])
40
+ sorted_data_right_lon = sorted(data_right, key=lambda point: point[1])
41
+
42
+ # Calculate the median ylong coordinate in each half-cell
43
+ median_lon_left = statistics.median(point[1] for point in sorted_data_left_lon)
44
+ median_lon_right = statistics.median(
45
+ point[1] for point in sorted_data_right_lon
46
+ )
47
+
48
+ if (depth > 0) and (len(self.data) >= do_split):
49
+ # split the data into four quadrants
50
+ data_q1 = data[
51
+ (data["latitude"] < median_lat) & (data["longitude"] < median_lon_left)
52
+ ]
53
+ data_q2 = data[
54
+ (data["latitude"] < median_lat) & (data["longitude"] >= median_lon_left)
55
+ ]
56
+ data_q3 = data[
57
+ (data["latitude"] >= median_lat)
58
+ & (data["longitude"] < median_lon_right)
59
+ ]
60
+ data_q4 = data[
61
+ (data["latitude"] >= median_lat)
62
+ & (data["longitude"] >= median_lon_right)
63
+ ]
64
+
65
+ # recursively build a quad tree on each quadrant which has data
66
+ if data_q1.shape[0] > 0:
67
+ self.children.append(
68
+ QuadTree(
69
+ data_q1,
70
+ id + "0",
71
+ depth - 1,
72
+ do_split=do_split,
73
+ )
74
+ )
75
+ if data_q2.shape[0] > 0:
76
+ self.children.append(
77
+ QuadTree(
78
+ data_q2,
79
+ id + "1",
80
+ depth - 1,
81
+ do_split=do_split,
82
+ )
83
+ )
84
+ if data_q3.shape[0] > 0:
85
+ self.children.append(
86
+ QuadTree(
87
+ data_q3,
88
+ id + "2",
89
+ depth - 1,
90
+ do_split=do_split,
91
+ )
92
+ )
93
+ if data_q4.shape[0] > 0:
94
+ self.children.append(
95
+ QuadTree(
96
+ data_q4,
97
+ id + "3",
98
+ depth - 1,
99
+ do_split=do_split,
100
+ )
101
+ )
102
+
103
+ def unwrap(self):
104
+ if len(self.children) == 0:
105
+ return {self.id: [self.mins, self.maxs, self.data.copy()]}
106
+ else:
107
+ d = dict()
108
+ for child in self.children:
109
+ d.update(child.unwrap())
110
+ return d
111
+
112
+
113
+ def extract(qt, name_new_column):
114
+ cluster = qt.unwrap()
115
+ boundaries, data = {}, []
116
+ for i, (id, vs) in zip(np.arange(len(cluster)), cluster.items()):
117
+ (min_lat, min_lon), (max_lat, max_lon), points = vs
118
+ points[name_new_column] = int(i)
119
+ data.append(points)
120
+ boundaries[i] = (
121
+ float(min_lat),
122
+ float(min_lon),
123
+ float(max_lat),
124
+ float(max_lon),
125
+ points["latitude"].mean(),
126
+ points["longitude"].mean(),
127
+ )
128
+
129
+ data = pd.concat(data)
130
+ return boundaries, data
131
+
132
+
133
+ def vizu(name_new_column, df_train, boundaries, do_split):
134
+ plt.hist(df_train[name_new_column], bins=len(boundaries))
135
+ plt.xlabel("Cluster ID")
136
+ plt.ylabel("Number of images")
137
+ plt.title("Cluster distribution")
138
+ plt.yscale("log")
139
+ plt.ylim(10, do_split)
140
+ plt.savefig(f"{name_new_column}_distrib.png")
141
+ plt.clf()
142
+
143
+ plt.scatter(
144
+ df_train["longitude"].to_numpy(),
145
+ df_train["latitude"].to_numpy(),
146
+ c=np.random.permutation(len(boundaries))[df_train[name_new_column].to_numpy()],
147
+ cmap="tab20",
148
+ s=0.1,
149
+ alpha=0.5,
150
+ )
151
+ plt.xlabel("Longitude")
152
+ plt.ylabel("Latitude")
153
+ plt.title("Quadtree map")
154
+ plt.savefig(f"{name_new_column}_map.png")
155
+
156
+
157
+ @hydra.main(
158
+ config_path="../configs/scripts",
159
+ config_name="enrich-metadata-quadtree",
160
+ version_base=None,
161
+ )
162
+ def main(cfg):
163
+
164
+ data_path = join(cfg.data_dir, "osv5m")
165
+ name_new_column = f"adaptive_quadtree_{cfg.depth}_{cfg.do_split}"
166
+
167
+ # Create clusters from train images
168
+ train_fp = join(data_path, f"train.csv")
169
+ df_train = pd.read_csv(train_fp)
170
+
171
+ qt = QuadTree(df_train, depth=cfg.depth, do_split=cfg.do_split)
172
+ boundaries, df_train = extract(qt, name_new_column)
173
+
174
+ vizu(name_new_column, df_train, boundaries, cfg.do_split)
175
+
176
+ # Save clusters
177
+ boundaries = pd.DataFrame.from_dict(
178
+ boundaries,
179
+ orient="index",
180
+ columns=["min_lat", "min_lon", "max_lat", "max_lon", "mean_lat", "mean_lon"],
181
+ )
182
+ boundaries.to_csv(f"{name_new_column}.csv", index_label="cluster_id")
183
+
184
+ # Assign test images to clusters
185
+ test_fp = join(data_path, f"test.csv")
186
+ df_test = pd.read_csv(test_fp)
187
+
188
+ above_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) > np.expand_dims(
189
+ boundaries["min_lat"].to_numpy(), 0
190
+ )
191
+ below_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) < np.expand_dims(
192
+ boundaries["max_lat"].to_numpy(), 0
193
+ )
194
+ above_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) > np.expand_dims(
195
+ boundaries["min_lon"].to_numpy(), 0
196
+ )
197
+ below_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) < np.expand_dims(
198
+ boundaries["max_lon"].to_numpy(), 0
199
+ )
200
+
201
+ mask = np.logical_and(
202
+ np.logical_and(above_lat, below_lat), np.logical_and(above_lon, below_lon)
203
+ )
204
+
205
+ df_test[name_new_column] = np.argmax(mask, axis=1)
206
+
207
+ # save index_to_gps_quadtree file
208
+ lat = torch.tensor(boundaries["mean_lat"])
209
+ lon = torch.tensor(boundaries["mean_lon"])
210
+ coord = torch.stack([lat / 90, lon / 180], dim=-1)
211
+ torch.save(
212
+ coord,
213
+ join(
214
+ data_path, f"index_to_gps_adaptive_quadtree_{cfg.depth}_{cfg.do_split}.pt"
215
+ ),
216
+ )
217
+
218
+ # Overwrite test.csv and train.csv
219
+ if cfg.overwrite_csv:
220
+ df_train.to_csv(train_fp, index=False)
221
+ df_test.to_csv(test_fp, index=False)
222
+
223
+
224
+ if __name__ == "__main__":
225
+ main()
scripts/preprocessing/enrich-metadata-quadtree.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import numpy as np
3
+ import pandas as pd
4
+ from os.path import join, dirname
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+
8
+
9
+ class QuadTree(object):
10
+ def __init__(self, data, mins=None, maxs=None, id="", depth=3, do_split=1000):
11
+ self.id = id
12
+ self.data = data
13
+
14
+ if mins is None:
15
+ mins = data[["latitude", "longitude"]].to_numpy().min(0)
16
+ if maxs is None:
17
+ maxs = data[["latitude", "longitude"]].to_numpy().max(0)
18
+
19
+ self.mins = np.asarray(mins)
20
+ self.maxs = np.asarray(maxs)
21
+ self.sizes = self.maxs - self.mins
22
+
23
+ self.children = []
24
+
25
+ mids = 0.5 * (self.mins + self.maxs)
26
+ xmin, ymin = self.mins
27
+ xmax, ymax = self.maxs
28
+ xmid, ymid = mids
29
+
30
+ if (depth > 0) and (len(self.data) >= do_split):
31
+ # split the data into four quadrants
32
+ data_q1 = data[(data["latitude"] < mids[0]) & (data["longitude"] < mids[1])]
33
+ data_q2 = data[
34
+ (data["latitude"] < mids[0]) & (data["longitude"] >= mids[1])
35
+ ]
36
+ data_q3 = data[
37
+ (data["latitude"] >= mids[0]) & (data["longitude"] < mids[1])
38
+ ]
39
+ data_q4 = data[
40
+ (data["latitude"] >= mids[0]) & (data["longitude"] >= mids[1])
41
+ ]
42
+
43
+ # recursively build a quad tree on each quadrant which has data
44
+ if data_q1.shape[0] > 0:
45
+ self.children.append(
46
+ QuadTree(
47
+ data_q1,
48
+ [xmin, ymin],
49
+ [xmid, ymid],
50
+ id + "0",
51
+ depth - 1,
52
+ do_split=do_split,
53
+ )
54
+ )
55
+ if data_q2.shape[0] > 0:
56
+ self.children.append(
57
+ QuadTree(
58
+ data_q2,
59
+ [xmin, ymid],
60
+ [xmid, ymax],
61
+ id + "1",
62
+ depth - 1,
63
+ do_split=do_split,
64
+ )
65
+ )
66
+ if data_q3.shape[0] > 0:
67
+ self.children.append(
68
+ QuadTree(
69
+ data_q3,
70
+ [xmid, ymin],
71
+ [xmax, ymid],
72
+ id + "2",
73
+ depth - 1,
74
+ do_split=do_split,
75
+ )
76
+ )
77
+ if data_q4.shape[0] > 0:
78
+ self.children.append(
79
+ QuadTree(
80
+ data_q4,
81
+ [xmid, ymid],
82
+ [xmax, ymax],
83
+ id + "3",
84
+ depth - 1,
85
+ do_split=do_split,
86
+ )
87
+ )
88
+
89
+ def unwrap(self):
90
+ if len(self.children) == 0:
91
+ return {self.id: [self.mins, self.maxs, self.data.copy()]}
92
+ else:
93
+ d = dict()
94
+ for child in self.children:
95
+ d.update(child.unwrap())
96
+ return d
97
+
98
+
99
+ def extract(qt, name_new_column):
100
+ cluster = qt.unwrap()
101
+ boundaries, data = {}, []
102
+ id_to_quad = np.array(list(cluster.keys()))
103
+ for i, (id, vs) in zip(np.arange(len(cluster)), cluster.items()):
104
+ (min_lat, min_lon), (max_lat, max_lon), points = vs
105
+ points[name_new_column] = int(i)
106
+ data.append(points)
107
+ boundaries[i] = (
108
+ float(min_lat),
109
+ float(min_lon),
110
+ float(max_lat),
111
+ float(max_lon),
112
+ points["latitude"].mean(),
113
+ points["longitude"].mean(),
114
+ )
115
+
116
+ data = pd.concat(data)
117
+ return boundaries, data, id_to_quad
118
+
119
+
120
+ def vizu(name_new_column, df_train, boundaries):
121
+ plt.hist(df_train[name_new_column], bins=len(boundaries))
122
+ plt.xlabel("Cluster ID")
123
+ plt.ylabel("Number of images")
124
+ plt.title("Cluster distribution")
125
+ plt.yscale("log")
126
+ plt.savefig(f"{name_new_column}_distrib.png")
127
+ plt.clf()
128
+
129
+ plt.scatter(
130
+ df_train["longitude"].to_numpy(),
131
+ df_train["latitude"].to_numpy(),
132
+ c=np.random.permutation(len(boundaries))[df_train[name_new_column].to_numpy()],
133
+ cmap="tab20",
134
+ s=0.1,
135
+ alpha=0.5,
136
+ )
137
+ plt.xlabel("Longitude")
138
+ plt.ylabel("Latitude")
139
+ plt.title("Quadtree map")
140
+ plt.savefig(f"{name_new_column}_map.png")
141
+
142
+
143
+ @hydra.main(
144
+ config_path="../configs/scripts",
145
+ config_name="enrich-metadata-quadtree",
146
+ version_base=None,
147
+ )
148
+ def main(cfg):
149
+ data_path = join(cfg.data_dir, "osv5m")
150
+ name_new_column = f"quadtree_{cfg.depth}_{cfg.do_split}"
151
+
152
+ # Create clusters from train images
153
+ train_fp = join(data_path, f"train.csv")
154
+ df_train = pd.read_csv(train_fp)
155
+
156
+ qt = QuadTree(df_train, depth=cfg.depth, do_split=cfg.do_split)
157
+ boundaries, df_train, id_to_quad = extract(qt, name_new_column)
158
+
159
+ vizu(name_new_column, df_train, boundaries)
160
+
161
+ # Save clusters
162
+ boundaries = pd.DataFrame.from_dict(
163
+ boundaries,
164
+ orient="index",
165
+ columns=["min_lat", "min_lon", "max_lat", "max_lon", "mean_lat", "mean_lon"],
166
+ )
167
+ boundaries.to_csv(f"{name_new_column}.csv", index_label="cluster_id")
168
+
169
+ # Assign test images to clusters
170
+ test_fp = join(data_path, f"test.csv")
171
+ df_test = pd.read_csv(test_fp)
172
+
173
+ above_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) > np.expand_dims(
174
+ boundaries["min_lat"].to_numpy(), 0
175
+ )
176
+ below_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) < np.expand_dims(
177
+ boundaries["max_lat"].to_numpy(), 0
178
+ )
179
+ above_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) > np.expand_dims(
180
+ boundaries["min_lon"].to_numpy(), 0
181
+ )
182
+ below_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) < np.expand_dims(
183
+ boundaries["max_lon"].to_numpy(), 0
184
+ )
185
+
186
+ mask = np.logical_and(
187
+ np.logical_and(above_lat, below_lat), np.logical_and(above_lon, below_lon)
188
+ )
189
+
190
+ df_test[name_new_column] = np.argmax(mask, axis=1)
191
+
192
+ # save index_to_gps_quadtree file
193
+ lat = torch.tensor(boundaries["mean_lat"])
194
+ lon = torch.tensor(boundaries["mean_lon"])
195
+ coord = torch.stack([lat / 90, lon / 180], dim=-1)
196
+ torch.save(
197
+ coord, join(data_path, f"index_to_gps_quadtree_{cfg.depth}_{cfg.do_split}.pt")
198
+ )
199
+
200
+ torch.save(id_to_quad, join(data_path, f"id_to_quad_{cfg.depth}_{cfg.do_split}.pt"))
201
+ # Overwrite test.csv and train.csv
202
+ if cfg.overwrite_csv:
203
+ df_train.to_csv(train_fp, index=False)
204
+ df_test.to_csv(test_fp, index=False)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
scripts/preprocessing/enrich-metadata.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import pandas as pd
5
+ import numpy as np
6
+ import reverse_geocoder
7
+ from os.path import join, dirname
8
+
9
+
10
+ class QuadTree(object):
11
+ def __init__(
12
+ self, data, mins=None, maxs=None, id="", depth=3, min_split=0, do_split=1000
13
+ ):
14
+ self.id = id
15
+ self.data = data
16
+
17
+ if mins is None:
18
+ mins = data[["latitude", "longitude"]].to_numpy().min(0)
19
+ if maxs is None:
20
+ maxs = data[["latitude", "longitude"]].to_numpy().max(0)
21
+
22
+ self.mins = np.asarray(mins)
23
+ self.maxs = np.asarray(maxs)
24
+ self.sizes = self.maxs - self.mins
25
+
26
+ self.children = []
27
+
28
+ mids = 0.5 * (self.mins + self.maxs)
29
+ xmin, ymin = self.mins
30
+ xmax, ymax = self.maxs
31
+ xmid, ymid = mids
32
+
33
+ if depth > 0 and len(self.data) >= do_split:
34
+ # split the data into four quadrants
35
+ data_q1 = data[(data["latitude"] < mids[0]) & (data["longitude"] < mids[1])]
36
+ data_q2 = data[
37
+ (data["latitude"] < mids[0]) & (data["longitude"] >= mids[1])
38
+ ]
39
+ data_q3 = data[
40
+ (data["latitude"] >= mids[0]) & (data["longitude"] < mids[1])
41
+ ]
42
+ data_q4 = data[
43
+ (data["latitude"] >= mids[0]) & (data["longitude"] >= mids[1])
44
+ ]
45
+
46
+ # recursively build a quad tree on each quadrant which has data
47
+ if data_q1.shape[0] > min_split:
48
+ self.children.append(
49
+ QuadTree(data_q1, [xmin, ymin], [xmid, ymid], id + "0", depth - 1)
50
+ )
51
+ if data_q2.shape[0] > min_split:
52
+ self.children.append(
53
+ QuadTree(data_q2, [xmin, ymid], [xmid, ymax], id + "1", depth - 1)
54
+ )
55
+ if data_q3.shape[0] > min_split:
56
+ self.children.append(
57
+ QuadTree(data_q3, [xmid, ymin], [xmax, ymid], id + "2", depth - 1)
58
+ )
59
+ if data_q4.shape[0] > min_split:
60
+ self.children.append(
61
+ QuadTree(data_q4, [xmid, ymid], [xmax, ymax], id + "3", depth - 1)
62
+ )
63
+
64
+ def unwrap(self):
65
+ if len(self.children) == 0:
66
+ return {self.id: [self.mins, self.maxs, self.data.copy()]}
67
+ else:
68
+ d = dict()
69
+ for child in self.children:
70
+ d.update(child.unwrap())
71
+ return d
72
+
73
+
74
+ def extract(qt):
75
+ cluster = qt.unwrap()
76
+ boundaries, data = {}, []
77
+ for id, vs in cluster.items():
78
+ (min_lat, min_lon), (max_lat, max_lon), points = vs
79
+ points["category"] = id
80
+ data.append(points)
81
+ boundaries[id] = (
82
+ float(min_lat),
83
+ float(min_lon),
84
+ float(max_lat),
85
+ float(max_lon),
86
+ )
87
+
88
+ data = pd.concat(data)
89
+ return boundaries, data
90
+
91
+
92
+ if __name__ == "__main__":
93
+ # merge into one DataFrame
94
+ data_path = join(dirname(dirname(__file__)), "datasets", "osv5m")
95
+ train_fp = join(data_path, f"train.csv")
96
+ test_fp = join(data_path, f"test.csv")
97
+
98
+ df_train = pd.read_csv(train_fp)
99
+ df_train["split"] = "train"
100
+
101
+ df_test = pd.read_csv(test_fp)
102
+ df_test["split"] = "test"
103
+
104
+ df = pd.concat([df_train, df_test])
105
+ size_before = df.shape[0]
106
+ qt = QuadTree(df, depth=15)
107
+ boundaries, df = extract(qt)
108
+ assert df.shape[0] == size_before
109
+
110
+ location = reverse_geocoder.search(
111
+ [(lat, lon) for lat, lon in zip(df["latitude"], df["longitude"])]
112
+ )
113
+ df["city"] = [l.get("name", "") for l in location]
114
+ df["country"] = [l.get("cc", "") for l in location]
115
+ del location
116
+
117
+ df_train = df[df["split"] == "train"].drop(["split"], axis=1)
118
+ df_test = df[df["split"] == "test"].drop(["split"], axis=1)
119
+ assert (df_train.shape[0] + df_test.shape[0]) == size_before
120
+
121
+ json.dump(boundaries, open(join(data_path, "borders.json"), "w"))
122
+ df_train.to_csv(train_fp, index=False)
123
+ df_test.to_csv(test_fp, index=False)
scripts/preprocessing/fix_namimbia.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join, dirname
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ if __name__ == "__main__":
6
+ # Define the list of cities
7
+ cities = [
8
+ "Walvis Bay",
9
+ "Keetmanshoop",
10
+ "Warmbad",
11
+ "Rundu",
12
+ "Outapi",
13
+ "Karibib",
14
+ "Otjimbingwe",
15
+ "Ondangwa",
16
+ "Oranjemund",
17
+ "Maltahohe",
18
+ "Otavi",
19
+ "Outjo",
20
+ "Swakopmund",
21
+ "Gobabis",
22
+ "Karasburg",
23
+ "Opuwo",
24
+ "Hentiesbaai",
25
+ "Katima Mulilo",
26
+ "Oshikango",
27
+ "Bethanie",
28
+ "Ongandjera",
29
+ "Mariental",
30
+ "Bagani",
31
+ "Nkurenkuru",
32
+ "Usakos",
33
+ "Rehoboth",
34
+ "Aranos",
35
+ "Omaruru",
36
+ "Arandis",
37
+ "Windhoek",
38
+ "Khorixas",
39
+ "Okahandja",
40
+ "Grootfontein",
41
+ "Tsumeb",
42
+ ]
43
+
44
+ csv_dtype = {"category": str, "country": str, "city": str}
45
+ for split in ["train", "test"]:
46
+ fp = join(
47
+ dirname(dirname(__file__)), "datasets", "osv5m", f"{split}.csv"
48
+ )
49
+
50
+ # Read the CSV file into a pandas DataFrame
51
+ df = pd.read_csv(fp, dtype=csv_dtype)
52
+
53
+ # Check if the "country" column contains any of the cities in the list
54
+ mask = df["city"].isin(cities)
55
+
56
+ # If a city is found, set the corresponding rows in the "country" column to 'NMB'
57
+ df.loc[mask, "country"] = "NMB"
58
+ assert all(map(lambda x: isinstance(x, str), df["country"].unique().tolist()))
59
+
60
+ # Drop the columns that are all NaN
61
+ df.dropna(subset=["id", "latitude", "longitude"], inplace=True)
62
+
63
+ # Save the modified DataFrame back to the CSV file
64
+ df.to_csv(fp, index=False)
scripts/preprocessing/nearest-neighbors.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import json
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ from os.path import dirname, join
6
+
7
+ sys.path.append(dirname(dirname(__file__)))
8
+
9
+ import torch
10
+ from transformers import AutoImageProcessor, AutoModel
11
+ from transformers import CLIPProcessor, CLIPModel
12
+ from transformers import pipeline
13
+
14
+ from data.data import osv5m
15
+ from json_stream import streamable_list
16
+
17
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+
20
+ def load_model_clip():
21
+ model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
22
+ processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
23
+ return processor, model.to(DEVICE)
24
+
25
+
26
+ def load_model_dino():
27
+ model = AutoModel.from_pretrained("facebook/dinov2-base")
28
+ processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
29
+ return processor, model.to(DEVICE)
30
+
31
+
32
+ def compute_dino(processor, model, x):
33
+ inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
34
+ outputs = model(**inputs)
35
+ last_hidden_states = outputs.last_hidden_state.cpu().numpy()
36
+ for i in range(len(x[0])):
37
+ yield [last_hidden_states[i].tolist(), x[1][i], x[2][i], x[3][i]]
38
+
39
+
40
+ def compute_clip(processor, model, x):
41
+ inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
42
+ features = model.get_image_features(**inputs)
43
+ features /= features.norm(dim=-1, keepdim=True)
44
+ features = features.cpu().numpy()
45
+ for i in range(len(x[0])):
46
+ yield [features[i].tolist(), x[1][i], x[2][i], x[3][i]]
47
+
48
+
49
+ def get_batch(dataset, batch_size):
50
+ data, lats, lons, ids = [], [], [], []
51
+ for i in range(len(dataset)):
52
+ id, lat, lon = dataset.df.iloc[i]
53
+ data.append(Image.open(join(dataset.image_folder, f"{int(id)}.jpg")))
54
+ lats.append(lat)
55
+ lons.append(lon)
56
+ ids.append(id)
57
+ if len(data) == batch_size:
58
+ yield data, lats, lons, ids
59
+ data, lats, lons, ids = [], [], [], []
60
+
61
+ if len(data) > 0:
62
+ yield data, lats, lons, ids
63
+ data, lats, lons, ids = [], [], [], []
64
+
65
+
66
+ if __name__ == "__main__":
67
+ import argparse
68
+
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument("--batch_size", type=int, default=256)
71
+ parser.add_argument("--compute_features", action="store_true")
72
+ parser.add_argument("--compute_nearest", action="store_true")
73
+ parser.add_argument("--json_path", default="features")
74
+ parser.add_argument("--which", type=str, default="clip", choices=["clip", "dino"])
75
+ args = parser.parse_args()
76
+ json_path = join(args.json_path, args.which)
77
+
78
+ os.makedirs(json_path, exist_ok=True)
79
+ if args.compute_features:
80
+ processor, model = (
81
+ load_model_clip() if args.which == "clip" else load_model_dino()
82
+ )
83
+ compute_fn = compute_clip if args.which == "clip" else compute_dino
84
+
85
+ for split in ["test"]: #'train',
86
+ # open existing json and read as dictionary
87
+ json_path_ = join(json_path, f"{split}.json")
88
+
89
+ dataset = osv5m(
90
+ "datasets/osv5m", transforms=None, split=split, dont_split=True
91
+ )
92
+
93
+ @torch.no_grad()
94
+ def compute(batch_size):
95
+ for data in tqdm(
96
+ get_batch(dataset, batch_size),
97
+ total=len(dataset) // batch_size,
98
+ desc=f"Computing {split} on {args.which}",
99
+ ):
100
+ features = compute_fn(processor, model, data)
101
+ for feature, lat, lon, id in features:
102
+ yield feature, lat, lon, id
103
+
104
+ data = streamable_list(compute(args.batch_size))
105
+ json.dump(data, open(json_path_, "w"), indent=4)
106
+
107
+ if args.compute_nearest:
108
+ from sklearn.metrics.pairwise import cosine_similarity
109
+ import numpy as np
110
+
111
+ train, test = [
112
+ json.load(open(join(json_path, f"{split}.json"), "r"))
113
+ for split in ["train", "test"]
114
+ ]
115
+
116
+ def get_neighbors(k=10):
117
+ for i, test_data in enumerate(tqdm(test)):
118
+ feature, lat, lon, id = test_data
119
+ features_train = np.stack(
120
+ [np.array(train_data[0]) for train_data in train]
121
+ )
122
+ cs = np.squeeze(
123
+ cosine_similarity(np.expand_dims(feature, axis=0), features_train),
124
+ axis=0,
125
+ )
126
+ i = np.argsort(cs)[-k:][::-1].tolist()
127
+ yield [
128
+ {n: x}
129
+ for idx in i
130
+ for n, x in zip(
131
+ ["feature", "lat", "lon", "id", "distance"],
132
+ train[idx]
133
+ + [
134
+ cs[idx],
135
+ ],
136
+ )
137
+ ]
138
+
139
+ data = streamable_list(get_neighbors())
140
+ json.dump(data, open(join(json_path, "nearest.json"), "w"), indent=4)
scripts/preprocessing/preprocess.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import numpy as np
4
+ from os.path import join
5
+ import matplotlib.pyplot as plt
6
+ import hydra
7
+
8
+
9
+ class QuadTree(object):
10
+ def __init__(self, data, mins=None, maxs=None, id="", depth=3, do_split=1000):
11
+ self.id = id
12
+ self.data = data
13
+
14
+ if mins is None:
15
+ mins = data[["latitude", "longitude"]].to_numpy().min(0)
16
+ if maxs is None:
17
+ maxs = data[["latitude", "longitude"]].to_numpy().max(0)
18
+
19
+ self.mins = np.asarray(mins)
20
+ self.maxs = np.asarray(maxs)
21
+ self.sizes = self.maxs - self.mins
22
+
23
+ self.children = []
24
+
25
+ mids = 0.5 * (self.mins + self.maxs)
26
+ xmin, ymin = self.mins
27
+ xmax, ymax = self.maxs
28
+ xmid, ymid = mids
29
+
30
+ if (depth > 0) and (len(self.data) >= do_split):
31
+ # split the data into four quadrants
32
+ data_q1 = data[(data["latitude"] < mids[0]) & (data["longitude"] < mids[1])]
33
+ data_q2 = data[
34
+ (data["latitude"] < mids[0]) & (data["longitude"] >= mids[1])
35
+ ]
36
+ data_q3 = data[
37
+ (data["latitude"] >= mids[0]) & (data["longitude"] < mids[1])
38
+ ]
39
+ data_q4 = data[
40
+ (data["latitude"] >= mids[0]) & (data["longitude"] >= mids[1])
41
+ ]
42
+
43
+ # recursively build a quad tree on each quadrant which has data
44
+ if data_q1.shape[0] > 0:
45
+ self.children.append(
46
+ QuadTree(
47
+ data_q1,
48
+ [xmin, ymin],
49
+ [xmid, ymid],
50
+ id + "0",
51
+ depth - 1,
52
+ do_split=do_split,
53
+ )
54
+ )
55
+ if data_q2.shape[0] > 0:
56
+ self.children.append(
57
+ QuadTree(
58
+ data_q2,
59
+ [xmin, ymid],
60
+ [xmid, ymax],
61
+ id + "1",
62
+ depth - 1,
63
+ do_split=do_split,
64
+ )
65
+ )
66
+ if data_q3.shape[0] > 0:
67
+ self.children.append(
68
+ QuadTree(
69
+ data_q3,
70
+ [xmid, ymin],
71
+ [xmax, ymid],
72
+ id + "2",
73
+ depth - 1,
74
+ do_split=do_split,
75
+ )
76
+ )
77
+ if data_q4.shape[0] > 0:
78
+ self.children.append(
79
+ QuadTree(
80
+ data_q4,
81
+ [xmid, ymid],
82
+ [xmax, ymax],
83
+ id + "3",
84
+ depth - 1,
85
+ do_split=do_split,
86
+ )
87
+ )
88
+
89
+ def unwrap(self):
90
+ if len(self.children) == 0:
91
+ return {self.id: [self.mins, self.maxs, self.data.copy()]}
92
+ else:
93
+ d = dict()
94
+ for child in self.children:
95
+ d.update(child.unwrap())
96
+ return d
97
+
98
+
99
+ def extract(qt, name_new_column):
100
+ cluster = qt.unwrap()
101
+ boundaries, data = {}, []
102
+ id_to_quad = np.array(list(cluster.keys()))
103
+ for i, (id, vs) in zip(np.arange(len(cluster)), cluster.items()):
104
+ (min_lat, min_lon), (max_lat, max_lon), points = vs
105
+ points[name_new_column] = int(i)
106
+ data.append(points)
107
+ boundaries[i] = (
108
+ float(min_lat),
109
+ float(min_lon),
110
+ float(max_lat),
111
+ float(max_lon),
112
+ points["latitude"].mean(),
113
+ points["longitude"].mean(),
114
+ )
115
+
116
+ data = pd.concat(data)
117
+ return boundaries, data, id_to_quad
118
+
119
+
120
+ def vizu(name_new_column, df_train, boundaries, save_path):
121
+ plt.hist(df_train[name_new_column], bins=len(boundaries))
122
+ plt.xlabel("Cluster ID")
123
+ plt.ylabel("Number of images")
124
+ plt.title("Cluster distribution")
125
+ plt.yscale("log")
126
+ plt.savefig(join(save_path, f"{name_new_column}_distrib.png"))
127
+ plt.clf()
128
+
129
+ plt.scatter(
130
+ df_train["longitude"].to_numpy(),
131
+ df_train["latitude"].to_numpy(),
132
+ c=np.random.permutation(len(boundaries))[df_train[name_new_column].to_numpy()],
133
+ cmap="tab20",
134
+ s=0.1,
135
+ alpha=0.5,
136
+ )
137
+ plt.xlabel("Longitude")
138
+ plt.ylabel("Latitude")
139
+ plt.title("Quadtree map")
140
+ plt.savefig(join(save_path, f"{name_new_column}_map.png"))
141
+
142
+
143
+ @hydra.main(
144
+ config_path="../../configs/scripts",
145
+ config_name="preprocess",
146
+ version_base=None,
147
+ )
148
+ def main(cfg):
149
+ data_path = join(cfg.data_dir, "osv5m")
150
+ save_path = cfg.data_dir
151
+ name_new_column = f"quadtree_{cfg.depth}_{cfg.do_split}"
152
+
153
+ # Create clusters from train images
154
+ train_fp = join(data_path, f"train.csv")
155
+ df_train = pd.read_csv(train_fp, low_memory=False)
156
+
157
+ qt = QuadTree(df_train, depth=cfg.depth, do_split=cfg.do_split)
158
+ boundaries, df_train, id_to_quad = extract(qt, name_new_column)
159
+
160
+ vizu(name_new_column, df_train, boundaries, save_path)
161
+
162
+ # Save clusters
163
+ boundaries = pd.DataFrame.from_dict(
164
+ boundaries,
165
+ orient="index",
166
+ columns=["min_lat", "min_lon", "max_lat", "max_lon", "mean_lat", "mean_lon"],
167
+ )
168
+ boundaries.to_csv(
169
+ join(save_path, f"{name_new_column}.csv"), index_label="cluster_id"
170
+ )
171
+
172
+ # Assign test images to clusters
173
+ test_fp = join(data_path, f"test.csv")
174
+ df_test = pd.read_csv(test_fp)
175
+
176
+ above_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) > np.expand_dims(
177
+ boundaries["min_lat"].to_numpy(), 0
178
+ )
179
+ below_lat = np.expand_dims(df_test["latitude"].to_numpy(), -1) < np.expand_dims(
180
+ boundaries["max_lat"].to_numpy(), 0
181
+ )
182
+ above_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) > np.expand_dims(
183
+ boundaries["min_lon"].to_numpy(), 0
184
+ )
185
+ below_lon = np.expand_dims(df_test["longitude"].to_numpy(), -1) < np.expand_dims(
186
+ boundaries["max_lon"].to_numpy(), 0
187
+ )
188
+
189
+ mask = np.logical_and(
190
+ np.logical_and(above_lat, below_lat), np.logical_and(above_lon, below_lon)
191
+ )
192
+
193
+ df_test[name_new_column] = np.argmax(mask, axis=1)
194
+
195
+ # save index_to_gps_quadtree file
196
+ lat = torch.tensor(boundaries["mean_lat"])
197
+ lon = torch.tensor(boundaries["mean_lon"])
198
+ coord = torch.stack([lat, lon], dim=-1)
199
+ torch.save(
200
+ coord, join(save_path, f"index_to_gps_quadtree_{cfg.depth}_{cfg.do_split}.pt")
201
+ )
202
+
203
+ torch.save(id_to_quad, join(save_path, f"id_to_quad_{cfg.depth}_{cfg.do_split}.pt"))
204
+ # Overwrite test.csv and train.csv
205
+ if cfg.overwrite_csv:
206
+ df_train.to_csv(train_fp, index=False)
207
+ df_test.to_csv(test_fp, index=False)
208
+
209
+ df = pd.read_csv(join(data_path, "train.csv"), low_memory=False).fillna("NaN")
210
+ # Compute the average location for each unique country
211
+ country_avg = (
212
+ df.groupby("unique_country")[["latitude", "longitude"]].mean().reset_index()
213
+ )
214
+ country_avg.to_csv(
215
+ join(save_path, "country_center.csv"),
216
+ columns=["unique_country", "latitude", "longitude"],
217
+ index=False,
218
+ )
219
+ # Compute the average location for each unique admin1 (region)
220
+ region_avg = (
221
+ df.groupby(["unique_region"])[["latitude", "longitude"]].mean().reset_index()
222
+ )
223
+ region_avg.to_csv(
224
+ join(save_path, "region_center.csv"),
225
+ columns=["unique_region", "latitude", "longitude"],
226
+ index=False,
227
+ )
228
+ # Compute the average location for each unique admin2 (area)
229
+ area_avg = (
230
+ df.groupby(["unique_sub-region"])[["latitude", "longitude"]]
231
+ .mean()
232
+ .reset_index()
233
+ )
234
+ area_avg.to_csv(
235
+ join(save_path, "sub-region_center.csv"),
236
+ columns=["unique_sub-region", "latitude", "longitude"],
237
+ index=False,
238
+ )
239
+ # Compute the average location for each unique city
240
+ city_avg = (
241
+ df.groupby(["unique_city"])[["latitude", "longitude"]].mean().reset_index()
242
+ )
243
+ city_avg.to_csv(
244
+ join(save_path, "city_center.csv"),
245
+ columns=["unique_city", "latitude", "longitude"],
246
+ index=False,
247
+ )
248
+
249
+ for class_name in [
250
+ "unique_country",
251
+ "unique_sub-region",
252
+ "unique_region",
253
+ "unique_city",
254
+ ]:
255
+ # Load CSV data into a Pandas DataFrame
256
+ csv_file = class_name.split("_")[-1] + "_center.csv"
257
+ df = pd.read_csv(join(save_path, csv_file), low_memory=False)
258
+
259
+ splits = ["train"]
260
+ categories = sorted(
261
+ pd.concat(
262
+ [
263
+ pd.read_csv(
264
+ join(data_path, f"{split}.csv"), low_memory=False
265
+ )[class_name]
266
+ for split in splits
267
+ ]
268
+ )
269
+ .fillna("NaN")
270
+ .unique()
271
+ .tolist()
272
+ )
273
+
274
+ if "NaN" in categories:
275
+ categories.remove("NaN")
276
+
277
+ # compute the total number of categories - this name is fixed and will be used as a lookup during init
278
+ num_classes = len(categories)
279
+
280
+ # create a mapping from category to index
281
+ category_to_index = {category: i for i, category in enumerate(categories)}
282
+
283
+ dictionary = torch.zeros((num_classes, 2))
284
+ for index, row in df.iterrows():
285
+ key = row.iloc[0]
286
+ value = [row.iloc[1], row.iloc[2]]
287
+ if key in categories:
288
+ (
289
+ dictionary[category_to_index[key], 0],
290
+ dictionary[category_to_index[key], 1],
291
+ ) = np.radians(row.iloc[1]), np.radians(row.iloc[2])
292
+
293
+ # Save the PyTorch tensor to a .pt file
294
+ output_file = join(save_path, "index_to_gps_" + class_name + ".pt")
295
+ torch.save(dictionary, output_file)
296
+
297
+ train = pd.read_csv(join(data_path, "train.csv"), low_memory=False).fillna(
298
+ "NaN"
299
+ )
300
+
301
+ u = train.groupby("unique_city").sample(n=1)
302
+
303
+ country_df = (
304
+ u.pivot(index="unique_city", columns="unique_country", values="unique_city")
305
+ .notna()
306
+ .astype(int)
307
+ .fillna(0)
308
+ )
309
+ country_to_idx = {
310
+ category: i for i, category in enumerate(list(country_df.columns))
311
+ }
312
+ city_country_matrix = torch.tensor(country_df.values) / 1.0
313
+
314
+ region_df = (
315
+ u.pivot(index="unique_city", columns="unique_region", values="unique_city")
316
+ .notna()
317
+ .astype(int)
318
+ .fillna(0)
319
+ )
320
+ region_to_idx = {category: i for i, category in enumerate(list(region_df.columns))}
321
+ city_region_matrix = torch.tensor(region_df.values) / 1.0
322
+
323
+ country_df = (
324
+ u.pivot(index="unique_city", columns="unique_country", values="unique_city")
325
+ .notna()
326
+ .astype(int)
327
+ .fillna(0)
328
+ )
329
+ country_to_idx = {
330
+ category: i for i, category in enumerate(list(country_df.columns))
331
+ }
332
+ city_country_matrix = torch.tensor(country_df.values) / 1.0
333
+
334
+ output_file = join(save_path, "city_to_country.pt")
335
+ torch.save(city_country_matrix, output_file)
336
+
337
+ output_file = join(save_path, "country_to_idx.pt")
338
+ torch.save(country_to_idx, output_file)
339
+
340
+ region_df = (
341
+ u.pivot(index="unique_city", columns="unique_region", values="unique_city")
342
+ .notna()
343
+ .astype(int)
344
+ .fillna(0)
345
+ )
346
+ region_to_idx = {category: i for i, category in enumerate(list(region_df.columns))}
347
+ city_region_matrix = torch.tensor(region_df.values) / 1.0
348
+
349
+ output_file = join(save_path, "city_to_region.pt")
350
+ torch.save(city_region_matrix, output_file)
351
+
352
+ output_file = join(save_path, "region_to_idx.pt")
353
+ torch.save(region_to_idx, output_file)
354
+
355
+ area_df = (
356
+ u.pivot(index="unique_city", columns="unique_sub-region", values="unique_city")
357
+ .notna()
358
+ .astype(int)
359
+ .fillna(0)
360
+ )
361
+ area_to_idx = {category: i for i, category in enumerate(list(area_df.columns))}
362
+ city_area_matrix = torch.tensor(area_df.values) / 1.0
363
+
364
+ output_file = join(save_path, "city_to_area.pt")
365
+ torch.save(city_area_matrix, output_file)
366
+
367
+ output_file = join(save_path, "area_to_idx.pt")
368
+ torch.save(area_to_idx, output_file)
369
+ gt = torch.load(join(save_path, f"id_to_quad_{cfg.depth}_{cfg.do_split}.pt"))
370
+ matrixes = []
371
+ dicts = []
372
+ for i in range(1, cfg.depth):
373
+ # Step 2: Truncate strings to size cfg.depth - 1
374
+ l = [s[: cfg.depth - i] if len(s) >= cfg.depth + 1 - i else s for s in gt]
375
+
376
+ # Step 3: Get unique values in the modified list l
377
+ h = list(set(l))
378
+
379
+ # Step 4: Create a dictionary to map unique values to their index
380
+ h_dict = {value: index for index, value in enumerate(h)}
381
+ dicts.append(h_dict)
382
+
383
+ # Step 5: Initialize a torch matrix with zeros
384
+ matrix = torch.zeros((len(gt), len(h)))
385
+
386
+ # Step 6: Fill in the matrix with 1s based on the mapping
387
+ for h in range(len(gt)):
388
+ j = h_dict[l[h]]
389
+ matrix[h, j] = 1
390
+ matrixes.append(matrix)
391
+
392
+ output_file = join(save_path, "quadtree_matrixes.pt")
393
+ torch.save(matrixes, output_file)
394
+
395
+ output_file = join(save_path, "quadtree_dicts.pt")
396
+ torch.save(dicts, output_file)
397
+
398
+
399
+ if __name__ == "__main__":
400
+ main()
scripts/preprocessing/train-val-split.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import dirname, join
3
+
4
+ import pandas as pd
5
+ from sklearn.model_selection import train_test_split
6
+
7
+ if __name__ == "__main__":
8
+ data_path = join(dirname(dirname(__file__)), "datasets", "osv5m")
9
+ train_fp = join(data_path, f"train.csv")
10
+ val_fp = join(data_path, f"val.csv")
11
+ os.makedirs(dirname(val_fp), exist_ok=True)
12
+ df = pd.read_csv(train_fp, dtype={"category": str, "country": str, "city": str})
13
+ df_train, df_val = train_test_split(df, stratify=df["category"], test_size=0.1)
14
+ df_train.to_csv(train_fp, index=False)
15
+ df_val.to_csv(val_fp, index=False)
scripts/retrieval/backbone.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join
2
+ import PIL
3
+ import numpy as np
4
+ import pandas as pd
5
+ import reverse_geocoder
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class GeoDataset(Dataset):
10
+ def __init__(self, image_folder, annotation_file, transformation, tag="image_id"):
11
+ self.image_folder = image_folder
12
+ gt = pd.read_csv(annotation_file, dtype={tag: str})
13
+ files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)])
14
+ gt = gt[gt[tag].isin(files)]
15
+ self.processor = transformation
16
+ self.gt = [
17
+ (g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows()
18
+ ]
19
+ self.tag = tag
20
+
21
+ def fid(self, i):
22
+ return self.gt[i][0]
23
+
24
+ def latlon(self, i):
25
+ return self.gt[i][1]
26
+
27
+ def __len__(self):
28
+ return len(self.gt)
29
+
30
+ def __getitem__(self, idx):
31
+ fp = join(self.image_folder, self.gt[idx][0] + ".jpg")
32
+ return self.processor(self, idx, fp)
33
+
34
+
35
+ def load_plonk(path):
36
+ import hydra
37
+ from hydra import initialize, compose
38
+ from models.module import Geolocalizer
39
+ from omegaconf import OmegaConf, open_dict
40
+ from os.path import join
41
+ from hydra.utils import instantiate
42
+
43
+ # load config from path
44
+ # make path relative to current_dir
45
+ with initialize(version_base=None, config_path="osv5m__best_model"):
46
+ cfg = compose(config_name="config", overrides=[])
47
+
48
+ checkpoint = torch.load(join(path, "last.ckpt"))
49
+ del checkpoint["state_dict"][
50
+ "model.backbone.clip.vision_model.embeddings.position_ids"
51
+ ]
52
+ torch.save(checkpoint, join(path, "last2.ckpt"))
53
+
54
+ with open_dict(cfg):
55
+ cfg.checkpoint = join(path, "last2.ckpt")
56
+
57
+ cfg.num_classes = 11399
58
+ cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
59
+ cfg.model.network.head.final_dim = cfg.num_classes * 3
60
+ cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")
61
+
62
+ cfg.dataset.train_dataset.path = ""
63
+ cfg.dataset.val_dataset.path = ""
64
+ cfg.dataset.test_dataset.path = ""
65
+ cfg.logger.save_dir = ""
66
+ cfg.data_dir = ""
67
+ cfg.root_dir = ""
68
+ cfg.mode = "test"
69
+ cfg.model.network.backbone.instance.path = (
70
+ "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
71
+ )
72
+ transform = instantiate(cfg.dataset.test_transform)
73
+ model = Geolocalizer.load_from_checkpoint(join(path, "last2.ckpt"), cfg=cfg.model)
74
+ os.remove(join(path, "last2.ckpt"))
75
+
76
+ @torch.no_grad()
77
+ def inference(model, x):
78
+ return x[0], model.model.backbone({"img": x[1].to(model.device)})[:, 0, :].cpu()
79
+
80
+ def collate_fn(batch):
81
+ return [b[0] for b in batch], torch.stack([b[1] for b in batch], dim=0)
82
+
83
+ def operate(self, idx, fp):
84
+ proc = self.processor(PIL.Image.open(fp))
85
+ return self.gt[idx][0], proc
86
+
87
+ return model, operate, inference, collate_fn
88
+
89
+
90
+ def load_clip(which):
91
+ # We evaluate on:
92
+ # - "openai/clip-vit-base-patch32"
93
+ # - "openai/clip-vit-large-patch14-336"
94
+ # - "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
95
+ # - "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
96
+ # - "geolocal/StreetCLIP"
97
+ from transformers import CLIPProcessor, CLIPModel
98
+
99
+ @torch.no_grad()
100
+ def inference(model, img):
101
+ image_ids = img.data.pop("image_id")
102
+ image_input = img.to(model.device)
103
+ image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
104
+ features = model.get_image_features(**image_input)
105
+ features /= features.norm(dim=-1, keepdim=True)
106
+ return image_ids, features.cpu()
107
+
108
+ processor = CLIPProcessor.from_pretrained(which)
109
+
110
+ def operate(self, idx, fp):
111
+ pil = PIL.Image.open(fp)
112
+ proc = processor(images=pil, return_tensors="pt")
113
+ proc["image_id"] = self.gt[idx][0]
114
+ return proc
115
+
116
+ return CLIPModel.from_pretrained(which), operate, inference, None
117
+
118
+
119
+ def load_dino(which):
120
+ # We evaluate on:
121
+ # - 'facebook/dinov2-large'
122
+ from transformers import AutoImageProcessor, AutoModel
123
+
124
+ @torch.no_grad()
125
+ def inference(model, img):
126
+ image_ids = img.data.pop("image_id")
127
+ image_input = img.to(model.device)
128
+ image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
129
+ features = model(**image_input).last_hidden_state[:, 0]
130
+ features /= features.norm(dim=-1, keepdim=True)
131
+ return image_ids, features.cpu()
132
+
133
+ processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
134
+
135
+ def operate(self, idx, fp):
136
+ pil = PIL.Image.open(fp)
137
+ proc = processor(images=pil, return_tensors="pt")
138
+ proc["image_id"] = self.gt[idx][0]
139
+ return proc
140
+
141
+ return AutoModel.from_pretrained("facebook/dinov2-large"), operate, inference, None
142
+
143
+
144
+ def get_backbone(name):
145
+ if os.path.isdir(name):
146
+ return load_plonk(name)
147
+ elif "clip" in name.lower():
148
+ return load_clip(name)
149
+ elif "dino" in name.lower():
150
+ return load_dino(name)
scripts/retrieval/retrieval.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import PIL
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ import operator
9
+
10
+ from PIL import Image
11
+ from itertools import cycle
12
+ from tqdm.auto import tqdm, trange
13
+ from os.path import join
14
+ from PIL import Image
15
+
16
+ from tqdm import tqdm
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torch.nn import functional as F
19
+
20
+ from backbone import get_backbone
21
+ from utils import haversine, get_filenames, get_match_values, compute_print_accuracy
22
+
23
+
24
+ def compute_features(path, data_dir, csv_file, tag, args):
25
+ data = GeoDataset(data_dir, csv_file, tag=tag)
26
+ if not os.path.isdir(test_features_dir) or len(
27
+ os.listdir(test_features_dir)
28
+ ) != len(data):
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model, transform, inference, collate_fn = get_backbone(args.name)
31
+ dataloader = DataLoader(
32
+ data,
33
+ batch_size=args.batch_size,
34
+ shuffle=False,
35
+ num_workers=8,
36
+ collate_fn=collate_fn,
37
+ )
38
+ model = model.to(device)
39
+ os.makedirs(path, exist_ok=True)
40
+
41
+ for i, x in enumerate(tqdm(dataloader)):
42
+ image_ids, features = inference(model, x)
43
+ # save features as numpy array
44
+ for j, image_id in zip(range(features.shape[0]), image_ids):
45
+ np.save(join(path, f"{image_id}.npy"), features[j].unsqueeze(0).numpy())
46
+
47
+
48
+ def get_results(args, train_test):
49
+ import joblib
50
+
51
+ if not os.path.isfile(join(args.features_parent, ".cache", "1-nn.pkl")):
52
+ import faiss, glob, bisect
53
+
54
+ # import sys; sys.exit(0)
55
+ indexes = [
56
+ get_filenames(idx) for idx in tqdm(range(1, 6), desc="Loading indexes...")
57
+ ]
58
+
59
+ train_gt = pd.read_csv(
60
+ join(args.data_parent, args.annotation_file), dtype={"image_id": str}
61
+ )[["image_id", "latitude", "longitude"]]
62
+ test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[
63
+ ["id", "latitude", "longitude"]
64
+ ]
65
+
66
+ # make a map between image_id and lat/lon
67
+ train_gt = {
68
+ g[1]["image_id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
69
+ for g in tqdm(
70
+ train_gt.iterrows(), total=len(train_gt), desc="Loading train_gt"
71
+ )
72
+ }
73
+ test_gt = {
74
+ g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
75
+ for g in tqdm(
76
+ test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt"
77
+ )
78
+ }
79
+
80
+ train_test = []
81
+ os.makedirs(join(args.features_parent, ".cache"), exist_ok=True)
82
+ for f in tqdm(os.listdir(test_features_dir)):
83
+ query_vector = np.load(join(test_features_dir, f))
84
+
85
+ neighbors = []
86
+ for index, ids in indexes:
87
+ distances, indices = index.search(query_vector, 1)
88
+ distances, indices = np.squeeze(distances), np.squeeze(indices)
89
+ bisect.insort(
90
+ neighbors, (ids[indices], distances), key=operator.itemgetter(1)
91
+ )
92
+
93
+ neighbors = list(reversed(neighbors))
94
+ train_gps = train_gt[neighbors[0][0].replace(".npy", "")][None, :]
95
+ test_gps = test_gt[f.replace(".npy", "")][None, :]
96
+ train_test.append((train_gps, test_gps))
97
+ joblib.dump(train_test, join(args.features_parent, ".cache", "1-nn.pkl"))
98
+ else:
99
+ train_test = joblib.load(join(args.features_parent, ".cache", "1-nn.pkl"))
100
+
101
+ return train_test
102
+
103
+
104
+ if __name__ == "__main__":
105
+ # make a train/eval argparser
106
+ import argparse
107
+
108
+ parser = argparse.ArgumentParser()
109
+ parser.add_argument("--id", type=int, default=1) # maybe need to remove/refactor
110
+ parser.add_argument("--batch_size", type=int, default=512)
111
+ parser.add_argument(
112
+ "--annotation_file", type=str, required=False, default="train.csv"
113
+ )
114
+ parser.add_argument("--name", type=str, default="openai/clip-vit-base-patch32")
115
+ parser.add_argument("--features_parent", type=str, default="faiss/")
116
+ parser.add_argument("--data_parent", type=str, default="data/")
117
+ parser.add_argument("--test", action="store_true")
118
+
119
+ args = parser.parse_args()
120
+ args.features_parent = join(args.features_parent, args.name)
121
+ if args.test:
122
+ csv_file = join(args.data_parent, "test.csv")
123
+ data_dir = join(args.data_parent, "test")
124
+ path = join(args.features_parent, "features-test")
125
+ model = get_backbone(args.name)
126
+ compute_features(path, data_dir, csv_file, tag="id", args=args)
127
+ train_test = get_results(args, train_test)
128
+
129
+ from collections import Counter
130
+
131
+ N, pos = Counter(), Counter()
132
+ for train_gps, test_gps in tqdm(train_test, desc="Computing accuracy..."):
133
+ get_match_values(train_gps, test_gps, N, pos)
134
+
135
+ for train_gps, test_gps in tqdm(train_test, desc="Computing haversine..."):
136
+ haversine(train_gps, test_gps, N, pos)
137
+
138
+ compute_print_accuracy(N, pos)
139
+ else:
140
+ csv_file = join(args.data_parent, args.annotation_file)
141
+ path = join(args.features_parent, f"features-{args.id}")
142
+ data_dir = join(args.data_parent, f"images-{args.id}", "train")
143
+ compute_features(path, data_dir, csv_file, tag="image_id", args=args)
scripts/retrieval/street-clip-zero-shot.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ import os
3
+ import sys
4
+ import PIL
5
+ import json
6
+ import torch
7
+ import numpy as np
8
+ import pandas as pd
9
+ import operator
10
+ import joblib
11
+ import reverse_geocoder
12
+
13
+ from PIL import Image
14
+ from itertools import cycle
15
+ from tqdm.auto import tqdm, trange
16
+ from os.path import join
17
+ from PIL import Image
18
+
19
+ from tqdm import tqdm
20
+ from collections import Counter
21
+ from transformers import CLIPProcessor, CLIPModel
22
+ from torch.utils.data import Dataset, DataLoader
23
+ from torch.nn import functional as F
24
+ from utils import haversine
25
+
26
+
27
+ class GeoDataset(Dataset):
28
+ def __init__(self, image_folder, annotation_file, tag="image_id"):
29
+ self.image_folder = image_folder
30
+ gt = pd.read_csv(annotation_file, dtype={tag: str})
31
+ files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)])
32
+ gt = gt[gt[tag].isin(files)]
33
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
34
+ self.gt = [
35
+ (g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows()
36
+ ]
37
+ self.tag = tag
38
+
39
+ def fid(self, i):
40
+ return self.gt[i][0]
41
+
42
+ def latlon(self, i):
43
+ return self.gt[i][1]
44
+
45
+ def __len__(self):
46
+ return len(self.gt)
47
+
48
+ def __getitem__(self, idx):
49
+ fp = join(self.image_folder, self.gt[idx][0] + ".jpg")
50
+ pil = PIL.Image.open(fp)
51
+ proc = self.processor(images=pil, return_tensors="pt")
52
+ proc["image_id"] = self.gt[idx][0]
53
+ return proc
54
+
55
+
56
+ @torch.no_grad()
57
+ def compute_features_clip(img, model):
58
+ image_ids = img.data.pop("image_id")
59
+ image_input = img.to(model.device)
60
+ image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
61
+ features = model.get_image_features(**image_input)
62
+ features /= features.norm(dim=-1, keepdim=True)
63
+ return image_ids, features.cpu()
64
+
65
+
66
+ def get_prompts(country, region, sub_region, city):
67
+ a = country if country != "" else None
68
+ b, c, d = None, None, None
69
+ if a is not None:
70
+ b = country + ", " + region if region != "" else None
71
+ if b is not None:
72
+ c = (
73
+ country + ", " + region + ", " + sub_region
74
+ if sub_region != ""
75
+ else None
76
+ )
77
+ d = (
78
+ country + ", " + region + ", " + sub_region + ", " + city
79
+ if city != ""
80
+ else None
81
+ )
82
+ return a, b, c, d
83
+
84
+
85
+ if __name__ == "__main__":
86
+ # make a train/eval argparser
87
+ import argparse
88
+
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument(
91
+ "--annotation_file", type=str, required=False, default="train.csv"
92
+ )
93
+ parser.add_argument(
94
+ "--features_parent", type=str, default="/home/isig/gaia-v2/faiss/street-clip"
95
+ )
96
+ parser.add_argument(
97
+ "--data_parent", type=str, default="/home/isig/gaia-v2/loic-data/"
98
+ )
99
+
100
+ args = parser.parse_args()
101
+ test_path_csv = join(args.data_parent, "test.csv")
102
+ test_image_dir = join(args.data_parent, "test")
103
+ save_path = join(args.features_parent, "indexes/test.index")
104
+ test_features_dir = join(args.features_parent, "indexes/features-test")
105
+
106
+ processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+ model = CLIPModel.from_pretrained("geolocal/StreetCLIP").to(device)
109
+
110
+ @torch.no_grad()
111
+ def compute_text_features_clip(text):
112
+ text_pt = processor(text=text, return_tensors="pt").to(device)
113
+ features = model.get_text_features(**text_pt)
114
+ features /= features.norm(dim=-1, keepdim=True)
115
+ return features.cpu().squeeze(0).numpy()
116
+
117
+ import country_converter as coco
118
+
119
+ if not os.path.isfile("text_street-clip-features.pkl"):
120
+ if not os.path.isfile("rg_cities1000.csv"):
121
+ os.system(
122
+ "wget https://raw.githubusercontent.com/thampiman/reverse-geocoder/master/reverse_geocoder/rg_cities1000.csv"
123
+ )
124
+
125
+ cities = pd.read_csv("rg_cities1000.csv")
126
+ cities = cities[["lat", "lon", "name", "admin1", "admin2", "cc"]]
127
+ reprs = {0: {}, 1: {}, 2: {}, 3: {}}
128
+ for line in tqdm(
129
+ cities.iterrows(), total=len(cities), desc="Creating hierarchy"
130
+ ):
131
+ lat, lon, city, region, sub_region, cc = line[1]
132
+ try:
133
+ city, region, sub_region, cc = [
134
+ ("" if pd.isna(x) else x)
135
+ for x in [
136
+ city,
137
+ region,
138
+ sub_region,
139
+ coco.convert(cc, to="name_short"),
140
+ ]
141
+ ]
142
+ a, b, c, d = get_prompts(cc, region, sub_region, city)
143
+ if a is not None:
144
+ if a not in reprs[0]:
145
+ reprs[0][a] = {
146
+ "gps": {(lat, lon)},
147
+ "embedding": compute_text_features_clip(a),
148
+ }
149
+ else:
150
+ reprs[0][a]["gps"].add((lat, lon))
151
+
152
+ if b is not None:
153
+ if b not in reprs[1]:
154
+ reprs[1][b] = {
155
+ "gps": {(lat, lon)},
156
+ "embedding": compute_text_features_clip(b),
157
+ }
158
+ else:
159
+ reprs[1][b]["gps"].add((lat, lon))
160
+
161
+ if c is not None:
162
+ if c not in reprs[2]:
163
+ reprs[2][c] = {
164
+ "gps": {(lat, lon)},
165
+ "embedding": compute_text_features_clip(c),
166
+ }
167
+ else:
168
+ reprs[2][c]["gps"].add((lat, lon))
169
+
170
+ if d is not None:
171
+ if d not in reprs[3]:
172
+ reprs[3][d] = {
173
+ "gps": {(lat, lon)},
174
+ "embedding": compute_text_features_clip(
175
+ d.replace(", , ", ", ")
176
+ ),
177
+ }
178
+ else:
179
+ reprs[3][d]["gps"].add((lat, lon))
180
+ except Exception as e:
181
+ # print stack trace into file log.txt
182
+ with open("log.txt", "a") as f:
183
+ print(traceback.format_exc(), file=f)
184
+
185
+ reprs[-1] = {"": {"gps": (0, 0), "embedding": compute_text_features_clip("")}}
186
+
187
+ # compute mean for gps of all 'a' and 'b' and 'c' and 'd'
188
+ for i in range(4):
189
+ for k in reprs[i].keys():
190
+ reprs[i][k]["gps"] = tuple(
191
+ np.array(list(reprs[i][k]["gps"])).mean(axis=0).tolist()
192
+ )
193
+
194
+ joblib.dump(reprs, "text_street-clip-features.pkl")
195
+ else:
196
+ reprs = joblib.load("text_street-clip-features.pkl")
197
+
198
+ def get_loc(x):
199
+ location = reverse_geocoder.search(x[0].tolist())[0]
200
+ country = coco.convert(names=location["cc"], to="name_short")
201
+ region = location.get("admin1", "")
202
+ sub_region = location.get("admin2", "")
203
+ city = location.get("name", "")
204
+ a, b, c, d = get_prompts(country, region, sub_region, city)
205
+ return a, b, c, d
206
+
207
+ def matches(embed, repr, control, gt, sw=None):
208
+ first_max = max(
209
+ (
210
+ (k, embed.dot(v["embedding"]))
211
+ for k, v in repr.items()
212
+ if sw is None or k.startswith(sw)
213
+ ),
214
+ key=operator.itemgetter(1),
215
+ )
216
+ if first_max[1] > embed.dot(control["embedding"]):
217
+ return repr[first_max[0]]["gps"], gt == first_max[0]
218
+ else:
219
+ return control["gps"], False
220
+
221
+ def get_match_values(gt, embed, N, pos):
222
+ xa, xb, xc, xd = get_loc(gt)
223
+
224
+ if xa is not None:
225
+ N["country"] += 1
226
+ gps, flag = matches(embed, reprs[0], reprs[-1][""], xa)
227
+ if flag:
228
+ pos["country"] += 1
229
+ if xb is not None:
230
+ N["region"] += 1
231
+ gps, flag = matches(embed, reprs[1], reprs[0][xa], xb, sw=xa)
232
+ if flag:
233
+ pos["region"] += 1
234
+ if xc is not None:
235
+ N["sub-region"] += 1
236
+ gps, flag = matches(
237
+ embed, reprs[2], reprs[1][xb], xc, sw=xb
238
+ )
239
+ if flag:
240
+ pos["sub-region"] += 1
241
+ if xd is not None:
242
+ N["city"] += 1
243
+ gps, flag = matches(
244
+ embed, reprs[3], reprs[2][xc], xd, sw=xc
245
+ )
246
+ if flag:
247
+ pos["city"] += 1
248
+ else:
249
+ if xd is not None:
250
+ N["city"] += 1
251
+ gps, flag = matches(
252
+ embed, reprs[3], reprs[1][xb], xd, sw=xb + ", "
253
+ )
254
+ if flag:
255
+ pos["city"] += 1
256
+
257
+ haversine(np.array(gps)[None, :], np.array(gt), N, pos)
258
+
259
+ def compute_print_accuracy(N, pos):
260
+ for k in N.keys():
261
+ pos[k] /= N[k]
262
+
263
+ # pretty-print accuracy in percentage with 2 floating points
264
+ print(
265
+ f'Accuracy: {pos["country"]*100.0:.2f} (country), {pos["region"]*100.0:.2f} (region), {pos["sub-region"]*100.0:.2f} (sub-region), {pos["city"]*100.0:.2f} (city)'
266
+ )
267
+ print(
268
+ f'Haversine: {pos["haversine"]:.2f} (haversine), {pos["geoguessr"]:.2f} (geoguessr)'
269
+ )
270
+
271
+ import joblib
272
+
273
+ data = GeoDataset(test_image_dir, test_path_csv, tag="id")
274
+ test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[
275
+ ["id", "latitude", "longitude"]
276
+ ]
277
+ test_gt = {
278
+ g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
279
+ for g in tqdm(test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt")
280
+ }
281
+
282
+ with open("/home/isig/gaia-v2/loic/plonk/test3_indices.txt", "r") as f:
283
+ # read lines
284
+ lines = f.readlines()
285
+ # remove whitespace characters like `\n` at the end of each line
286
+ lines = [l.strip() for l in lines]
287
+ # and convert to set
288
+ lines = set(lines)
289
+
290
+ train_test = []
291
+ N, pos = Counter(), Counter()
292
+ for f in tqdm(os.listdir(test_features_dir)):
293
+ if f.replace(".npy", "") not in lines:
294
+ continue
295
+ query_vector = np.squeeze(np.load(join(test_features_dir, f)))
296
+ test_gps = test_gt[f.replace(".npy", "")][None, :]
297
+ get_match_values(test_gps, query_vector, N, pos)
298
+
299
+ compute_print_accuracy(N, pos)
scripts/retrieval/utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import reverse_geocoder
4
+
5
+
6
+ def get_loc(x):
7
+ location = reverse_geocoder.search(x[0].tolist())[0]
8
+ country = location.get("cc", "")
9
+ region = location.get("admin1", "")
10
+ sub_region = location.get("admin2", "")
11
+ city = location.get("name", "")
12
+
13
+ a = country if country != "" else None
14
+ b, c, d = None, None, None
15
+ if a is not None:
16
+ b = country + "," + region if region != "" else None
17
+ if b is not None:
18
+ c = country + "," + region + "," + sub_region if sub_region != "" else None
19
+ d = (
20
+ country + "," + region + "," + sub_region + "," + city
21
+ if city != ""
22
+ else None
23
+ )
24
+
25
+ return a, b, c, d
26
+
27
+
28
+ def get_match_values(pred, gt, N, pos):
29
+ xa, xb, xc, xd = get_loc(gt)
30
+ ya, yb, yc, yd = get_loc(pred)
31
+
32
+ if xa is not None:
33
+ N["country"] += 1
34
+ if xa == ya:
35
+ pos["country"] += 1
36
+ if xb is not None:
37
+ N["region"] += 1
38
+ if xb == yb:
39
+ pos["region"] += 1
40
+ if xc is not None:
41
+ N["sub-region"] += 1
42
+ if xc == yc:
43
+ pos["sub-region"] += 1
44
+ if xd is not None:
45
+ N["city"] += 1
46
+ if xd == yd:
47
+ pos["city"] += 1
48
+
49
+
50
+ def compute_print_accuracy(N, pos):
51
+ for k in N.keys():
52
+ pos[k] /= N[k]
53
+
54
+ # pretty-print accuracy in percentage with 2 floating points
55
+ print(
56
+ f'Accuracy: {pos["country"]*100.0:.2f} (country), {pos["region"]*100.0:.2f} (region), {pos["sub-region"]*100.0:.2f} (sub-region), {pos["city"]*100.0:.2f} (city)'
57
+ )
58
+ print(
59
+ f'Haversine: {pos["haversine"]:.2f} (haversine), {pos["geoguessr"]:.2f} (geoguessr)'
60
+ )
61
+
62
+
63
+ def get_filenames(idx):
64
+ from autofaiss import build_index
65
+
66
+ path = join(args.features_parent, f"features-{idx}/")
67
+ files = [f for f in os.listdir(path)]
68
+ full_files = [join(path, f) for f in os.listdir(path)]
69
+ index = build_index(
70
+ embeddings=np.concatenate([np.load(f) for f in tqdm(full_files)], axis=0),
71
+ nb_cores=12,
72
+ save_on_disk=False,
73
+ )[0]
74
+ return index, files
75
+
76
+
77
+ def normalize(x):
78
+ lat, lon = x[:, 0], x[:, 1]
79
+ """Used to put all lat lon inside ±90 and ±180."""
80
+ lat = (lat + 90) % 360 - 90
81
+ if lat > 90:
82
+ lat = 180 - lat
83
+ lon += 180
84
+ lon = (lon + 180) % 360 - 180
85
+ return np.stack([lat, lon], axis=1)
86
+
87
+
88
+ def haversine(pred, gt, N, p):
89
+ # expects inputs to be np arrays in (lat, lon) format as radians
90
+ # N x 2
91
+ pred = np.radians(normalize(pred))
92
+ gt = np.radians(normalize(gt))
93
+
94
+ # calculate the difference in latitude and longitude between the predicted and ground truth points
95
+ lat_diff = pred[:, 0] - gt[:, 0]
96
+ lon_diff = pred[:, 1] - gt[:, 1]
97
+
98
+ # calculate the haversine formula components
99
+ lhs = np.sin(lat_diff / 2) ** 2
100
+ rhs = np.cos(pred[:, 0]) * np.cos(gt[:, 0]) * np.sin(lon_diff / 2) ** 2
101
+ a = lhs + rhs
102
+
103
+ # calculate the final distance using the haversine formula
104
+ c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
105
+
106
+ haversine_distance = 6371 * c[0]
107
+ geoguessr_sum = 5000 * np.exp(-haversine_distance / 1492.7)
108
+
109
+ N["geoguessr"] += 1
110
+ p["geoguessr"] += geoguessr_sum
111
+
112
+ N["haversine"] += 1
113
+ p["haversine"] += haversine_distance
utils/__init__.py ADDED
File without changes
utils/image_processing.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torchvision
4
+
5
+
6
+ def remap_image_torch(image):
7
+ image_torch = ((image + 1) / 2.0) * 255.0
8
+ image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8)
9
+ return image_torch
10
+
11
+
12
+ class CenterCrop(torch.nn.Module):
13
+ """Crops the given image at the center. Allows to crop to the maximum possible size.
14
+ Args:
15
+ size (sequence or int): Desired output size of the crop. If size is an
16
+ int instead of sequence like (h, w), a square crop (size, size) is
17
+ made.
18
+ ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio.
19
+ """
20
+
21
+ def __init__(self, size=None, ratio="1:1"):
22
+ super().__init__()
23
+ self.size = size
24
+ self.ratio = ratio
25
+
26
+ def forward(self, img):
27
+ """
28
+ Args:
29
+ img (PIL Image or Tensor): Image to be cropped.
30
+
31
+ Returns:
32
+ PIL Image or Tensor: Cropped image.
33
+ """
34
+ if self.size is None:
35
+ if isinstance(img, torch.Tensor):
36
+ h, w = img.shape[-2:]
37
+ else:
38
+ w, h = img.size
39
+ ratio = self.ratio.split(":")
40
+ ratio = float(ratio[0]) / float(ratio[1])
41
+ ratioed_w = int(h * ratio)
42
+ ratioed_h = int(w / ratio)
43
+ if w >= h:
44
+ if ratioed_h <= h:
45
+ size = (ratioed_h, w)
46
+ else:
47
+ size = (h, ratioed_w)
48
+ else:
49
+ if ratioed_w <= w:
50
+ size = (h, ratioed_w)
51
+ else:
52
+ size = (ratioed_h, w)
53
+ else:
54
+ size = self.size
55
+ return torchvision.transforms.functional.center_crop(img, size)
56
+
57
+ def __repr__(self) -> str:
58
+ return f"{self.__class__.__name__}(size={self.size})"
utils/lr_scheduler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ class WarmupLR:
5
+ """
6
+ Linear Warmup learning rate scheduler. After warmup, learning rate is
7
+ constant.
8
+
9
+ Args:
10
+ optimizer (torch.optim.Optimizer): optimizer
11
+ warmup_steps (int): number of warmup steps
12
+
13
+ """
14
+
15
+ def __init__(self, optimizer, warmup_steps):
16
+ self.optimizer = optimizer
17
+ self.warmup_steps = warmup_steps
18
+ self.base_lr = None
19
+
20
+ def get_lr(self, lr, step):
21
+ return lr * min(step / max(self.warmup_steps, 1), 1.0)
22
+
23
+ def step(self, step):
24
+ if self.base_lr is None:
25
+ self.base_lr = [
26
+ param_group["lr"] for param_group in self.optimizer.param_groups
27
+ ]
28
+ for param_group, base_lr_group in zip(
29
+ self.optimizer.param_groups, self.base_lr
30
+ ):
31
+ param_group["lr"] = self.get_lr(base_lr_group, step)
32
+
33
+ def state_dict(self):
34
+ return {
35
+ key: value for key, value in self.__dict__.items() if key != "optimizer"
36
+ }
37
+
38
+ def load_state_dict(self, state_dict):
39
+ self.__dict__.update(state_dict)
40
+
41
+
42
+ class WarmupCosineDecayLR:
43
+ """
44
+ Linear Warmup learning rate scheduler. After warmup, learning rate is
45
+ constant.
46
+ After warmup, learning rate follows a cosine decay.
47
+
48
+ Args:
49
+ optimizer (torch.optim.Optimizer): optimizer
50
+ warmup_steps (int): number of warmup steps
51
+ total_steps (int): total number of steps
52
+ rate (float): cosine decay rate
53
+ """
54
+
55
+ def __init__(self, optimizer, warmup_steps, total_steps, rate=1.0):
56
+ self.optimizer = optimizer
57
+ self.warmup_steps = warmup_steps
58
+ self.base_lr = None
59
+ self.total_steps = total_steps
60
+ self.rate = rate
61
+
62
+ def get_lr(self, lr, step):
63
+ if step < self.warmup_steps:
64
+ return lr * min(step / max(self.warmup_steps, 1), 1.0)
65
+ else:
66
+ return (
67
+ 0.5
68
+ * lr
69
+ * (
70
+ 1
71
+ + math.cos(
72
+ self.rate
73
+ * math.pi
74
+ * (step - self.warmup_steps)
75
+ / (self.total_steps - self.warmup_steps)
76
+ )
77
+ )
78
+ )
79
+
80
+ def step(self, step):
81
+ if self.base_lr is None:
82
+ self.base_lr = [
83
+ param_group["lr"] for param_group in self.optimizer.param_groups
84
+ ]
85
+ for param_group, base_lr_group in zip(
86
+ self.optimizer.param_groups, self.base_lr
87
+ ):
88
+ param_group["lr"] = self.get_lr(base_lr_group, step)
89
+
90
+ def state_dict(self):
91
+ return {
92
+ key: value for key, value in self.__dict__.items() if key != "optimizer"
93
+ }
94
+
95
+ def load_state_dict(self, state_dict):
96
+ self.__dict__.update(state_dict)
utils/model_utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def print_trainable_parameters(model):
2
+ """
3
+ Prints the number and percentage of trainable parameters in the model.
4
+ Useful for tracking % parameters trained for LoRA.
5
+ """
6
+ trainable_params = 0
7
+ all_param = 0
8
+ for _, param in model.named_parameters():
9
+ all_param += param.numel()
10
+ if param.requires_grad:
11
+ trainable_params += param.numel()
12
+ print(
13
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
14
+ )
utils/quadtree_10_1000.csv ADDED
The diff for this file is too large to render. See raw diff