import torch
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from testutils import TestUtils
import torch_npu


class TestSumAdd(TestUtils):
    def op_calc(self, input_element, dim, input_element2):
        tmp = torch.sum(input_element, dim)
        return tmp + input_element2

    @parametrize('shape', [(32, 64, 128, 2048)])
    @parametrize('dim', [0, 1, 2, 3])
    @parametrize('dtype', ['float32'])
    def test_reduction_cases_shapes(self, shape, dim, dtype):
        input_element = self._generate_tensor(shape, dtype)
        if dim == -1 or dim == 3:
            input_element2 = torch.full(size=(32, 64, 128), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu"))
        elif dim == 2:
            input_element2 = torch.full(size=(32, 64, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu"))
        elif dim == 1:
            input_element2 = torch.full(size=(32, 128, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu"))
        else:
            input_element2 = torch.full(size=(64, 128, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu"))

        std_sum = self.op_calc(input_element, dim, input_element2)

        compiled_op_calc = torch.compile(self.op_calc, backend="inductor")
        inductor_sum = compiled_op_calc(input_element, dim, input_element2)

        self.assertEqual(std_sum, inductor_sum, atol=1e-1, rtol=1e-1)


instantiate_parametrized_tests(TestSumAdd)

if __name__ == "__main__":
    run_tests()