import torch import torch.nn as nn import pandas as pd from models.networks.utils import UnormGPS class HybridHead(nn.Module): """Classification head followed by regression head for the network.""" def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh): super().__init__() self.final_dim = final_dim self.use_tanh = use_tanh self.scale_tanh = scale_tanh self.unorm = UnormGPS() if quadtree_path is not None: quadtree = pd.read_csv(quadtree_path) self.init_quadtree(quadtree) def init_quadtree(self, quadtree): quadtree[["min_lat", "max_lat"]] /= 90.0 quadtree[["min_lon", "max_lon"]] /= 180.0 self.register_buffer( "cell_center", 0.5 * torch.tensor(quadtree[["max_lat", "max_lon"]].values) + 0.5 * torch.tensor(quadtree[["min_lat", "min_lon"]].values), ) self.register_buffer( "cell_size", torch.tensor(quadtree[["max_lat", "max_lon"]].values) - torch.tensor(quadtree[["min_lat", "min_lon"]].values), ) def forward(self, x, gt_label): """Forward pass of the network. x : Union[torch.Tensor, dict] with the output of the backbone. """ classification_logits = x[..., : self.final_dim] classification = classification_logits.argmax(dim=-1) regression = x[..., self.final_dim :] if self.use_tanh: regression = self.scale_tanh * torch.tanh(regression) regression = regression.view(regression.shape[0], -1, 2) if self.training: regression = torch.gather( regression, 1, gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2), )[:, 0, :] size = 2.0 / self.cell_size[gt_label] center = self.cell_center[gt_label] gps = ( self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0 ) else: regression = torch.gather( regression, 1, classification.unsqueeze(-1) .unsqueeze(-1) .expand(regression.shape[0], 1, 2), )[:, 0, :] size = 2.0 / self.cell_size[classification] center = self.cell_center[classification] gps = ( self.cell_center[classification] + regression * self.cell_size[classification] / 2.0 ) gps = self.unorm(gps) return { "label": classification_logits, "gps": gps, "size": size, "center": center, "reg": regression, } class HybridHeadCentroid(nn.Module): """Classification head followed by regression head for the network.""" def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh): super().__init__() self.final_dim = final_dim self.use_tanh = use_tanh self.scale_tanh = scale_tanh self.unorm = UnormGPS() if quadtree_path is not None: quadtree = pd.read_csv(quadtree_path) self.init_quadtree(quadtree) def init_quadtree(self, quadtree): quadtree[["min_lat", "max_lat", "mean_lat"]] /= 90.0 quadtree[["min_lon", "max_lon", "mean_lon"]] /= 180.0 self.cell_center = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) self.cell_size_up = torch.tensor(quadtree[["max_lat", "max_lon"]].values) - torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) self.cell_size_down = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) - torch.tensor(quadtree[["min_lat", "min_lon"]].values) def forward(self, x, gt_label): """Forward pass of the network. x : Union[torch.Tensor, dict] with the output of the backbone. """ classification_logits = x[..., : self.final_dim] classification = classification_logits.argmax(dim=-1) self.cell_size_up = self.cell_size_up.to(classification.device) self.cell_center = self.cell_center.to(classification.device) self.cell_size_down = self.cell_size_down.to(classification.device) regression = x[..., self.final_dim :] if self.use_tanh: regression = self.scale_tanh * torch.tanh(regression) regression = regression.view(regression.shape[0], -1, 2) if self.training: regression = torch.gather( regression, 1, gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2), )[:, 0, :] size = torch.where( regression > 0, self.cell_size_up[gt_label], self.cell_size_down[gt_label], ) center = self.cell_center[gt_label] gps = self.cell_center[gt_label] + regression * size else: regression = torch.gather( regression, 1, classification.unsqueeze(-1) .unsqueeze(-1) .expand(regression.shape[0], 1, 2), )[:, 0, :] size = torch.where( regression > 0, self.cell_size_up[classification], self.cell_size_down[classification], ) center = self.cell_center[classification] gps = self.cell_center[classification] + regression * size gps = self.unorm(gps) return { "label": classification_logits, "gps": gps, "size": 1.0 / size, "center": center, "reg": regression, } class SharedHybridHead(HybridHead): """Classification head followed by SHARED regression head for the network.""" def forward(self, x, gt_label): """Forward pass of the network. x : Union[torch.Tensor, dict] with the output of the backbone. """ classification_logits = x[..., : self.final_dim] classification = classification_logits.argmax(dim=-1) regression = x[..., self.final_dim :] if self.use_tanh: regression = self.scale_tanh * torch.tanh(regression) if self.training: gps = ( self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0 ) else: gps = ( self.cell_center[classification] + regression * self.cell_size[classification] / 2.0 ) gps = self.unorm(gps) return {"label": classification_logits, "gps": gps}