import numpy as np

import torch



import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests

from torch_npu.testing.common_utils import create_common_tensor





class TestFormatCast(TestCase):



    def supported_op_exec(self, input1):

        m = torch.nn.Identity(54, unused_argument1=0.1, unused_argument2=False)

        output = m(input1)

        return output.cpu().detach()



    def custom_op_exec(self, input1, acl_format):

        output = torch_npu.npu_format_cast(input1, acl_format)

        return output.cpu().detach()



    def test_npu_format_cast(self, device="npu"):

        item = [np.float16, 0, (2, 2, 4, 4)]

        _, npu_input = create_common_tensor(item, -1, 1)

        acl_format = 3



        supported_output = self.supported_op_exec(npu_input)

        custom_output = self.custom_op_exec(npu_input, acl_format)

        self.assertRtolEqual(supported_output, custom_output)





if __name__ == "__main__":

    run_tests()