import torch

import torch_npu



from torch_npu.testing.testcase import TestCase, run_tests





class TestFloatStatus(TestCase):



    def test_float_status(self, device="npu"):

        float_tensor = torch.tensor([40000.0], dtype=torch.float16).npu()

        float_tensor = float_tensor + float_tensor



        input1 = torch.zeros(8).npu()

        float_status = torch_npu.npu_alloc_float_status(input1)

        local_float_status = torch_npu.npu_get_float_status(float_status)



        self.assertTrue(local_float_status.cpu()[0] != 0)

        cleared_float_status = torch_npu.npu_clear_float_status(local_float_status)

        input1 = torch.zeros(8).npu()

        float_status = torch_npu.npu_alloc_float_status(input1)

        local_float_status = torch_npu.npu_get_float_status(float_status)

        self.assertTrue(local_float_status.cpu()[0] == 0)





if __name__ == "__main__":

    run_tests()