import unittest
from collections import namedtuple

import numpy as np
import torch
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

import mx_driving


# pylint: disable=too-many-return-values
@golden_data_cache(__file__)
def cpu_gen_inputs(shape, dtype):
    bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape
    shapes = torch.tensor([60, 40] * num_levels).reshape(num_levels, 2)
    num_keys = sum((H * W).item() for H, W in shapes)

    value = torch.rand(bs, num_keys, num_heads, embed_dims) * 0.01
    sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2)
    attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points) + 1e-5
    offset = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1]))
    grad_output = torch.rand(bs, num_queries, num_heads * embed_dims) * 1e-3

    value = value.to(dtype)
    sampling_locations = sampling_locations.to(dtype)
    attention_weights = attention_weights.to(dtype)
    grad_output = grad_output.to(dtype)

    return shapes, num_keys, value, sampling_locations, attention_weights, offset, grad_output


@golden_data_cache(__file__)
def multi_scale_deformable_attn_pytorch(
    value: torch.Tensor,
    value_spatial_shapes: torch.Tensor,
    sampling_locations: torch.Tensor,
    attention_weights: torch.Tensor,
) -> torch.Tensor:
    bs, _, num_heads, embed_dims = value.shape
    _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
    sampling_grids = 2 * sampling_locations - 1
    sampling_value_list = []
    for level, (H_, W_) in enumerate(value_spatial_shapes):
        value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)

        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)

        sampling_value_l_ = torch.nn.functional.grid_sample(
            value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
        )
        sampling_value_list.append(sampling_value_l_)

    attention_weights = attention_weights.transpose(1, 2).reshape(
        bs * num_heads, 1, num_queries, num_levels * num_points
    )
    output = (
        (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
        .sum(-1)
        .view(bs, num_heads * embed_dims, num_queries)
    )
    return output.transpose(1, 2).contiguous()


@golden_data_cache(__file__)
def multi_scale_deformable_attn_pytorch_grad(
    cpu_output, cpu_grad_output, cpu_value, cpu_sampling_locations, cpu_attention_weights
):
    cpu_output.backward(cpu_grad_output)
    grad_value = cpu_value.grad.float().numpy()
    grad_sampling_locations = cpu_sampling_locations.grad.float().numpy()
    grad_attention_weights = cpu_attention_weights.grad.float().numpy()

    return grad_value, grad_sampling_locations, grad_attention_weights


ExecResults = namedtuple("ExecResults", ["output", "grad_value", "grad_sampling_locations", "grad_attention_weights"])
Inputs = namedtuple("Inputs", ["value", "shapes", "offset", "sampling_locations", "attention_weights", "grad_output"])


class TestMultiScaleDeformableAttnFunction(TestCase):
    def gen_inputs(self, shape, dtype, data=None):
        bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape
        shapes, _, value, sampling_locations, attention_weights, offset, grad_output = cpu_gen_inputs(shape, dtype)
        if data is not None:
            value.fill_(data)
            sampling_locations.fill_(data)
            attention_weights.fill_(data)
            grad_output.fill_(data)

        cpu_value = value.double()
        cpu_shapes = shapes.long()
        cpu_sampling_locations = sampling_locations.double()
        cpu_attention_weights = attention_weights.double()
        cpu_grad_output = grad_output.double()

        cpu_value.requires_grad_()
        cpu_sampling_locations.requires_grad_()
        cpu_attention_weights.requires_grad_()

        npu_value = value.npu()
        npu_shapes = shapes.npu()
        npu_offset = offset.npu()
        npu_sampling_locations = sampling_locations.npu()
        npu_attention_weights = attention_weights.npu()
        npu_grad_output = grad_output.npu()

        npu_value.requires_grad_()
        npu_sampling_locations.requires_grad_()
        npu_attention_weights.requires_grad_()

        return Inputs(
            cpu_value, cpu_shapes, None, cpu_sampling_locations, cpu_attention_weights, cpu_grad_output
        ), Inputs(npu_value, npu_shapes, npu_offset, npu_sampling_locations, npu_attention_weights, npu_grad_output)

    def cpu_to_exec(self, cpu_inputs):
        cpu_value = cpu_inputs.value
        cpu_shapes = cpu_inputs.shapes
        cpu_sampling_locations = cpu_inputs.sampling_locations
        cpu_attention_weights = cpu_inputs.attention_weights
        cpu_grad_output = cpu_inputs.grad_output
        cpu_output = multi_scale_deformable_attn_pytorch(
            cpu_value, cpu_shapes, cpu_sampling_locations, cpu_attention_weights
        )
        grad_value, grad_sampling_locations, grad_attention_weights = multi_scale_deformable_attn_pytorch_grad(
            cpu_output, cpu_grad_output, cpu_value, cpu_sampling_locations, cpu_attention_weights
        )
        return ExecResults(
            output=cpu_output.detach().float().numpy(),
            grad_value=grad_value,
            grad_sampling_locations=grad_sampling_locations,
            grad_attention_weights=grad_attention_weights,
        )

    def npu_to_exec(self, npu_inputs):
        npu_value = npu_inputs.value
        npu_shapes = npu_inputs.shapes
        npu_offset = npu_inputs.offset
        npu_sampling_locations = npu_inputs.sampling_locations
        npu_attention_weights = npu_inputs.attention_weights
        npu_grad_output = npu_inputs.grad_output
        npu_output = mx_driving.multi_scale_deformable_attn(
            npu_value, npu_shapes, npu_offset, npu_sampling_locations, npu_attention_weights
        )
        npu_output.backward(npu_grad_output)
        return ExecResults(
            output=npu_output.detach().cpu().numpy(),
            grad_value=npu_value.grad.cpu().numpy(),
            grad_sampling_locations=npu_sampling_locations.grad.cpu().numpy(),
            grad_attention_weights=npu_attention_weights.grad.cpu().numpy(),
        )

    # fast_mode: num_heads * num_points * num_levels <= 64
    def test_fast_mode(self):
        shape = [6, 9680, 32, 8, 1, 8]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)

    def test_embed_32(self):
        shape = [6, 9680, 32, 8, 4, 4]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)

    def test_embed_unaligned(self):
        shape = [6, 9680, 37, 4, 5, 3]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)

    def test_embed_16(self):
        shape = [1, 27216, 16, 5, 3, 1]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)

    def test_embed_64(self):
        shape = [1, 1450, 64, 6, 1, 2]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)
    
    def test_fully_embed_64(self):
        shape = [1, 1450, 64, 8, 8, 8]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)
    
    def test_fully_embed_128(self):
        shape = [1, 1450, 128, 8, 8, 8]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)
    
    def test_fully_embed_256(self):
        shape = [1, 1450, 256, 8, 8, 8]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)
    
    def test_fully_unaligned(self):
        shape = [1, 1450, 147, 7, 3, 21]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)

    def test_point_16(self):
        shape = [1, 1890, 32, 7, 4, 16]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float32)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output, npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations)

    def test_fp16(self):
        shape = [6, 9680, 32, 8, 4, 4]
        cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float16)
        cpu_results = self.cpu_to_exec(cpu_inputs)
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(cpu_results.output.astype(np.float16), npu_results.output)
        self.assertRtolEqual(cpu_results.grad_value.astype(np.float16), npu_results.grad_value)
        self.assertRtolEqual(cpu_results.grad_attention_weights.astype(np.float16), npu_results.grad_attention_weights)
        self.assertRtolEqual(cpu_results.grad_sampling_locations.astype(np.float16), npu_results.grad_sampling_locations)

    def test_nan(self):
        shape = [6, 9680, 32, 8, 4, 4]
        _, npu_inputs = self.gen_inputs(shape, torch.float32, float('nan'))
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(np.zeros_like(npu_results.output), npu_results.output)
        self.assertRtolEqual(np.zeros_like(npu_results.grad_value), npu_results.grad_value)
        self.assertRtolEqual(np.zeros_like(npu_results.grad_attention_weights), npu_results.grad_attention_weights)
        self.assertRtolEqual(np.zeros_like(npu_results.grad_sampling_locations), npu_results.grad_sampling_locations)

    def test_inf(self):
        shape = [6, 9680, 32, 8, 4, 4]
        _, npu_inputs = self.gen_inputs(shape, torch.float32, float('inf'))
        npu_results = self.npu_to_exec(npu_inputs)
        self.assertRtolEqual(np.zeros_like(npu_results.output), npu_results.output)
        self.assertRtolEqual(np.zeros_like(npu_results.grad_value), npu_results.grad_value)
        self.assertRtolEqual(np.zeros_like(npu_results.grad_attention_weights), npu_results.grad_attention_weights)
        self.assertRtolEqual(np.zeros_like(npu_results.grad_sampling_locations), npu_results.grad_sampling_locations)


if __name__ == "__main__":
    run_tests()