#

import numpy as np
import torch

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor

DEVICE_NAME = torch_npu.npu.get_device_name(0)


class TestConvolutionTranspose(TestCase):

    def supported_op_exec(self, input1, weight, bias, padding, output_padding, stride, dilation, groups):
        dim = input1.dim()
        if dim == 4:
            output = torch.nn.functional.conv_transpose2d(input1, weight, bias, stride, padding,
                                                          output_padding, groups, dilation)
        elif dim == 5:
            output = torch.nn.functional.conv_transpose3d(input1, weight, bias, stride, padding,
                                                          output_padding, groups, dilation)
        return output.cpu().detach()

    def custom_op_exec(self, input1, weight, bias, padding, output_padding, stride, dilation, groups):
        output = torch_npu.npu_convolution_transpose(input1, weight, bias, padding, output_padding,
                                                     stride, dilation, groups)
        return output.cpu().detach()

    def test_npu_convolution_transpose(self):
        items = [[[np.float32, 0, [1, 3, 3, 3]], [np.float32, 0, [3, 2, 3, 3]], [np.float32, 2, [2]],
                  [1, 1], [0, 0], [0, 0], [1, 1], 1],
                 [[np.float16, 2, [20, 16, 50, 10, 20]], [np.float16, 2, [16, 33, 3, 3, 3]], None,
                  [0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1], 1]]
        if "Ascend910A" in DEVICE_NAME or "Ascend910P" in DEVICE_NAME:
            items0 = [[np.float32, 2, [20, 16, 50, 10, 20]], [np.float32, 2, [16, 33, 3, 3, 3]], None,
                      [0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1], 1]
            items.append(items0)
        for item in items:
            _, npu_input = create_common_tensor(item[0], 0, 0.001)
            _, weight = create_common_tensor(item[1], 0, 0.001)
            _, bias = create_common_tensor(item[2], 1, 200) if item[2] else _, None
            padding = item[3]
            output_padding = item[4]
            stride = item[5]
            dilation = item[6]
            groups = item[7]

        supported_output = self.supported_op_exec(npu_input, weight, bias, padding, output_padding,
                                                  stride, dilation, groups)
        custom_output = self.custom_op_exec(npu_input, weight, bias, padding, output_padding,
                                            stride, dilation, groups)
        self.assertRtolEqual(supported_output, custom_output)


if __name__ == "__main__":
    run_tests()