import os
import subprocess
import sys
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests
class TestInfNanMode(TestCase):
@SupportedDevices(['Ascend950'])
def test_is_support_inf_nan_always_enabled(self):
import torch_npu.npu.utils as utils
self.assertTrue(utils.is_support_inf_nan())
@SupportedDevices(['Ascend950'])
def test_env_vars_ignored(self):
code = (
"import torch_npu.npu.utils as utils;"
"assert utils.is_support_inf_nan()"
)
env_cases = [
({'INF_NAN_MODE_ENABLE': '1'}, "INF_NAN_MODE_ENABLE=1"),
({'INF_NAN_MODE_ENABLE': '0'}, "INF_NAN_MODE_ENABLE=0"),
({'INF_NAN_MODE_FORCE_DISABLE': '0'}, "INF_NAN_MODE_FORCE_DISABLE=0"),
({'INF_NAN_MODE_FORCE_DISABLE': '1'}, "INF_NAN_MODE_FORCE_DISABLE=1"),
({'INF_NAN_MODE_ENABLE': '0', 'INF_NAN_MODE_FORCE_DISABLE': '1'},
"INF_NAN_MODE_ENABLE=0 + INF_NAN_MODE_FORCE_DISABLE=1"),
]
for env_vars, desc in env_cases:
env = os.environ.copy()
env.update(env_vars)
result = subprocess.run(
[sys.executable, '-c', code],
env=env,
capture_output=True,
text=True,
)
self.assertEqual(
result.returncode, 0,
f"{desc} should be ignored on Ascend950.\n"
f"stdout: {result.stdout}\nstderr: {result.stderr}"
)
if __name__ == "__main__":
run_tests()