yunusserhat's picture
Upload 40 files
94f372a verified
import torch.nn as nn
from models.networks.utils import UnormGPS
from torch.nn.functional import tanh, sigmoid, softmax
class AuxHead(nn.Module):
def __init__(self, aux_data=[], use_tanh=False):
super().__init__()
self.aux_data = aux_data
self.unorm = UnormGPS()
self.use_tanh = use_tanh
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
if self.use_tanh:
gps = tanh(x["gps"])
gps = self.unorm(gps)
output = {"gps": gps}
if "land_cover" in self.aux_data:
output["land_cover"] = softmax(x["land_cover"])
if "road_index" in self.aux_data:
output["road_index"] = x["road_index"]
if "drive_side" in self.aux_data:
output["drive_side"] = sigmoid(x["drive_side"])
if "climate" in self.aux_data:
output["climate"] = softmax(x["climate"])
if "soil" in self.aux_data:
output["soil"] = softmax(x["soil"])
if "dist_sea" in self.aux_data:
output["dist_sea"] = x["dist_sea"]
return output