"""
Copyright (c) OpenMMLab. All rights reserved.
"""
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch_npu
from torch.autograd import Function
import mx_driving._C
class BorderAlignFunction(Function):
@staticmethod
def forward(ctx: Any, feature_map: torch.Tensor, rois: torch.Tensor, pooled_size: int) -> torch.Tensor:
if (torch.numel(feature_map) == 0 or torch.numel(rois) == 0 or pooled_size == 0):
raise Exception("Error! Input Tensor can not be a empty Tensor! \n")
ctx.pooled_size = pooled_size
ctx.feature_size = feature_map.size()
batch_size, num_channels, data_height, data_width = feature_map.size()
output = torch.zeros(batch_size, data_height * data_width, ctx.pooled_size + 1, num_channels).to(
feature_map.device
)
mx_driving._C.border_align(feature_map, rois, output, ctx.pooled_size)
npu_outputs, index = output.max(dim=-2)
npu_outputs = (
npu_outputs.reshape([batch_size, data_height * data_width, 4, num_channels // 4])
.permute([0, 3, 1, 2])
.contiguous()
)
index = (
index.int()
.reshape([batch_size, data_height * data_width, 4, num_channels // 4])
.permute([0, 3, 1, 2])
.contiguous()
)
ctx.save_for_backward(rois, index)
return npu_outputs
@staticmethod
def backward(ctx, grad_output):
rois, argmax_idx = ctx.saved_tensors
_, _, height, width = ctx.feature_size
grad_output = grad_output.contiguous()
grad_input = mx_driving._C.border_align_backward(
grad_output,
rois,
argmax_idx,
ctx.pooled_size,
height,
width)
return grad_input, None, None
border_align = BorderAlignFunction.apply