import torch

import torch_npu



from torch_npu.testing.testcase import TestCase, run_tests

from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices





class TestRFFT(TestCase):

    @SupportedDevices(['Ascend910B'])

    def supported_op_exec(self, x, length, dim, norm):

        return torch.fft.rfft(input=x, n=length, dim=dim, norm=norm)



    def custom_op_exec(self, x, length, dim, norm):

        return torch.fft.rfft(input=x.npu(), n=length, dim=dim, norm=norm)



    def test_npu_rfft_meta(self):

        shape = [64, 64, 1024]

        length = shape[-1]

        dim = -1

        norm = "backward"

        x = torch.randn(shape, dtype=torch.float32)



        supported_output = self.supported_op_exec(x, length, dim, norm)

        custom_output = self.custom_op_exec(x, length, dim, norm)

        self.assertRtolEqual(supported_output.real, custom_output.real)

        self.assertRtolEqual(supported_output.imag, custom_output.imag)



if __name__ == "__main__":

    run_tests()