|
|
|
import numpy as np |
|
from typing import List |
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
from torch import nn |
|
|
|
from detectron2.config import configurable |
|
from detectron2.layers import Conv2d, ShapeSpec, get_norm |
|
from detectron2.utils.registry import Registry |
|
|
|
__all__ = ["FastRCNNConvFCHead", "build_box_head", "ROI_BOX_HEAD_REGISTRY"] |
|
|
|
ROI_BOX_HEAD_REGISTRY = Registry("ROI_BOX_HEAD") |
|
ROI_BOX_HEAD_REGISTRY.__doc__ = """ |
|
Registry for box heads, which make box predictions from per-region features. |
|
|
|
The registered object will be called with `obj(cfg, input_shape)`. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
@ROI_BOX_HEAD_REGISTRY.register() |
|
class FastRCNNConvFCHead(nn.Sequential): |
|
""" |
|
A head with several 3x3 conv layers (each followed by norm & relu) and then |
|
several fc layers (each followed by relu). |
|
""" |
|
|
|
@configurable |
|
def __init__( |
|
self, input_shape: ShapeSpec, *, conv_dims: List[int], fc_dims: List[int], conv_norm="" |
|
): |
|
""" |
|
NOTE: this interface is experimental. |
|
|
|
Args: |
|
input_shape (ShapeSpec): shape of the input feature. |
|
conv_dims (list[int]): the output dimensions of the conv layers |
|
fc_dims (list[int]): the output dimensions of the fc layers |
|
conv_norm (str or callable): normalization for the conv layers. |
|
See :func:`detectron2.layers.get_norm` for supported types. |
|
""" |
|
super().__init__() |
|
assert len(conv_dims) + len(fc_dims) > 0 |
|
|
|
self._output_size = (input_shape.channels, input_shape.height, input_shape.width) |
|
|
|
self.conv_norm_relus = [] |
|
for k, conv_dim in enumerate(conv_dims): |
|
conv = Conv2d( |
|
self._output_size[0], |
|
conv_dim, |
|
kernel_size=3, |
|
padding=1, |
|
bias=not conv_norm, |
|
norm=get_norm(conv_norm, conv_dim), |
|
activation=nn.ReLU(), |
|
) |
|
self.add_module("conv{}".format(k + 1), conv) |
|
self.conv_norm_relus.append(conv) |
|
self._output_size = (conv_dim, self._output_size[1], self._output_size[2]) |
|
|
|
self.fcs = [] |
|
for k, fc_dim in enumerate(fc_dims): |
|
if k == 0: |
|
self.add_module("flatten", nn.Flatten()) |
|
fc = nn.Linear(int(np.prod(self._output_size)), fc_dim) |
|
self.add_module("fc{}".format(k + 1), fc) |
|
self.add_module("fc_relu{}".format(k + 1), nn.ReLU()) |
|
self.fcs.append(fc) |
|
self._output_size = fc_dim |
|
|
|
for layer in self.conv_norm_relus: |
|
weight_init.c2_msra_fill(layer) |
|
for layer in self.fcs: |
|
weight_init.c2_xavier_fill(layer) |
|
|
|
@classmethod |
|
def from_config(cls, cfg, input_shape): |
|
num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV |
|
conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM |
|
num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC |
|
fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM |
|
return { |
|
"input_shape": input_shape, |
|
"conv_dims": [conv_dim] * num_conv, |
|
"fc_dims": [fc_dim] * num_fc, |
|
"conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM, |
|
} |
|
|
|
def forward(self, x): |
|
for layer in self: |
|
x = layer(x) |
|
return x |
|
|
|
@property |
|
@torch.jit.unused |
|
def output_shape(self): |
|
""" |
|
Returns: |
|
ShapeSpec: the output feature shape |
|
""" |
|
o = self._output_size |
|
if isinstance(o, int): |
|
return ShapeSpec(channels=o) |
|
else: |
|
return ShapeSpec(channels=o[0], height=o[1], width=o[2]) |
|
|
|
|
|
def build_box_head(cfg, input_shape): |
|
""" |
|
Build a box head defined by `cfg.MODEL.ROI_BOX_HEAD.NAME`. |
|
""" |
|
name = cfg.MODEL.ROI_BOX_HEAD.NAME |
|
return ROI_BOX_HEAD_REGISTRY.get(name)(cfg, input_shape) |
|
|