Spaces:
Runtime error
Runtime error
from torch.autograd import Function | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext( | |
'_ext', ['assign_score_withk_forward', 'assign_score_withk_backward']) | |
class AssignScoreWithK(Function): | |
r"""Perform weighted sum to generate output features according to scores. | |
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/ | |
scene_seg/lib/paconv_lib/src/gpu>`_. | |
This is a memory-efficient CUDA implementation of assign_scores operation, | |
which first transform all point features with weight bank, then assemble | |
neighbor features with ``knn_idx`` and perform weighted sum of ``scores``. | |
See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for | |
more detailed descriptions. | |
Note: | |
This implementation assumes using ``neighbor`` kernel input, which is | |
(point_features - center_features, point_features). | |
See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/ | |
pointnet2/paconv.py#L128 for more details. | |
""" | |
def forward(ctx, | |
scores, | |
point_features, | |
center_features, | |
knn_idx, | |
aggregate='sum'): | |
""" | |
Args: | |
scores (torch.Tensor): (B, npoint, K, M), predicted scores to | |
aggregate weight matrices in the weight bank. | |
``npoint`` is the number of sampled centers. | |
``K`` is the number of queried neighbors. | |
``M`` is the number of weight matrices in the weight bank. | |
point_features (torch.Tensor): (B, N, M, out_dim) | |
Pre-computed point features to be aggregated. | |
center_features (torch.Tensor): (B, N, M, out_dim) | |
Pre-computed center features to be aggregated. | |
knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN. | |
We assume the first idx in each row is the idx of the center. | |
aggregate (str, optional): Aggregation method. | |
Can be 'sum', 'avg' or 'max'. Defaults: 'sum'. | |
Returns: | |
torch.Tensor: (B, out_dim, npoint, K), the aggregated features. | |
""" | |
agg = {'sum': 0, 'avg': 1, 'max': 2} | |
B, N, M, out_dim = point_features.size() | |
_, npoint, K, _ = scores.size() | |
output = point_features.new_zeros((B, out_dim, npoint, K)) | |
ext_module.assign_score_withk_forward( | |
point_features.contiguous(), | |
center_features.contiguous(), | |
scores.contiguous(), | |
knn_idx.contiguous(), | |
output, | |
B=B, | |
N0=N, | |
N1=npoint, | |
M=M, | |
K=K, | |
O=out_dim, | |
aggregate=agg[aggregate]) | |
ctx.save_for_backward(output, point_features, center_features, scores, | |
knn_idx) | |
ctx.agg = agg[aggregate] | |
return output | |
def backward(ctx, grad_out): | |
""" | |
Args: | |
grad_out (torch.Tensor): (B, out_dim, npoint, K) | |
Returns: | |
grad_scores (torch.Tensor): (B, npoint, K, M) | |
grad_point_features (torch.Tensor): (B, N, M, out_dim) | |
grad_center_features (torch.Tensor): (B, N, M, out_dim) | |
""" | |
_, point_features, center_features, scores, knn_idx = ctx.saved_tensors | |
agg = ctx.agg | |
B, N, M, out_dim = point_features.size() | |
_, npoint, K, _ = scores.size() | |
grad_point_features = point_features.new_zeros(point_features.shape) | |
grad_center_features = center_features.new_zeros(center_features.shape) | |
grad_scores = scores.new_zeros(scores.shape) | |
ext_module.assign_score_withk_backward( | |
grad_out.contiguous(), | |
point_features.contiguous(), | |
center_features.contiguous(), | |
scores.contiguous(), | |
knn_idx.contiguous(), | |
grad_point_features, | |
grad_center_features, | |
grad_scores, | |
B=B, | |
N0=N, | |
N1=npoint, | |
M=M, | |
K=K, | |
O=out_dim, | |
aggregate=agg) | |
return grad_scores, grad_point_features, \ | |
grad_center_features, None, None | |
assign_score_withk = AssignScoreWithK.apply | |