File size: 1,211 Bytes
94f372a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn
from models.networks.utils import UnormGPS
from torch.nn.functional import tanh, sigmoid, softmax


class AuxHead(nn.Module):
    def __init__(self, aux_data=[], use_tanh=False):
        super().__init__()
        self.aux_data = aux_data
        self.unorm = UnormGPS()
        self.use_tanh = use_tanh

    def forward(self, x):
        """Forward pass of the network.

        x : Union[torch.Tensor, dict] with the output of the backbone.

        """
        if self.use_tanh:
            gps = tanh(x["gps"])
        gps = self.unorm(gps)
        output = {"gps": gps}
        if "land_cover" in self.aux_data:
            output["land_cover"] = softmax(x["land_cover"])
        if "road_index" in self.aux_data:
            output["road_index"] = x["road_index"]
        if "drive_side" in self.aux_data:
            output["drive_side"] = sigmoid(x["drive_side"])
        if "climate" in self.aux_data:
            output["climate"] = softmax(x["climate"])
        if "soil" in self.aux_data:
            output["soil"] = softmax(x["soil"])
        if "dist_sea" in self.aux_data:
            output["dist_sea"] = x["dist_sea"]
        return output