import unittest
import torch
import numpy as np

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
from torch_npu.testing.common_utils import SupportedDevices


@SupportedDevices(['Ascend910B'])
class TestArgSort(TestCase):
    def cpu_op_exec(self, input1, dim, descending):
        output = torch.argsort(input1, dim=dim, descending=descending)
        return output.numpy()

    def npu_op_exec(self, input1, dim, descending):
        output = torch.argsort(input1, dim=dim, descending=descending)
        return output.cpu().numpy()

    def cpu_default_op_exec(self, input1):
        output = torch.argsort(input1)
        return output.numpy()

    def npu_default_op_exec(self, input1):
        output = torch.argsort(input1)
        return output.cpu().numpy()

    def cpu_op_exec_stable(self, input1, stable, dim, descending):
        output = torch.argsort(input1, stable=stable, dim=dim, descending=descending)
        return output.numpy()

    def npu_op_exec_stable(self, input1, stable, dim, descending):
        output = torch.argsort(input1, stable=stable, dim=dim, descending=descending)
        return output.cpu().numpy()

    def test_sort_shape_format_fp32(self):
        shape_format = [
            [[np.float32, 0, (8, 4, 3, 9)], 2, False],
            [[np.float32, 0, (2, 3)]],
            [[np.float32, 0, (1, 7)], 0, True],
            [[np.float32, 0, (1, 5, 6)], 1, False],
        ]

        for item in shape_format:
            cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100)
            if len(item) > 1:
                cpu_output = self.cpu_op_exec(cpu_input1, item[1], item[2])
                npu_output = self.npu_op_exec(npu_input1, item[1], item[2])
            else:
                cpu_output = self.cpu_default_op_exec(cpu_input1)
                npu_output = self.npu_default_op_exec(npu_input1)
            self.assertRtolEqual(cpu_output, npu_output)

    def test_sort_shape_format_fp16(self):
        shape_format = [
            [[np.float16, 0, (8, 4, 3, 9)], 2, False],
            [[np.float16, 0, (2, 3)]],
            [[np.float16, 0, (1, 7)], 0, True],
            [[np.float16, 0, (1, 5, 6)], 1, False],
        ]

        for item in shape_format:
            cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100)
            if len(item) > 1:
                cpu_output = self.cpu_op_exec(cpu_input1, item[1], item[2])
                npu_output = self.npu_op_exec(npu_input1, item[1], item[2])
            else:
                cpu_output = self.cpu_default_op_exec(cpu_input1)
                npu_output = self.npu_default_op_exec(npu_input1)
            self.assertRtolEqual(cpu_output, npu_output)

    @unittest.skipIf("1.11.0" in torch.__version__,
                "OP `argsort.stable` is not supported on torch v1.11.0, skip this ut for this torch version")
    def test_sort_stable_shape_format_fp32(self):
        shape_format = [
            [[np.float32, 0, (8, 4, 3, 9)], True, 2, False],
            [[np.float32, 0, (1, 7)], False, 0, True],
            [[np.float32, 0, (1, 5, 6)], True, 1, False],
        ]

        for item in shape_format:
            cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100)
            cpu_output = self.cpu_op_exec_stable(cpu_input1, item[1], item[2], item[3])
            npu_output = self.npu_op_exec_stable(npu_input1, item[1], item[2], item[3])
            self.assertRtolEqual(cpu_output, npu_output)

    @unittest.skipIf("1.11.0" in torch.__version__,
                "OP `argsort.stable` is not supported on torch v1.11.0, skip this ut for this torch version")
    def test_sort_stable_shape_format_fp16(self):
        shape_format = [
            [[np.float16, 0, (8, 4, 3, 9)], True, 2, False],
            [[np.float16, 0, (1, 7)], False, 0, True],
            [[np.float16, 0, (1, 5, 6)], True, 1, False],
        ]

        for item in shape_format:
            cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100)
            cpu_output = self.cpu_op_exec_stable(cpu_input1, item[1], item[2], item[3])
            npu_output = self.npu_op_exec_stable(npu_input1, item[1], item[2], item[3])
            self.assertRtolEqual(cpu_output, npu_output)


if __name__ == "__main__":
    run_tests()