"""
Copyright (c) OpenMMLab. All rights reserved.
Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
Modification by: Huawei Developers
Modification date: 2024-06-04 
Modification Description: 
Modification 1. Add support for Ascend NPU
"""

from typing import Any, Tuple

import torch
import torch_npu
from torch.autograd import Function

import mx_driving._C


class ThreeInterpolateFunction(Function):

    @staticmethod
    def forward(ctx: Any, features: torch.Tensor, indices: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:

        b, c, m = features.size()
        n = indices.size(1)
        ctx.three_interpolate_for_backward = (indices, weight, m)

        func = mx_driving._C.npu_three_interpolate
        out = func(b, c, m, n, features, indices, weight)

        return out

    @staticmethod
    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        b, c, n = grad_out.size()
        idx, weight, m = ctx.three_interpolate_for_backward

        grad_out_dtype = grad_out.dtype
        grad_out_data = grad_out.data.contiguous().to(torch.float)
        weight = weight.to(torch.float)

        grad_features = mx_driving._C.npu_three_interpolate_backward(b, c, n, m, grad_out_data, idx, weight)

        if grad_out_dtype == torch.half:
            grad_features = grad_features.to(torch.half)

        return grad_features, None, None


three_interpolate = ThreeInterpolateFunction.apply