import torch
import torch.nn as nn
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

import mx_driving.fused


@golden_data_cache(__file__)
def gen_inputs(shape, dtype):
    torch.manual_seed(123)
    x_data_cpu = torch.rand(shape, dtype=dtype)
    return x_data_cpu


@golden_data_cache(__file__)
def cpu_to_exec(x_data_cpu):
    f = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    cpu_output = f(x_data_cpu.float())
    return cpu_output


def npu_to_exec(x_data_cpu):
    npu_output = mx_driving.fused.npu_max_pool2d(x_data_cpu.npu(), 3, 2, 1)
    return npu_output


class TestNpuMaxPool2d(TestCase):
    def test_npu_max_pool2d(self):
        dtype_list = [torch.float32, torch.float16]
        shape_list = [
            [18, 64, 464, 800],
            [6, 64, 464, 800],
            [7, 32, 46400, 18],
            [2, 16, 42, 24785],
            [1, 16, 3, 3],
            [1, 8, 100, 100],
            [1, 1, 1, 1]
        ]

        items = [
            [shape, dtype]
            for shape in shape_list
            for dtype in dtype_list
        ]

        for item in items:
            shape, dtype = item
            x_data_cpu = gen_inputs(shape, dtype)

            cpu_output = cpu_to_exec(x_data_cpu)
            npu_output = npu_to_exec(x_data_cpu)

            self.assertRtolEqual(cpu_output, npu_output.float())


if __name__ == "__main__":
    run_tests()