import json
import os
import re
import unittest
import warnings
import torch
from torch.testing._internal import common_utils, common_device_type
from torch.testing._internal.opinfo.core import OpInfo
from torch.testing._internal.common_utils import remove_device_and_dtype_suffixes, TEST_WITH_SLOW, \
IS_SANDCASTLE, TEST_SKIP_FAST, RERUN_DISABLED_TESTS, DISABLED_TESTS_FILE, SLOW_TESTS_FILE, maybe_load_json, \
TEST_MPS, IS_FBCODE
from torch.testing._internal.common_dtype import floating_and_complex_types_and
import torch_npu
from torch_npu.testing._npu_testing_utils import update_skip_list, get_decorators
__all__ = []
def _get_tests_dict():
def _filter_json(data):
if _is_910A():
return {key: val for key, val in data.items() if len(val) > 1 and not (val[1] and "A2" in val[1])}
return {key: val for key, val in data.items() if len(val) > 1 and not (val[1] and "910A" in val[1])}
def _is_910A():
device_name = torch_npu.npu.get_device_name(0)
if "Ascend910A" in device_name or "Ascend910P" in device_name:
return True
return False
def _load_disabled_json(filename):
if os.path.isfile(filename):
with open(filename) as fp0:
disabled_dict = json.load(fp0, object_hook=_filter_json)
return disabled_dict
warnings.warn(f"Attempted to load json file {filename} but it does not exist.")
return {}
disabled_tests_dict = {}
slow_tests_dict = {}
if os.getenv("SLOW_TESTS_FILE", ""):
slow_tests_dict = maybe_load_json(os.getenv("SLOW_TESTS_FILE", ""))
if os.getenv("DISABLED_TESTS_FILE", ""):
disabled_tests_dict = _load_disabled_json(os.getenv("DISABLED_TESTS_FILE", ""))
if SLOW_TESTS_FILE:
if os.path.exists(SLOW_TESTS_FILE):
with open(SLOW_TESTS_FILE) as fp:
slow_tests_dict = json.load(fp)
os.environ['SLOW_TESTS_FILE'] = SLOW_TESTS_FILE
else:
warnings.warn(f'slow test file provided but not found: {SLOW_TESTS_FILE}')
if DISABLED_TESTS_FILE:
if os.path.exists(DISABLED_TESTS_FILE):
disabled_tests_dict = _load_disabled_json(DISABLED_TESTS_FILE)
os.environ['DISABLED_TESTS_FILE'] = DISABLED_TESTS_FILE
else:
warnings.warn(f'disabled test file provided but not found: {DISABLED_TESTS_FILE}')
return disabled_tests_dict, slow_tests_dict
def _check_if_enable_npu(test: unittest.TestCase):
disabled_tests_dict, slow_tests_dict = _get_tests_dict()
classname = str(test.__class__).split("'")[1].split(".")[-1]
sanitized_testname = remove_device_and_dtype_suffixes(test._testMethodName)
def matches_test(target: str):
target_test_parts = re.split(" (?=\\(__main__)", target) if "__main__" in target else target.split()
if len(target_test_parts) < 2:
return False
target_testname = target_test_parts[0]
target_classname = target_test_parts[1][1:-1].split(".")[-1]
class_device_replace = target_classname
if "PRIVATEUSE1" in target_classname:
class_device_replace = target_classname.replace("PRIVATEUSE1", "NPU")
elif "NPU" in target_classname:
class_device_replace = target_classname.replace("NPU", "PRIVATEUSE1")
return (classname.startswith(target_classname) or classname.startswith(class_device_replace)) \
and (target_testname in (test._testMethodName, sanitized_testname))
if any(matches_test(x) for x in slow_tests_dict.keys()):
getattr(test, test._testMethodName).__dict__['slow_test'] = True
if not TEST_WITH_SLOW:
raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
if not IS_SANDCASTLE:
should_skip = False
skip_msg = ""
for disabled_test, _ in disabled_tests_dict.items():
if matches_test(disabled_test):
should_skip = True
skip_msg = "this test is disabled now"
break
if should_skip and not RERUN_DISABLED_TESTS:
raise unittest.SkipTest(skip_msg)
if not should_skip and RERUN_DISABLED_TESTS:
skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \
" disabled tests are run"
raise unittest.SkipTest(skip_msg)
if TEST_SKIP_FAST:
if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test',
False):
raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
def _supported_dtypes(self, device_type):
return self.dtypes
def _supported_backward_dtypes(self, device_type):
if not self.supports_autograd:
return set()
backward_dtypes = self.backward_dtypes
allowed_backward_dtypes = floating_and_complex_types_and(
torch.bfloat16, torch.float16, torch.complex32
)
return set(allowed_backward_dtypes).intersection(backward_dtypes)
def _test_for_npu():
os.environ['PYTORCH_TESTING_DEVICE_FOR_CUSTOM'] = 'privateuse1'
os.environ['PYTORCH_TESTING_DEVICE_EXCEPT_FOR'] = 'cuda,cpu'
common_device_type.onlyCUDA = common_device_type.onlyPRIVATEUSE1
common_utils.TEST_CUDA = common_utils.TEST_PRIVATEUSE1
def _apply_test_patchs():
update_skip_list()
OpInfo.get_decorators = get_decorators
OpInfo.supported_dtypes = _supported_dtypes
OpInfo.supported_backward_dtypes = _supported_backward_dtypes
common_utils.check_if_enable = _check_if_enable_npu
_apply_test_patchs()
_test_for_npu()