import numpy as np
import torch
import torch_npu
from torch.autograd import Function
import mx_driving._C
class AdsDeformableAggregation(Function):
@staticmethod
def forward(
ctx,
mc_ms_feat: torch.Tensor,
spatial_shape: torch.Tensor,
scale_start_index: torch.Tensor,
sampling_location: torch.Tensor,
weights: torch.Tensor,
):
if (torch.numel(mc_ms_feat) == 0 or torch.numel(weights) == 0):
raise Exception("Erorr! Input Tensor can not be a empty Tensor.\n")
mc_ms_feat = mc_ms_feat.contiguous()
spatial_shape = spatial_shape.contiguous().int()
scale_start_index = scale_start_index.contiguous().int()
sampling_location = sampling_location.contiguous()
weights = weights.contiguous()
output = mx_driving._C.npu_deformable_aggregation(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
ctx.save_for_backward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
return output
@staticmethod
def backward(ctx, grad_output):
(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
) = ctx.saved_tensors
if (torch.numel(mc_ms_feat) == 0 or torch.numel(spatial_shape) == 0 or torch.numel(sampling_location) == 0):
raise Exception("Erorr! Input Tensor can not be a empty Tensor.\n")
mc_ms_feat = mc_ms_feat.contiguous()
spatial_shape = spatial_shape.contiguous().int()
scale_start_index = scale_start_index.contiguous().int()
sampling_location = sampling_location.contiguous()
weights = weights.contiguous()
grad_mc_ms_feat = torch.zeros_like(mc_ms_feat)
grad_sampling_location = torch.zeros_like(sampling_location)
grad_weights = torch.zeros_like(weights)
grad_mc_ms_feat, grad_sampling_location, grad_weights = mx_driving._C.npu_deformable_aggregation_backward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
grad_output.contiguous(),
grad_mc_ms_feat,
grad_sampling_location,
grad_weights,
)
return (
grad_mc_ms_feat,
None,
None,
grad_sampling_location,
grad_weights,
)
npu_deformable_aggregation = AdsDeformableAggregation.apply
deformable_aggregation = AdsDeformableAggregation.apply