# -*- coding: utf-8 -*- # Loss functions (PyTorch and own defined) # # Own defined loss functions: # xentropy_loss, dice_loss, mse_loss and msge_loss (https://github.com/vqdang/hover_net) # WeightedBaseLoss, MAEWeighted, MSEWeighted, BCEWeighted, CEWeighted (https://github.com/okunator/cellseg_models.pytorch) # @ Fabian Hörst, fabian.hoerst@uk-essen.de # Institute for Artifical Intelligence in Medicine, # University Medicine Essen import torch import torch.nn.functional as F from typing import List, Tuple from torch import nn from torch.nn.modules.loss import _Loss from base_ml.base_utils import filter2D, gaussian_kernel2d class XentropyLoss(_Loss): """Cross entropy loss""" def __init__(self, reduction: str = "mean") -> None: super().__init__(size_average=None, reduce=None, reduction=reduction) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Assumes NCHW shape of array, must be torch.float32 dtype Args: input (torch.Tensor): Ground truth array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes target (torch.Tensor): Prediction array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes Returns: torch.Tensor: Cross entropy loss, with shape () [scalar], grad_fn = MeanBackward0 """ # reshape input = input.permute(0, 2, 3, 1) target = target.permute(0, 2, 3, 1) epsilon = 10e-8 # scale preds so that the class probs of each sample sum to 1 pred = input / torch.sum(input, -1, keepdim=True) # manual computation of crossentropy pred = torch.clamp(pred, epsilon, 1.0 - epsilon) loss = -torch.sum((target * torch.log(pred)), -1, keepdim=True) loss = loss.mean() if self.reduction == "mean" else loss.sum() return loss class DiceLoss(_Loss): """Dice loss Args: smooth (float, optional): Smoothing value. Defaults to 1e-3. """ def __init__(self, smooth: float = 1e-3) -> None: super().__init__(size_average=None, reduce=None, reduction="mean") self.smooth = smooth def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Assumes NCHW shape of array, must be torch.float32 dtype `pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC. Args: input (torch.Tensor): Prediction array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes target (torch.Tensor): Ground truth array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes Returns: torch.Tensor: Dice loss, with shape () [scalar], grad_fn=SumBackward0 """ input = input.permute(0, 2, 3, 1) target = target.permute(0, 2, 3, 1) inse = torch.sum(input * target, (0, 1, 2)) l = torch.sum(input, (0, 1, 2)) r = torch.sum(target, (0, 1, 2)) loss = 1.0 - (2.0 * inse + self.smooth) / (l + r + self.smooth) loss = torch.sum(loss) return loss class MSELossMaps(_Loss): """Calculate mean squared error loss for combined horizontal and vertical maps of segmentation tasks.""" def __init__(self) -> None: super().__init__(size_average=None, reduce=None, reduction="mean") def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Loss calculation Args: input (torch.Tensor): Prediction of combined horizontal and vertical maps with shape (N, 2, H, W), channel 0 is vertical and channel 1 is horizontal target (torch.Tensor): Ground truth of combined horizontal and vertical maps with shape (N, 2, H, W), channel 0 is vertical and channel 1 is horizontal Returns: torch.Tensor: Mean squared error per pixel with shape (N, 2, H, W), grad_fn=SubBackward0 """ # reshape loss = input - target loss = (loss * loss).mean() return loss class MSGELossMaps(_Loss): def __init__(self) -> None: super().__init__(size_average=None, reduce=None, reduction="mean") def get_sobel_kernel( self, size: int, device: str ) -> Tuple[torch.Tensor, torch.Tensor]: """Get sobel kernel with a given size. Args: size (int): Kernel site device (str): Cuda device Returns: Tuple[torch.Tensor, torch.Tensor]: Horizontal and vertical sobel kernel, each with shape (size, size) """ assert size % 2 == 1, "Must be odd, get size=%d" % size h_range = torch.arange( -size // 2 + 1, size // 2 + 1, dtype=torch.float32, device=device, requires_grad=False, ) v_range = torch.arange( -size // 2 + 1, size // 2 + 1, dtype=torch.float32, device=device, requires_grad=False, ) h, v = torch.meshgrid(h_range, v_range, indexing="ij") kernel_h = h / (h * h + v * v + 1.0e-15) kernel_v = v / (h * h + v * v + 1.0e-15) return kernel_h, kernel_v def get_gradient_hv(self, hv: torch.Tensor, device: str) -> torch.Tensor: """For calculating gradient of horizontal and vertical prediction map Args: hv (torch.Tensor): horizontal and vertical map device (str): CUDA device Returns: torch.Tensor: Gradient with same shape as input """ kernel_h, kernel_v = self.get_sobel_kernel(5, device=device) kernel_h = kernel_h.view(1, 1, 5, 5) # constant kernel_v = kernel_v.view(1, 1, 5, 5) # constant h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW # can only apply in NCHW mode h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2) v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2) dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1) dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC return dhv def forward( self, input: torch.Tensor, target: torch.Tensor, focus: torch.Tensor, device: str, ) -> torch.Tensor: """MSGE (Gradient of MSE) loss Args: input (torch.Tensor): Input with shape (B, C, H, W) target (torch.Tensor): Target with shape (B, C, H, W) focus (torch.Tensor): Focus, type of masking (B, C, W, W) device (str): CUDA device to work with. Returns: torch.Tensor: MSGE loss """ input = input.permute(0, 2, 3, 1) target = target.permute(0, 2, 3, 1) focus = focus.permute(0, 2, 3, 1) focus = focus[..., 1] focus = (focus[..., None]).float() # assume input NHW focus = torch.cat([focus, focus], axis=-1).to(device) true_grad = self.get_gradient_hv(target, device) pred_grad = self.get_gradient_hv(input, device) loss = pred_grad - true_grad loss = focus * (loss * loss) # artificial reduce_mean with focused region loss = loss.sum() / (focus.sum() + 1.0e-8) return loss class FocalTverskyLoss(nn.Module): """FocalTverskyLoss PyTorch implementation of the Focal Tversky Loss Function for multiple classes doi: 10.1109/ISBI.2019.8759329 Abraham, N., & Khan, N. M. (2019). A Novel Focal Tversky Loss Function With Improved Attention U-Net for Lesion Segmentation. In International Symposium on Biomedical Imaging. https://doi.org/10.1109/isbi.2019.8759329 @ Fabian Hörst, fabian.hoerst@uk-essen.de Institute for Artifical Intelligence in Medicine, University Medicine Essen Args: alpha_t (float, optional): Alpha parameter for tversky loss (multiplied with false-negatives). Defaults to 0.7. beta_t (float, optional): Beta parameter for tversky loss (multiplied with false-positives). Defaults to 0.3. gamma_f (float, optional): Gamma Focal parameter. Defaults to 4/3. smooth (float, optional): Smooting factor. Defaults to 0.000001. """ def __init__( self, alpha_t: float = 0.7, beta_t: float = 0.3, gamma_f: float = 4 / 3, smooth: float = 1e-6, ) -> None: super().__init__() self.alpha_t = alpha_t self.beta_t = beta_t self.gamma_f = gamma_f self.smooth = smooth self.num_classes = 2 def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Loss calculation Args: input (torch.Tensor): Predictions, logits (without Softmax). Shape: (B, C, H, W) target (torch.Tensor): Targets, either flattened (Shape: (C, H, W) or as one-hot encoded (Shape: (batch-size, C, H, W)). Raises: ValueError: Error if there is a shape missmatch Returns: torch.Tensor: FocalTverskyLoss (weighted) """ input = input.permute(0, 2, 3, 1) if input.shape[-1] != self.num_classes: raise ValueError( "Predictions must be a logit tensor with the last dimension shape beeing equal to the number of classes" ) if len(target.shape) != len(input.shape): # convert the targets to onehot target = F.one_hot(target, num_classes=self.num_classes) # flatten target = target.permute(0, 2, 3, 1) target = target.view(-1) input = torch.softmax(input, dim=-1).view(-1) # calculate true positives, false positives and false negatives tp = (input * target).sum() fp = ((1 - target) * input).sum() fn = (target * (1 - input)).sum() Tversky = (tp + self.smooth) / ( tp + self.alpha_t * fn + self.beta_t * fp + self.smooth ) FocalTversky = (1 - Tversky) ** self.gamma_f return FocalTversky class MCFocalTverskyLoss(FocalTverskyLoss): """Multiclass FocalTverskyLoss PyTorch implementation of the Focal Tversky Loss Function for multiple classes doi: 10.1109/ISBI.2019.8759329 Abraham, N., & Khan, N. M. (2019). A Novel Focal Tversky Loss Function With Improved Attention U-Net for Lesion Segmentation. In International Symposium on Biomedical Imaging. https://doi.org/10.1109/isbi.2019.8759329 @ Fabian Hörst, fabian.hoerst@uk-essen.de Institute for Artifical Intelligence in Medicine, University Medicine Essen Args: alpha_t (float, optional): Alpha parameter for tversky loss (multiplied with false-negatives). Defaults to 0.7. beta_t (float, optional): Beta parameter for tversky loss (multiplied with false-positives). Defaults to 0.3. gamma_f (float, optional): Gamma Focal parameter. Defaults to 4/3. smooth (float, optional): Smooting factor. Defaults to 0.000001. num_classes (int, optional): Number of output classes. For binary segmentation, prefer FocalTverskyLoss (speed optimized). Defaults to 2. class_weights (List[int], optional): Weights for each class. If not provided, equal weight. Length must be equal to num_classes. Defaults to None. """ def __init__( self, alpha_t: float = 0.7, beta_t: float = 0.3, gamma_f: float = 4 / 3, smooth: float = 0.000001, num_classes: int = 2, class_weights: List[int] = None, ) -> None: super().__init__(alpha_t, beta_t, gamma_f, smooth) self.num_classes = num_classes if class_weights is None: self.class_weights = [1 for i in range(self.num_classes)] else: assert ( len(class_weights) == self.num_classes ), "Please provide matching weights" self.class_weights = class_weights self.class_weights = torch.Tensor(self.class_weights) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Loss calculation Args: input (torch.Tensor): Predictions, logits (without Softmax). Shape: (B, num_classes, H, W) target (torch.Tensor): Targets, either flattened (Shape: (B, H, W) or as one-hot encoded (Shape: (B, num_classes, H, W)). Raises: ValueError: Error if there is a shape missmatch Returns: torch.Tensor: FocalTverskyLoss (weighted) """ input = input.permute(0, 2, 3, 1) if input.shape[-1] != self.num_classes: raise ValueError( "Predictions must be a logit tensor with the last dimension shape beeing equal to the number of classes" ) if len(target.shape) != len(input.shape): # convert the targets to onehot target = F.one_hot(target, num_classes=self.num_classes) target = target.permute(0, 2, 3, 1) # Softmax input = torch.softmax(input, dim=-1) # Reshape input = torch.permute(input, (3, 1, 2, 0)) target = torch.permute(target, (3, 1, 2, 0)) input = torch.flatten(input, start_dim=1) target = torch.flatten(target, start_dim=1) tp = torch.sum(input * target, 1) fp = torch.sum((1 - target) * input, 1) fn = torch.sum(target * (1 - input), 1) Tversky = (tp + self.smooth) / ( tp + self.alpha_t * fn + self.beta_t * fp + self.smooth ) FocalTversky = (1 - Tversky) ** self.gamma_f self.class_weights = self.class_weights.to(FocalTversky.device) return torch.sum(self.class_weights * FocalTversky) class WeightedBaseLoss(nn.Module): """Init a base class for weighted cross entropy based losses. Enables weighting for object instance edges and classes. Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) Args: apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the loss matrix. Defaults to False. apply_ls (bool, optional): If True, Label smoothing will be applied to the target.. Defaults to False. apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. class_weights (torch.Tensor, optional): Class weights. A tensor of shape (C, ). Defaults to None. edge_weight (float, optional): Weight for the object instance border pixels. Defaults to None. """ def __init__( self, apply_sd: bool = False, apply_ls: bool = False, apply_svls: bool = False, apply_mask: bool = False, class_weights: torch.Tensor = None, edge_weight: float = None, **kwargs, ) -> None: super().__init__() self.apply_sd = apply_sd self.apply_ls = apply_ls self.apply_svls = apply_svls self.apply_mask = apply_mask self.class_weights = class_weights self.edge_weight = edge_weight def apply_spectral_decouple( self, loss_matrix: torch.Tensor, yhat: torch.Tensor, lam: float = 0.01 ) -> torch.Tensor: """Apply spectral decoupling L2 norm after the loss. https://arxiv.org/abs/2011.09468 Args: loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). yhat (torch.Tensor): The pixel predictions of the model. Shape (B, C, H, W). lam (float, optional): Lambda constant.. Defaults to 0.01. Returns: torch.Tensor: SD-regularized loss matrix. Same shape as input. """ return loss_matrix + (lam / 2) * (yhat**2).mean(axis=1) def apply_ls_to_target( self, target: torch.Tensor, num_classes: int, label_smoothing: float = 0.1, ) -> torch.Tensor: """_summary_ Args: target (torch.Tensor): Number of classes in the data. num_classes (int): The target one hot tensor. Shape (B, C, H, W) label_smoothing (float, optional): The smoothing coeff alpha. Defaults to 0.1. Returns: torch.Tensor: Label smoothed target. Same shape as input. """ return target * (1 - label_smoothing) + label_smoothing / num_classes def apply_svls_to_target( self, target: torch.Tensor, num_classes: int, kernel_size: int = 5, sigma: int = 3, **kwargs, ) -> torch.Tensor: """Apply spatially varying label smoothihng to target map. https://arxiv.org/abs/2104.05788 Args: target (torch.Tensor): The target one hot tensor. Shape (B, C, H, W). num_classes (int): Number of classes in the data. kernel_size (int, optional): Size of a square kernel.. Defaults to 5. sigma (int, optional): The std of the gaussian. Defaults to 3. Returns: torch.Tensor: Label smoothed target. Same shape as input. """ my, mx = kernel_size // 2, kernel_size // 2 gaussian_kernel = gaussian_kernel2d( kernel_size, sigma, num_classes, device=target.device ) neighborsum = (1 - gaussian_kernel[..., my, mx]) + 1e-16 gaussian_kernel = gaussian_kernel.clone() gaussian_kernel[..., my, mx] = neighborsum svls_kernel = gaussian_kernel / neighborsum[0] return filter2D(target.float(), svls_kernel) / svls_kernel[0].sum() def apply_class_weights( self, loss_matrix: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """Multiply pixelwise loss matrix by the class weights. NOTE: No normalization Args: loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). target (torch.Tensor): The target mask. Shape (B, H, W). Returns: torch.Tensor: The loss matrix scaled with the weight matrix. Shape (B, H, W). """ weight_mat = self.class_weights[target.long()].to(target.device) # to (B, H, W) loss = loss_matrix * weight_mat return loss def apply_edge_weights( self, loss_matrix: torch.Tensor, weight_map: torch.Tensor ) -> torch.Tensor: """Apply weights to the object boundaries. Basically just computes `edge_weight`**`weight_map`. Args: loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). weight_map (torch.Tensor): Map that points to the pixels that will be weighted. Shape (B, H, W). Returns: torch.Tensor: The loss matrix scaled with the nuclear boundary weights. Shape (B, H, W). """ return loss_matrix * self.edge_weight**weight_map def apply_mask_weight( self, loss_matrix: torch.Tensor, mask: torch.Tensor, norm: bool = True ) -> torch.Tensor: """Apply a mask to the loss matrix. Args: loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). mask (torch.Tensor): The mask. Shape (B, H, W). norm (bool, optional): If True, the loss matrix will be normalized by the mean of the mask. Defaults to True. Returns: torch.Tensor: The loss matrix scaled with the mask. Shape (B, H, W). """ loss_matrix *= mask if norm: norm_mask = torch.mean(mask.float()) + 1e-7 loss_matrix /= norm_mask return loss_matrix def extra_repr(self) -> str: """Add info to print.""" s = "apply_sd={apply_sd}, apply_ls={apply_ls}, apply_svls={apply_svls}, apply_mask={apply_mask}, class_weights={class_weights}, edge_weight={edge_weight}" # noqa return s.format(**self.__dict__) class MAEWeighted(WeightedBaseLoss): """Compute the MAE loss. Used in the stardist method. Stardist: https://arxiv.org/pdf/1806.03535.pdf Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) NOTE: We have added the option to apply spectral decoupling and edge weights to the loss matrix. Args: alpha (float, optional): Weight regulizer b/w [0,1]. In stardist repo, this is the 'train_background_reg' parameter. Defaults to 1e-4. apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the loss matrix. Defaults to False. apply_mask (bool, optional): f True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. edge_weight (float, optional): Weight that is added to object borders. Defaults to None. """ def __init__( self, alpha: float = 1e-4, apply_sd: bool = False, apply_mask: bool = False, edge_weight: float = None, **kwargs, ) -> None: super().__init__(apply_sd, False, False, apply_mask, False, edge_weight) self.alpha = alpha self.eps = 1e-7 def forward( self, input: torch.Tensor, target: torch.Tensor, target_weight: torch.Tensor = None, mask: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """Compute the masked MAE loss. Args: input (torch.Tensor): The prediction map. Shape (B, C, H, W). target (torch.Tensor): The ground truth annotations. Shape (B, H, W). target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. Raises: ValueError: Pred and target shapes must match. Returns: torch.Tensor: Computed MAE loss (scalar). """ yhat = input n_classes = yhat.shape[1] if target.size() != yhat.size(): target = target.unsqueeze(1).repeat_interleave(n_classes, dim=1) if not yhat.shape == target.shape: raise ValueError( f"Pred and target shapes must match. Got: {yhat.shape}, {target.shape}" ) # compute the MAE loss with alpha as weight mae_loss = torch.mean(torch.abs(target - yhat), axis=1) # (B, H, W) if self.apply_mask and mask is not None: mae_loss = self.apply_mask_weight(mae_loss, mask, norm=True) # (B, H, W) # add the background regularization if self.alpha > 0: reg = torch.mean(((1 - mask).unsqueeze(1)) * torch.abs(yhat), axis=1) mae_loss += self.alpha * reg if self.apply_sd: mae_loss = self.apply_spectral_decouple(mae_loss, yhat) if self.edge_weight is not None: mae_loss = self.apply_edge_weights(mae_loss, target_weight) return mae_loss.mean() class MSEWeighted(WeightedBaseLoss): """MSE-loss. Args: apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the loss matrix. Defaults to False. apply_ls (bool, optional): If True, Label smoothing will be applied to the target. Defaults to False. apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. edge_weight (float, optional): Weight that is added to object borders. Defaults to None. class_weights (torch.Tensor, optional): Class weights. A tensor of shape (n_classes,). Defaults to None. """ def __init__( self, apply_sd: bool = False, apply_ls: bool = False, apply_svls: bool = False, apply_mask: bool = False, edge_weight: float = None, class_weights: torch.Tensor = None, **kwargs, ) -> None: super().__init__( apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight ) @staticmethod def tensor_one_hot(type_map: torch.Tensor, n_classes: int) -> torch.Tensor: """Convert a segmentation mask into one-hot-format. I.e. Takes in a segmentation mask of shape (B, H, W) and reshapes it into a tensor of shape (B, C, H, W). Args: type_map (torch.Tensor): Multi-label Segmentation mask. Shape (B, H, W). n_classes (int): Number of classes. (Zero-class included.) Raises: TypeError: Input `type_map` should have dtype: torch.int64. Returns: torch.Tensor: A one hot tensor. Shape: (B, C, H, W). Dtype: torch.FloatTensor. """ if not type_map.dtype == torch.int64: raise TypeError( f""" Input `type_map` should have dtype: torch.int64. Got: {type_map.dtype}.""" ) one_hot = torch.zeros( type_map.shape[0], n_classes, *type_map.shape[1:], device=type_map.device, dtype=type_map.dtype, ) return one_hot.scatter_(dim=1, index=type_map.unsqueeze(1), value=1.0) + 1e-7 def forward( self, input: torch.Tensor, target: torch.Tensor, target_weight: torch.Tensor = None, mask: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """Compute the MSE-loss. Args: input (torch.Tensor): The prediction map. Shape (B, C, H, W, C). target (torch.Tensor): The ground truth annotations. Shape (B, H, W). target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. Returns: torch.Tensor: Computed MSE loss (scalar). """ yhat = input target_one_hot = target num_classes = yhat.shape[1] if target.size() != yhat.size(): if target.dtype == torch.float32: target_one_hot = target.unsqueeze(1) else: target_one_hot = MSEWeighted.tensor_one_hot(target, num_classes) if self.apply_svls: target_one_hot = self.apply_svls_to_target( target_one_hot, num_classes, **kwargs ) if self.apply_ls: target_one_hot = self.apply_ls_to_target( target_one_hot, num_classes, **kwargs ) mse = F.mse_loss(yhat, target_one_hot, reduction="none") # (B, C, H, W) mse = torch.mean(mse, dim=1) # to (B, H, W) if self.apply_mask and mask is not None: mse = self.apply_mask_weight(mse, mask, norm=False) # (B, H, W) if self.apply_sd: mse = self.apply_spectral_decouple(mse, yhat) if self.class_weights is not None: mse = self.apply_class_weights(mse, target) if self.edge_weight is not None: mse = self.apply_edge_weights(mse, target_weight) return torch.mean(mse) class BCEWeighted(WeightedBaseLoss): def __init__( self, apply_sd: bool = False, apply_ls: bool = False, apply_svls: bool = False, apply_mask: bool = False, edge_weight: float = None, class_weights: torch.Tensor = None, **kwargs, ) -> None: """Binary cross entropy loss with weighting and other tricks. Parameters ---------- apply_sd : bool, default=False If True, Spectral decoupling regularization will be applied to the loss matrix. apply_ls : bool, default=False If True, Label smoothing will be applied to the target. apply_svls : bool, default=False If True, spatially varying label smoothing will be applied to the target apply_mask : bool, default=False If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W) edge_weight : float, default=None Weight that is added to object borders. class_weights : torch.Tensor, default=None Class weights. A tensor of shape (n_classes,). """ super().__init__( apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight ) self.eps = 1e-8 def forward( self, input: torch.Tensor, target: torch.Tensor, target_weight: torch.Tensor = None, mask: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """Compute binary cross entropy loss. Parameters ---------- yhat : torch.Tensor The prediction map. Shape (B, C, H, W). target : torch.Tensor the ground truth annotations. Shape (B, H, W). target_weight : torch.Tensor, default=None The edge weight map. Shape (B, H, W). mask : torch.Tensor, default=None The mask map. Shape (B, H, W). Returns ------- torch.Tensor: Computed BCE loss (scalar). """ # Logits input yhat = input num_classes = yhat.shape[1] yhat = torch.clip(yhat, self.eps, 1.0 - self.eps) if target.size() != yhat.size(): target = target.unsqueeze(1).repeat_interleave(num_classes, dim=1) if self.apply_svls: target = self.apply_svls_to_target(target, num_classes, **kwargs) if self.apply_ls: target = self.apply_ls_to_target(target, num_classes, **kwargs) bce = F.binary_cross_entropy_with_logits( yhat.float(), target.float(), reduction="none" ) # (B, C, H, W) bce = torch.mean(bce, dim=1) # (B, H, W) if self.apply_mask and mask is not None: bce = self.apply_mask_weight(bce, mask, norm=False) # (B, H, W) if self.apply_sd: bce = self.apply_spectral_decouple(bce, yhat) if self.class_weights is not None: bce = self.apply_class_weights(bce, target) if self.edge_weight is not None: bce = self.apply_edge_weights(bce, target_weight) return torch.mean(bce) # class BCEWeighted(WeightedBaseLoss): # """Binary cross entropy loss with weighting and other tricks. # Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) # Args: # apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the # loss matrix. Defaults to False. # apply_ls (bool, optional): If True, Label smoothing will be applied to the target. Defaults to False. # apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. # apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. # edge_weight (float, optional): Weight that is added to object borders. Defaults to None. # class_weights (torch.Tensor, optional): Class weights. A tensor of shape (n_classes,). Defaults to None. # """ # def __init__( # self, # apply_sd: bool = False, # apply_ls: bool = False, # apply_svls: bool = False, # apply_mask: bool = False, # edge_weight: float = None, # class_weights: torch.Tensor = None, # **kwargs, # ) -> None: # super().__init__( # apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight # ) # self.eps = 1e-8 # def forward( # self, # input: torch.Tensor, # target: torch.Tensor, # target_weight: torch.Tensor = None, # mask: torch.Tensor = None, # **kwargs, # ) -> torch.Tensor: # """Compute binary cross entropy loss. # Args: # input (torch.Tensor): The prediction map. We internally convert back via logit function. Shape (B, C, H, W). # target (torch.Tensor): the ground truth annotations. Shape (B, H, W). # target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. # mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. # Returns: # torch.Tensor: Computed BCE loss (scalar). # """ # yhat = input # yhat = torch.special.logit(yhat) # num_classes = yhat.shape[1] # yhat = torch.clip(yhat, self.eps, 1.0 - self.eps) # if target.size() != yhat.size(): # target = target.unsqueeze(1).repeat_interleave(num_classes, dim=1) # if self.apply_svls: # target = self.apply_svls_to_target(target, num_classes, **kwargs) # if self.apply_ls: # target = self.apply_ls_to_target(target, num_classes, **kwargs) # bce = F.binary_cross_entropy_with_logits( # yhat.float(), target.float(), reduction="none" # ) # (B, C, H, W) # bce = torch.mean(bce, dim=1) # (B, H, W) # if self.apply_mask and mask is not None: # bce = self.apply_mask_weight(bce, mask, norm=False) # (B, H, W) # if self.apply_sd: # bce = self.apply_spectral_decouple(bce, yhat) # if self.class_weights is not None: # bce = self.apply_class_weights(bce, target) # if self.edge_weight is not None: # bce = self.apply_edge_weights(bce, target_weight) # return torch.mean(bce) class CEWeighted(WeightedBaseLoss): def __init__( self, apply_sd: bool = False, apply_ls: bool = False, apply_svls: bool = False, apply_mask: bool = False, edge_weight: float = None, class_weights: torch.Tensor = None, **kwargs, ) -> None: """Cross-Entropy loss with weighting. Parameters ---------- apply_sd : bool, default=False If True, Spectral decoupling regularization will be applied to the loss matrix. apply_ls : bool, default=False If True, Label smoothing will be applied to the target. apply_svls : bool, default=False If True, spatially varying label smoothing will be applied to the target apply_mask : bool, default=False If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W) edge_weight : float, default=None Weight that is added to object borders. class_weights : torch.Tensor, default=None Class weights. A tensor of shape (n_classes,). """ super().__init__( apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight ) self.eps = 1e-8 def forward( self, input: torch.Tensor, target: torch.Tensor, target_weight: torch.Tensor = None, mask: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """Compute the cross entropy loss. Parameters ---------- yhat : torch.Tensor The prediction map. Shape (B, C, H, W). target : torch.Tensor the ground truth annotations. Shape (B, H, W). target_weight : torch.Tensor, default=None The edge weight map. Shape (B, H, W). mask : torch.Tensor, default=None The mask map. Shape (B, H, W). Returns ------- torch.Tensor: Computed CE loss (scalar). """ yhat = input # TODO: remove doubled Softmax -> this function needs logits instead of softmax output input_soft = F.softmax(yhat, dim=1) + self.eps # (B, C, H, W) num_classes = yhat.shape[1] if len(target.shape) != len(yhat.shape) and target.shape[1] != num_classes: target_one_hot = MSEWeighted.tensor_one_hot( target, num_classes ) # (B, C, H, W) else: target_one_hot = target target = torch.argmax(target, dim=1) assert target_one_hot.shape == yhat.shape if self.apply_svls: target_one_hot = self.apply_svls_to_target( target_one_hot, num_classes, **kwargs ) if self.apply_ls: target_one_hot = self.apply_ls_to_target( target_one_hot, num_classes, **kwargs ) loss = -torch.sum(target_one_hot * torch.log(input_soft), dim=1) # (B, H, W) if self.apply_mask and mask is not None: loss = self.apply_mask_weight(loss, mask, norm=False) # (B, H, W) if self.apply_sd: loss = self.apply_spectral_decouple(loss, yhat) if self.class_weights is not None: loss = self.apply_class_weights(loss, target) if self.edge_weight is not None: loss = self.apply_edge_weights(loss, target_weight) return loss.mean() # class CEWeighted(WeightedBaseLoss): # """Cross-Entropy loss with weighting. # Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) # Args: # apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the loss matrix. Defaults to False. # apply_ls (bool, optional): If True, Label smoothing will be applied to the target. Defaults to False. # apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. # apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. # edge_weight (float, optional): Weight that is added to object borders. Defaults to None. # class_weights (torch.Tensor, optional): Class weights. A tensor of shape (n_classes,). Defaults to None. # logits (bool, optional): If work on logit values. Defaults to False. Defaults to False. # """ # def __init__( # self, # apply_sd: bool = False, # apply_ls: bool = False, # apply_svls: bool = False, # apply_mask: bool = False, # edge_weight: float = None, # class_weights: torch.Tensor = None, # logits: bool = False, # **kwargs, # ) -> None: # super().__init__( # apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight # ) # self.eps = 1e-8 # self.logits = logits # def forward( # self, # input: torch.Tensor, # target: torch.Tensor, # target_weight: torch.Tensor = None, # mask: torch.Tensor = None, # **kwargs, # ) -> torch.Tensor: # """Compute the cross entropy loss. # Args: # input (torch.Tensor): The prediction map. Shape (B, C, H, W). # target (torch.Tensor): The ground truth annotations. Shape (B, H, W). # target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. # mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. # Returns: # torch.Tensor: Computed CE loss (scalar). # """ # yhat = input # if self.logits: # input_soft = ( # F.softmax(yhat, dim=1) + self.eps # ) # (B, C, H, W) # check if doubled softmax # else: # input_soft = input # num_classes = yhat.shape[1] # if len(target.shape) != len(yhat.shape) and target.shape[1] != num_classes: # target_one_hot = MSEWeighted.tensor_one_hot( # target, num_classes # ) # (B, C, H, W) # else: # target_one_hot = target # target = torch.argmax(target, dim=1) # assert target_one_hot.shape == yhat.shape # if self.apply_svls: # target_one_hot = self.apply_svls_to_target( # target_one_hot, num_classes, **kwargs # ) # if self.apply_ls: # target_one_hot = self.apply_ls_to_target( # target_one_hot, num_classes, **kwargs # ) # loss = -torch.sum(target_one_hot * torch.log(input_soft), dim=1) # (B, H, W) # if self.apply_mask and mask is not None: # loss = self.apply_mask_weight(loss, mask, norm=False) # (B, H, W) # if self.apply_sd: # loss = self.apply_spectral_decouple(loss, yhat) # if self.class_weights is not None: # loss = self.apply_class_weights(loss, target) # if self.edge_weight is not None: # loss = self.apply_edge_weights(loss, target_weight) # return loss.mean() ### Stardist loss functions class L1LossWeighted(nn.Module): def __init__(self) -> None: super().__init__() def forward( self, input: torch.Tensor, target: torch.Tensor, target_weight: torch.Tensor = None, ) -> torch.Tensor: l1loss = F.l1_loss(input, target, size_average=True, reduce=False) l1loss = torch.mean(l1loss, dim=1) if target_weight is not None: l1loss = torch.mean(target_weight * l1loss) else: l1loss = torch.mean(l1loss) return l1loss def retrieve_loss_fn(loss_name: dict, **kwargs) -> _Loss: """Return the loss function with given name defined in the LOSS_DICT and initialize with kwargs kwargs must match with the parameters defined in the initialization method of the selected loss object Args: loss_name (dict): Name of the loss function Returns: _Loss: Loss """ loss_fn = LOSS_DICT[loss_name] loss_fn = loss_fn(**kwargs) return loss_fn LOSS_DICT = { "xentropy_loss": XentropyLoss, "dice_loss": DiceLoss, "mse_loss_maps": MSELossMaps, "msge_loss_maps": MSGELossMaps, "FocalTverskyLoss": FocalTverskyLoss, "MCFocalTverskyLoss": MCFocalTverskyLoss, "CrossEntropyLoss": nn.CrossEntropyLoss, # input logits, targets "L1Loss": nn.L1Loss, "MSELoss": nn.MSELoss, "CTCLoss": nn.CTCLoss, # probability "NLLLoss": nn.NLLLoss, # log-probabilities of each class "PoissonNLLLoss": nn.PoissonNLLLoss, "GaussianNLLLoss": nn.GaussianNLLLoss, "KLDivLoss": nn.KLDivLoss, # argument input in log-space "BCELoss": nn.BCELoss, # probabilities "BCEWithLogitsLoss": nn.BCEWithLogitsLoss, # logits "MarginRankingLoss": nn.MarginRankingLoss, "HingeEmbeddingLoss": nn.HingeEmbeddingLoss, "MultiLabelMarginLoss": nn.MultiLabelMarginLoss, "HuberLoss": nn.HuberLoss, "SmoothL1Loss": nn.SmoothL1Loss, "SoftMarginLoss": nn.SoftMarginLoss, # logits "MultiLabelSoftMarginLoss": nn.MultiLabelSoftMarginLoss, "CosineEmbeddingLoss": nn.CosineEmbeddingLoss, "MultiMarginLoss": nn.MultiMarginLoss, "TripletMarginLoss": nn.TripletMarginLoss, "TripletMarginWithDistanceLoss": nn.TripletMarginWithDistanceLoss, "MAEWeighted": MAEWeighted, "MSEWeighted": MSEWeighted, "BCEWeighted": BCEWeighted, # logits "CEWeighted": CEWeighted, # logits "L1LossWeighted": L1LossWeighted, }