# Copyright (c) OpenMMLab. All rights reserved. import copy import math from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, is_norm from mmcv.ops import batched_nms from mmengine.model import (BaseModule, bias_init_with_prob, constant_init, normal_init) from mmengine.structures import InstanceData from torch import Tensor from mmdet.models.layers.transformer import inverse_sigmoid from mmdet.models.utils import (filter_scores_and_topk, multi_apply, select_single_mlvl, sigmoid_geometric_mean) from mmdet.registry import MODELS from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor, get_box_wh, scale_boxes) from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean from mmdet.models.dense_heads.rtmdet_head import RTMDetHead from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead, MaskFeatModule from mmdet.utils import AvoidCUDAOOM def sthgoeswrong(logits): return torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits)) from time import time @MODELS.register_module(force=True) class RTMDetInsHeadCustom(RTMDetInsHead): def loss_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], kernel_preds: List[Tensor], mask_feat: Tensor, batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None): """Compute losses of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Decoded box for each scale level with shape (N, num_anchors * 4, H, W) in [tl_x, tl_y, br_x, br_y] format. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_imgs = len(batch_img_metas) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels device = cls_scores[0].device anchor_list, valid_flag_list = self.get_anchors( featmap_sizes, batch_img_metas, device=device) flatten_cls_scores = torch.cat([ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) for cls_score in cls_scores ], 1) flatten_kernels = torch.cat([ kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_gen_params) for kernel_pred in kernel_preds ], 1) decoded_bboxes = [] for anchor, bbox_pred in zip(anchor_list[0], bbox_preds): anchor = anchor.reshape(-1, 4) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) bbox_pred = distance2bbox(anchor, bbox_pred) decoded_bboxes.append(bbox_pred) flatten_bboxes = torch.cat(decoded_bboxes, 1) for gt_instances in batch_gt_instances: gt_instances.masks = gt_instances.masks.to_tensor( dtype=torch.bool, device=device) cls_reg_targets = self.get_targets( flatten_cls_scores, flatten_bboxes, anchor_list, valid_flag_list, batch_gt_instances, batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore) (anchor_list, labels_list, label_weights_list, bbox_targets_list, assign_metrics_list, sampling_results_list) = cls_reg_targets losses_cls, losses_bbox,\ cls_avg_factors, bbox_avg_factors = multi_apply( self.loss_by_feat_single, cls_scores, decoded_bboxes, labels_list, label_weights_list, bbox_targets_list, assign_metrics_list, self.prior_generator.strides) cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) bbox_avg_factor = reduce_mean( sum(bbox_avg_factors)).clamp_(min=1).item() losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels, sampling_results_list, batch_gt_instances) loss = dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask) return loss def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor, priors: Tensor) -> Tensor: ori_maskfeat = mask_feat num_inst = priors.shape[0] h, w = mask_feat.size()[-2:] if num_inst < 1: return torch.empty( size=(num_inst, h, w), dtype=mask_feat.dtype, device=mask_feat.device) if len(mask_feat.shape) < 4: mask_feat.unsqueeze(0) coord = self.prior_generator.single_level_grid_priors( (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2) num_inst = priors.shape[0] points = priors[:, :2].reshape(-1, 1, 2) strides = priors[:, 2:].reshape(-1, 1, 2) relative_coord = (points - coord).permute(0, 2, 1) / ( strides[..., 0].reshape(-1, 1, 1) * 8) relative_coord = relative_coord.reshape(num_inst, 2, h, w) mask_feat = torch.cat( [relative_coord, mask_feat.repeat(num_inst, 1, 1, 1)], dim=1) weights, biases = self.parse_dynamic_params(kernels) fp16_used = weights[0].dtype == torch.float16 n_layers = len(weights) x = mask_feat.reshape(1, -1, h, w) for i, (weight, bias) in enumerate(zip(weights, biases)): with torch.cuda.amp.autocast(enabled=False): if fp16_used: weight = weight.to(torch.float32) bias = bias.to(torch.float32) x = F.conv2d( x, weight, bias=bias, stride=1, padding=0, groups=num_inst) if i < n_layers - 1: x = F.relu(x) if fp16_used: x = torch.clip(x, -8192, 8192) if sthgoeswrong(x): torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt') raise Exception('Mask Head NaN') x = x.reshape(num_inst, h, w) return x def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, sampling_results_list: list, batch_gt_instances: InstanceList) -> Tensor: batch_pos_mask_logits = [] pos_gt_masks = [] ignore_masks = [] for idx, (mask_feat, kernels, sampling_results, gt_instances) in enumerate( zip(mask_feats, flatten_kernels, sampling_results_list, batch_gt_instances)): pos_priors = sampling_results.pos_priors pos_inds = sampling_results.pos_inds pos_kernels = kernels[pos_inds] # n_pos, num_gen_params pos_mask_logits = self._mask_predict_by_feat_single( mask_feat, pos_kernels, pos_priors) if gt_instances.masks.numel() == 0: gt_masks = torch.empty_like(gt_instances.masks) if gt_masks.shape[0] > 0: ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device) ignore_masks.append(ignore) else: gt_masks = gt_instances.masks[ sampling_results.pos_assigned_gt_inds, :] ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds]) batch_pos_mask_logits.append(pos_mask_logits) pos_gt_masks.append(gt_masks) pos_gt_masks = torch.cat(pos_gt_masks, 0) batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0) ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0)) pos_gt_masks = pos_gt_masks[ignore_masks] batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks] # avg_factor num_pos = batch_pos_mask_logits.shape[0] num_pos = reduce_mean(mask_feats.new_tensor([num_pos ])).clamp_(min=1).item() if batch_pos_mask_logits.shape[0] == 0: return mask_feats.sum() * 0 scale = self.prior_generator.strides[0][0] // self.mask_loss_stride # upsample pred masks batch_pos_mask_logits = F.interpolate( batch_pos_mask_logits.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) # downsample gt masks pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride // 2::self.mask_loss_stride, self.mask_loss_stride // 2::self.mask_loss_stride] loss_mask = self.loss_mask( batch_pos_mask_logits, pos_gt_masks, weight=None, avg_factor=num_pos) return loss_mask @MODELS.register_module() class RTMDetInsSepBNHeadCustom(RTMDetInsSepBNHead): def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor, priors: Tensor) -> Tensor: ori_maskfeat = mask_feat num_inst = priors.shape[0] h, w = mask_feat.size()[-2:] if num_inst < 1: return torch.empty( size=(num_inst, h, w), dtype=mask_feat.dtype, device=mask_feat.device) if len(mask_feat.shape) < 4: mask_feat.unsqueeze(0) coord = self.prior_generator.single_level_grid_priors( (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2) num_inst = priors.shape[0] points = priors[:, :2].reshape(-1, 1, 2) strides = priors[:, 2:].reshape(-1, 1, 2) relative_coord = (points - coord).permute(0, 2, 1) / ( strides[..., 0].reshape(-1, 1, 1) * 8) relative_coord = relative_coord.reshape(num_inst, 2, h, w) mask_feat = torch.cat( [relative_coord, mask_feat.repeat(num_inst, 1, 1, 1)], dim=1) weights, biases = self.parse_dynamic_params(kernels) fp16_used = weights[0].dtype == torch.float16 n_layers = len(weights) x = mask_feat.reshape(1, -1, h, w) for i, (weight, bias) in enumerate(zip(weights, biases)): with torch.cuda.amp.autocast(enabled=False): if fp16_used: weight = weight.to(torch.float32) bias = bias.to(torch.float32) x = F.conv2d( x, weight, bias=bias, stride=1, padding=0, groups=num_inst) if i < n_layers - 1: x = F.relu(x) if fp16_used: x = torch.clip(x, -8192, 8192) if sthgoeswrong(x): torch.save({'mask_feat': ori_maskfeat, 'kernels': kernels, 'priors': priors}, 'maskhead_nan_input.pt') raise Exception('Mask Head NaN') x = x.reshape(num_inst, h, w) return x @AvoidCUDAOOM.retry_if_cuda_oom def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, sampling_results_list: list, batch_gt_instances: InstanceList) -> Tensor: batch_pos_mask_logits = [] pos_gt_masks = [] ignore_masks = [] for idx, (mask_feat, kernels, sampling_results, gt_instances) in enumerate( zip(mask_feats, flatten_kernels, sampling_results_list, batch_gt_instances)): pos_priors = sampling_results.pos_priors pos_inds = sampling_results.pos_inds pos_kernels = kernels[pos_inds] # n_pos, num_gen_params pos_mask_logits = self._mask_predict_by_feat_single( mask_feat, pos_kernels, pos_priors) if gt_instances.masks.numel() == 0: gt_masks = torch.empty_like(gt_instances.masks) # if gt_masks.shape[0] > 0: # ignore = torch.zeros(gt_masks.shape[0], dtype=torch.bool).to(device=gt_masks.device) # ignore_masks.append(ignore) else: msk = torch.logical_not(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds]) gt_masks = gt_instances.masks[ sampling_results.pos_assigned_gt_inds, :][msk] pos_mask_logits = pos_mask_logits[msk] # ignore_masks.append(gt_instances.ignore_mask[sampling_results.pos_assigned_gt_inds]) batch_pos_mask_logits.append(pos_mask_logits) pos_gt_masks.append(gt_masks) pos_gt_masks = torch.cat(pos_gt_masks, 0) batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0) # ignore_masks = torch.logical_not(torch.cat(ignore_masks, 0)) # pos_gt_masks = pos_gt_masks[ignore_masks] # batch_pos_mask_logits = batch_pos_mask_logits[ignore_masks] # avg_factor num_pos = batch_pos_mask_logits.shape[0] num_pos = reduce_mean(mask_feats.new_tensor([num_pos ])).clamp_(min=1).item() if batch_pos_mask_logits.shape[0] == 0: return mask_feats.sum() * 0 scale = self.prior_generator.strides[0][0] // self.mask_loss_stride # upsample pred masks batch_pos_mask_logits = F.interpolate( batch_pos_mask_logits.unsqueeze(0), scale_factor=scale, mode='bilinear', align_corners=False).squeeze(0) # downsample gt masks pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride // 2::self.mask_loss_stride, self.mask_loss_stride // 2::self.mask_loss_stride] loss_mask = self.loss_mask( batch_pos_mask_logits, pos_gt_masks, weight=None, avg_factor=num_pos) return loss_mask