"""
Copyright (c) Megvii Inc. All rights reserved.
Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
Modification by: Huawei Developers
Modification date: 2024-06-04 
Modification Description: 
Modification 1. Add support for Ascend NPU
"""

import torch
from torch.autograd import Function

import mx_driving._C


class AdsVoxelPoolingFunction(Function):
    @staticmethod
    def forward(ctx, geom_xyz, input_features, voxel_num):
        grad_input_features = torch.zeros_like(input_features)
        geom_xyz = geom_xyz.reshape(geom_xyz.shape[0], -1, geom_xyz.shape[-1])
        input_features = input_features.reshape(geom_xyz.shape[0], -1, input_features.shape[-1])

        batch_size = input_features.shape[0]
        num_points = input_features.shape[1]
        if (num_points == 0):
            raise Exception("Error! Number of points can not be zero.\n")
        
        num_channels = input_features.shape[2]
        output_features = input_features.new_zeros(batch_size, voxel_num[1], voxel_num[0], num_channels)
        pos_memo = geom_xyz.new_ones(batch_size, num_points, 3) * -1
        pos, result = mx_driving._C.voxel_pooling_train(
            input_features,
            geom_xyz,
            output_features,
            pos_memo,
            batch_size,
            num_points,
            num_channels,
            voxel_num[0],
            voxel_num[1],
            voxel_num[2],
        )
        ctx.save_for_backward(grad_input_features, pos)
        return result.permute(0, 3, 1, 2)

    @staticmethod
    def backward(ctx, grad_output_features):
        (grad_input_features, pos_memo) = ctx.saved_tensors
        grad_input_features_shape = grad_input_features.shape

        batch_size = pos_memo.shape[0]
        num_points = pos_memo.shape[1]
        num_channels = grad_output_features.shape[1]
        H = grad_output_features.shape[2]
        W = grad_output_features.shape[3]

        result = mx_driving._C.voxel_pool_train_backward(
            grad_output_features, pos_memo, batch_size, num_points, num_channels, H, W
        )
        grad_input_features = result.reshape(grad_input_features_shape)
        return None, grad_input_features, None


npu_voxel_pooling_train = AdsVoxelPoolingFunction.apply