import torch
from torch_npu.testing.testcase import TestCase, run_tests
import op_extension


class TestCustomAdd(TestCase):

    def test_add_custom_ops(self):
        length = [8, 2048]
        x = torch.randint(low=1, high=100, size=length, device='cpu', dtype=torch.int)
        y = torch.randint(low=1, high=100, size=length, device='cpu', dtype=torch.int)

        x_npu = x.npu()
        y_npu = y.npu()
        output = op_extension.ops.custom_add(x_npu, y_npu)
        cpuout = torch.add(x, y)

        self.assertRtolEqual(output, cpuout)


class TestCustomTrig(TestCase):
    def test_trig_custom_ops(self):
        length = [8, 2048]
        x = torch.rand(length, device='npu', dtype=torch.float32)
        out_sin = torch.empty_like(x)
        out_cos = torch.empty_like(x)
        x_npu = x.npu()
        out_sin_npu = out_sin.npu()
        out_cos_npu = out_cos.npu()

        out_tan = op_extension.ops.custom_trig(x_npu, out_sin_npu, out_cos_npu)
        out_sin_cpu = torch.sin(x)
        out_cos_cpu = torch.cos(x)
        out_tan_cpu = torch.tan(x)
        self.assertRtolEqual(out_sin_npu, out_sin_cpu)
        self.assertRtolEqual(out_cos_npu, out_cos_cpu)
        self.assertRtolEqual(out_tan, out_tan_cpu)


if __name__ == "__main__":
    run_tests()