import torch.nn as nn from models.networks.utils import UnormGPS from torch.nn.functional import tanh, sigmoid, softmax class AuxHead(nn.Module): def __init__(self, aux_data=[], use_tanh=False): super().__init__() self.aux_data = aux_data self.unorm = UnormGPS() self.use_tanh = use_tanh def forward(self, x): """Forward pass of the network. x : Union[torch.Tensor, dict] with the output of the backbone. """ if self.use_tanh: gps = tanh(x["gps"]) gps = self.unorm(gps) output = {"gps": gps} if "land_cover" in self.aux_data: output["land_cover"] = softmax(x["land_cover"]) if "road_index" in self.aux_data: output["road_index"] = x["road_index"] if "drive_side" in self.aux_data: output["drive_side"] = sigmoid(x["drive_side"]) if "climate" in self.aux_data: output["climate"] = softmax(x["climate"]) if "soil" in self.aux_data: output["soil"] = softmax(x["soil"]) if "dist_sea" in self.aux_data: output["dist_sea"] = x["dist_sea"] return output