import math
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair, _single
import torch_npu

__all__ = [
    "ModulatedDeformConv2dFunction",
    "ModulatedDeformConv",
]



class ModulatedDeformConv2dFunction(Function):

    @staticmethod
    def forward(ctx,
                input_tensor,
                offset_ori,
                mask,
                weight,
                bias=None,
                with_bias=False,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                deformable_groups=1,
                sort_index_for_npu_fp=None,
                sort_index_for_npu_bp=None,
                ):

        input_tensor = input_tensor.float()
        offset_ori = offset_ori.float()
        mask = mask.float()

        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups
        ctx.deformable_groups = deformable_groups
        ctx.sort_index_for_npu_bp = sort_index_for_npu_bp
        ctx.with_bias = with_bias

        offset = offset_ori.index_select(1, sort_index_for_npu_fp)
        offset_all = torch.cat([offset, mask], dim=1)
        output, offset_out = torch_npu.npu_deformable_conv2d(
            input_tensor, weight, offset_all, bias,
            kernel_size=[weight.shape[3], weight.shape[2]],
            stride=[1, 1, ctx.stride, ctx.stride],
            padding=[ctx.padding, ctx.padding, ctx.padding, ctx.padding],
            dilation=[1, 1, ctx.dilation, ctx.dilation],
            groups=ctx.groups, deformable_groups=ctx.deformable_groups,
            modulated=True)
        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
                or input_tensor.requires_grad:
            ctx.save_for_backward(input_tensor, weight, offset_out, offset_all)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_tensor, weight, offset_out, offset_all = ctx.saved_tensors
        grad_input, grad_weight, grad_offset_all, grad_bias = torch_npu.npu_deformable_conv2dbk(
            input_tensor, grad_output, offset_out, weight, offset_all,
            kernel_size=[weight.shape[3], weight.shape[2]],
            stride=[1, 1, ctx.stride, ctx.stride],
            padding=[ctx.padding, ctx.padding, ctx.padding, ctx.padding],
            dilation=[1, 1, ctx.dilation, ctx.dilation],
            groups=ctx.groups, deformable_groups=ctx.deformable_groups, modulated=True)
        grad_offset = grad_offset_all.index_select(1, ctx.sort_index_for_npu_bp)
        grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :]
        if not ctx.with_bias:
            grad_bias = None

        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
                None, None, None, None, None, None, None, None)


class ModulatedDeformConv(nn.Module):
    r"""Applies an NPU based Modulated Deformable 2D convolution operation.

    The implementation of this ModulatedDeformConv is mainly based
    on the implementation of mmcv for design and reconstruction.

    In ModulatedDeformConvFunction, the forward and backward are customized,
    and the input tensor is reconstructed ito match the NPU based function.

    It is worth mentioning that DeformConv(DCNv1) is also implemented
    by setting modulated==False. Due to the difference between input
    and initialization, there is no additional implementation here.

    .. note::
        ModulatedDeformConv only implements operations under fp32 data types.
        Notice, weight and bias in conv_offset must be initialized to 0.

    Args:
        in_channels (int): Number of channels in the input image.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size(int, tuple): Size of the convolving kernel.
        stride(int, tuple): Stride of the convolution. Default: 1.
        padding (int or tuple): Zero-padding added to both sides of the input.
            Default: 0.
        dilation (int or tuple): Spacing between kernel elements. Default: 1.
        groups (int): Number of blocked connections from input.
            channels to output channels. Default: 1.
        deform_groups (int): Number of deformable group partitions.
        bias (bool): If True, adds a learnable bias to the output. Default: False.
        pack (bool): If True, conv_offset and mask will be included in this module. Default: True.

    Examples::
        >>> m = ModulatedDeformConv(32, 32, 1)
        >>> input_tensor = torch.randn(2, 32, 5, 5)
        >>> output = m(input_tensor)


        >>> x = torch.randn(2, 32, 7, 7)
        >>> model = ModulatedDeformConv(32, 32, 3, 2, 1)

        >>> torch.npu.set_device(0)
        >>> x = x.npu()
        >>> model = model.npu()

        >>> o = model(x)
        >>> l = o.sum()
        >>> l.backward()
        >>> print(l)
    """
    
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 deformable_groups=1,
                 bias=True,
                 pack=True,
                 ):
        super(ModulatedDeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.deformable_groups = deformable_groups
        self.with_bias = bias
        self.pack = pack
        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.bias = torch.zeros(self.weight.shape[0])
        if self.pack:
            self.conv_offset = nn.Conv2d(
                self.in_channels,
                self.deformable_groups * 3 * self.kernel_size[0] *
                self.kernel_size[1],
                kernel_size=self.kernel_size,
                stride=_pair(self.stride),
                padding=_pair(self.padding),
                bias=True)
        self.split_num = self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1]
        sort_index_for_npu = list(range(self.split_num))
        sort_index_for_npu_fp = sort_index_for_npu[1::2] + sort_index_for_npu[::2]
        sort_index_for_npu_bp_dict = {i: idx for idx, i in enumerate(sort_index_for_npu_fp)}
        sort_index_for_npu_bp = [sort_index_for_npu_bp_dict[i] for i in sort_index_for_npu]
        self.sort_index_for_npu_fp = torch.IntTensor(sort_index_for_npu_fp)
        self.sort_index_for_npu_bp = torch.IntTensor(sort_index_for_npu_bp)
        self.sort_index_for_npu_todevice = False
        self.init_param()

    def init_param(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        if self.pack:
            self.conv_offset.weight.data.zero_()
            self.conv_offset.bias.data.zero_()

    def forward(self, x):
        if self.pack:
            out = self.conv_offset(x)
            offset = out[:, :self.split_num, ...]
            mask = torch.sigmoid(out[:, self.split_num:, ...])
        else:
            x, offset, mask = x

        if not self.sort_index_for_npu_todevice:
            self.sort_index_for_npu_fp = self.sort_index_for_npu_fp.to(x.device)
            self.sort_index_for_npu_bp = self.sort_index_for_npu_bp.to(x.device)
            self.bias = self.bias.to(x.device)
            self.sort_index_for_npu_todevice = True

        return ModulatedDeformConv2dFunction.apply(
            x, offset, mask, self.weight, self.bias, self.with_bias,
            self.stride, self.padding, self.dilation,
            self.groups, self.deformable_groups,
            self.sort_index_for_npu_fp,
            self.sort_index_for_npu_bp,
        )


DCNv2 = ModulatedDeformConv