import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor


class TestAdaptiveMaxPool2dBackward(TestCase):
    def cpu_op_exec(self, input_tensor, output_size):
        input_tensor.requires_grad = True
        m = nn.AdaptiveMaxPool2d(output_size)
        output = m(input_tensor)
        output.backward(output)
        cpu_grad = input_tensor.grad
        return cpu_grad

    def npu_op_exec(self, input_tensor, output_size):
        input_tensor.requires_grad = True
        m = nn.AdaptiveMaxPool2d(output_size)
        output = m(input_tensor)
        output.backward(output)
        npu_grad = input_tensor.grad
        npu_grad = npu_grad.to("cpu")
        return npu_grad

    def test_adaptiveMaxPool2d_shape_format_fp32_6(self):
        format_list = [0, 3]
        shape_list = [(1, 3, 8, 9)]
        shape_format = [[np.float16, i, j] for i in format_list for j in shape_list]
        output_list = [(2, 3)]
        for item in shape_format:
            cpu_input, npu_input = create_common_tensor(item, 0, 100)
            for output_size in output_list:
                cpu_input = cpu_input.to(torch.float32)
                cpu_output = self.cpu_op_exec(cpu_input, output_size)
                cpu_output = cpu_output.to(torch.float16)
                npu_output = self.npu_op_exec(npu_input, output_size)
                self.assertRtolEqual(cpu_output, npu_output)

    def test_adaptiveMaxPool2d_backward_case_in_photo2cartoon(self):
        cpu_x = torch.rand(1, 256, 31, 31)
        npu_x = cpu_x.npu()
        cpu_x.requires_grad = True
        npu_x.requires_grad = True
        cpu_out = F.adaptive_max_pool2d(cpu_x, 1)
        npu_out = F.adaptive_max_pool2d(npu_x, 1)
        cpu_out.backward(torch.ones_like(cpu_out))
        npu_out.backward(torch.ones_like(npu_out))
        self.assertRtolEqual(cpu_x.grad, npu_x.grad.cpu(), 0.0003)

    def test_adaptiveMaxPool2d_backward_case_in_photo2cartoon_fp16(self):
        cpu_x = torch.rand(1, 256, 31, 31).half()
        npu_x = cpu_x.npu()
        cpu_x.requires_grad = True
        npu_x.requires_grad = True
        cpu_out = F.adaptive_max_pool2d(cpu_x.float(), 1).half()
        npu_out = F.adaptive_max_pool2d(npu_x, 1)
        cpu_out.backward(torch.ones_like(cpu_out))
        npu_out.backward(torch.ones_like(npu_out))
        self.assertRtolEqual(cpu_x.grad, npu_x.grad.cpu())


if __name__ == "__main__":
    run_tests()