# Copyright (c) OpenMMLab. All rights reserved. import random import torch import torch.distributed as dist import torch.nn.functional as F from mmcv.runner import get_dist_info from ...utils import log_img_scale from ..builder import DETECTORS from .single_stage import SingleStageDetector @DETECTORS.register_module() class YOLOX(SingleStageDetector): r"""Implementation of `YOLOX: Exceeding YOLO Series in 2021 `_ Note: Considering the trade-off between training speed and accuracy, multi-scale training is temporarily kept. More elegant implementation will be adopted in the future. Args: backbone (nn.Module): The backbone module. neck (nn.Module): The neck module. bbox_head (nn.Module): The bbox head module. train_cfg (obj:`ConfigDict`, optional): The training config of YOLOX. Default: None. test_cfg (obj:`ConfigDict`, optional): The testing config of YOLOX. Default: None. pretrained (str, optional): model pretrained path. Default: None. input_size (tuple): The model default input image size. The shape order should be (height, width). Default: (640, 640). size_multiplier (int): Image size multiplication factor. Default: 32. random_size_range (tuple): The multi-scale random range during multi-scale training. The real training image size will be multiplied by size_multiplier. Default: (15, 25). random_size_interval (int): The iter interval of change image size. Default: 10. init_cfg (dict, optional): Initialization config dict. Default: None. """ def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None, input_size=(640, 640), size_multiplier=32, random_size_range=(15, 25), random_size_interval=10, init_cfg=None): super(YOLOX, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained, init_cfg) log_img_scale(input_size, skip_square=True) self.rank, self.world_size = get_dist_info() self._default_input_size = input_size self._input_size = input_size self._random_size_range = random_size_range self._random_size_interval = random_size_interval self._size_multiplier = size_multiplier self._progress_in_iter = 0 def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None): """ Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Each item are the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): Class indices corresponding to each box gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes can be ignored when computing the loss. Returns: dict[str, Tensor]: A dictionary of loss components. """ # Multi-scale training img, gt_bboxes = self._preprocess(img, gt_bboxes) losses = super(YOLOX, self).forward_train(img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) # random resizing if (self._progress_in_iter + 1) % self._random_size_interval == 0: self._input_size = self._random_resize(device=img.device) self._progress_in_iter += 1 return losses def _preprocess(self, img, gt_bboxes): scale_y = self._input_size[0] / self._default_input_size[0] scale_x = self._input_size[1] / self._default_input_size[1] if scale_x != 1 or scale_y != 1: img = F.interpolate( img, size=self._input_size, mode='bilinear', align_corners=False) for gt_bbox in gt_bboxes: gt_bbox[..., 0::2] = gt_bbox[..., 0::2] * scale_x gt_bbox[..., 1::2] = gt_bbox[..., 1::2] * scale_y return img, gt_bboxes def _random_resize(self, device): tensor = torch.LongTensor(2).to(device) if self.rank == 0: size = random.randint(*self._random_size_range) aspect_ratio = float( self._default_input_size[1]) / self._default_input_size[0] size = (self._size_multiplier * size, self._size_multiplier * int(aspect_ratio * size)) tensor[0] = size[0] tensor[1] = size[1] if self.world_size > 1: dist.barrier() dist.broadcast(tensor, 0) input_size = (tensor[0].item(), tensor[1].item()) return input_size