Spaces:
Sleeping
Sleeping
File size: 1,059 Bytes
94f372a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from models.networks.utils import UnormGPS
import torch.nn as nn
import numpy as np
class IdToGPS(nn.Module):
def __init__(self, id_to_gps: str):
"""Map index to gps coordinates (indices can be country or city ids)"""
super().__init__()
if "quadtree" in id_to_gps:
self.id_to_gps = torch.load(
"_".join(id_to_gps.split("_")[:-4] + id_to_gps.split("_")[-3:])
)
else:
self.id_to_gps = torch.load(id_to_gps)
#self.unorm = UnormGPS()
def forward(self, x):
"""Mapping from country id to gps coordinates
Args:
x: torch.Tensor with features
"""
if isinstance(x, dict):
# for oracle
labels, x = x["label"], x["img"]
else:
# predicted labels
labels = x
self.id_to_gps = self.id_to_gps.to(labels.device)
#return {"gps": self.unorm(self.id_to_gps[labels])}
return {"gps": self.id_to_gps[labels]}
|