import torch
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from testutils import TestUtils
import torch_npu




class TestReshape(TestUtils):
    B, N, S, D = (1, 12, 256, 8)

    def op_calc(self, a, b):
        a = a.reshape(self.S, self.B, self.N * self.D)
        b = b.reshape(self.S, self.B, self.N * self.D)
        y = a + b
        return y

    @parametrize('shape', [(1, 12, 256, 8)])
    @parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64'])
    def test_view_cases(self, shape, dtype):
        a = self._generate_tensor(shape, dtype)
        b = self._generate_tensor(shape, dtype)

        std_reshape = self.op_calc(a, b)

        compiled_op_calc = torch.compile(self.op_calc, backend="inductor")
        inductor_reshape = compiled_op_calc(a, b)

        self.assertEqual(std_reshape, inductor_reshape, atol=1e-3, rtol=1e-3)


instantiate_parametrized_tests(TestReshape)

if __name__ == "__main__":
    run_tests()