Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from torch import nn | |
from torch.autograd import Function | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext('_ext', [ | |
'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward', | |
'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward', | |
'right_pool_forward', 'right_pool_backward' | |
]) | |
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} | |
class TopPoolFunction(Function): | |
def symbolic(g, input): | |
output = g.op( | |
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top'])) | |
return output | |
def forward(ctx, input): | |
output = ext_module.top_pool_forward(input) | |
ctx.save_for_backward(input) | |
return output | |
def backward(ctx, grad_output): | |
input, = ctx.saved_tensors | |
output = ext_module.top_pool_backward(input, grad_output) | |
return output | |
class BottomPoolFunction(Function): | |
def symbolic(g, input): | |
output = g.op( | |
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom'])) | |
return output | |
def forward(ctx, input): | |
output = ext_module.bottom_pool_forward(input) | |
ctx.save_for_backward(input) | |
return output | |
def backward(ctx, grad_output): | |
input, = ctx.saved_tensors | |
output = ext_module.bottom_pool_backward(input, grad_output) | |
return output | |
class LeftPoolFunction(Function): | |
def symbolic(g, input): | |
output = g.op( | |
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left'])) | |
return output | |
def forward(ctx, input): | |
output = ext_module.left_pool_forward(input) | |
ctx.save_for_backward(input) | |
return output | |
def backward(ctx, grad_output): | |
input, = ctx.saved_tensors | |
output = ext_module.left_pool_backward(input, grad_output) | |
return output | |
class RightPoolFunction(Function): | |
def symbolic(g, input): | |
output = g.op( | |
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right'])) | |
return output | |
def forward(ctx, input): | |
output = ext_module.right_pool_forward(input) | |
ctx.save_for_backward(input) | |
return output | |
def backward(ctx, grad_output): | |
input, = ctx.saved_tensors | |
output = ext_module.right_pool_backward(input, grad_output) | |
return output | |
class CornerPool(nn.Module): | |
"""Corner Pooling. | |
Corner Pooling is a new type of pooling layer that helps a | |
convolutional network better localize corners of bounding boxes. | |
Please refer to https://arxiv.org/abs/1808.01244 for more details. | |
Code is modified from https://github.com/princeton-vl/CornerNet-Lite. | |
Args: | |
mode(str): Pooling orientation for the pooling layer | |
- 'bottom': Bottom Pooling | |
- 'left': Left Pooling | |
- 'right': Right Pooling | |
- 'top': Top Pooling | |
Returns: | |
Feature map after pooling. | |
""" | |
pool_functions = { | |
'bottom': BottomPoolFunction, | |
'left': LeftPoolFunction, | |
'right': RightPoolFunction, | |
'top': TopPoolFunction, | |
} | |
cummax_dim_flip = { | |
'bottom': (2, False), | |
'left': (3, True), | |
'right': (3, False), | |
'top': (2, True), | |
} | |
def __init__(self, mode): | |
super(CornerPool, self).__init__() | |
assert mode in self.pool_functions | |
self.mode = mode | |
self.corner_pool = self.pool_functions[mode] | |
def forward(self, x): | |
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0': | |
if torch.onnx.is_in_onnx_export(): | |
assert torch.__version__ >= '1.7.0', \ | |
'When `cummax` serves as an intermediate component whose '\ | |
'outputs is used as inputs for another modules, it\'s '\ | |
'expected that pytorch version must be >= 1.7.0, '\ | |
'otherwise Error appears like: `RuntimeError: tuple '\ | |
'appears in op that does not forward tuples, unsupported '\ | |
'kind: prim::PythonOp`.' | |
dim, flip = self.cummax_dim_flip[self.mode] | |
if flip: | |
x = x.flip(dim) | |
pool_tensor, _ = torch.cummax(x, dim=dim) | |
if flip: | |
pool_tensor = pool_tensor.flip(dim) | |
return pool_tensor | |
else: | |
return self.corner_pool.apply(x) | |