"""
This file is mainly used to transform failed test names into DecorateInfo Object.
"""
from pathlib import Path
import os
import yaml
from torchgen.code_template import CodeTemplate
from torchgen.gen import FileManager
from codegen.autograd.utils import VERSION_PART
from codegen.utils import PathManager
project_path = Path(os.path.dirname(__file__)).parent
op_plugin_info_path = os.path.realpath(os.path.join(
project_path,
f'third_party/op-plugin/test/test_v{VERSION_PART[0]}r{VERSION_PART[1]}_ops',
"unsupported_ops_info.yaml"))
torch_npu_info_path = os.path.realpath(os.path.join(project_path, "test", "unsupported_ops_info.yaml"))
unsupported_dict_path = op_plugin_info_path if os.path.exists(op_plugin_info_path) else torch_npu_info_path
PathManager.check_directory_path_readable(unsupported_dict_path)
with open(unsupported_dict_path, "r") as f:
unsupported_summary_dict = yaml.safe_load(f)
skip_list = dict()
class2test = {
"TestCommon":[
"test_compare_cpu",
"test_dtypes",
"test_errors",
"test_multiple_devices",
"test_non_standard_bool_values",
"test_noncontiguous_samples",
"test_numpy_ref",
"test_out",
"test_out_integral_dtype",
"test_out_warning",
"test_pointwise_tag_coverage",
"test_variant_consistency_eager",
"test_complex_half_reference_testing"
],
"TestCompositeCompliance":[
"test_backward",
"test_forward_ad",
"test_operator",
],
"TestMathBits":[
"test_conj_view",
"test_neg_conj_view",
"test_neg_view",
],
"TestFakeTensor":[
"test_fake",
"test_fake_autocast",
"test_fake_crossref_backward_amp",
"test_fake_crossref_backward_no_amp",
"test_pointwise_ops",
],
"TestTags":["test_tags"]
}
def get_class_name(func_name):
for key, value in class2test.items():
if func_name in value:
return key
raise RuntimeError("Can't find corresponding class name of test: {}".format(func_name))
def update_skip_list(class_name, op_name, func_name, dtype):
skip_reason = "npu test skipped!"
template = f"""\nDecorateInfo(unittest.skip("{skip_reason}"), \'{class_name}\', \'{func_name}\', dtypes=(torch.{dtype},))"""
template_ = f"""\nDecorateInfo(unittest.skip("{skip_reason}"), \'{class_name}\', \'{func_name}\')"""
if dtype:
if op_name not in skip_list:
skip_list[op_name] = [template]
else:
skip_list[op_name].append(template)
else:
if op_name not in skip_list:
skip_list[op_name] = [template_]
else:
skip_list[op_name].append(template_)
def gen_ops_info(summary_dict):
for op_name in summary_dict:
for func_name in summary_dict[op_name]:
for dtype in summary_dict[op_name][func_name]:
update_skip_list(get_class_name(func_name), op_name, func_name, dtype)
skip_template = CodeTemplate(
"""\n'${op_name}': [${decorators}]"""
)
fm = FileManager(os.path.join("torch_npu", "testing"), os.path.join("codegen", "templates"), False)
fm.write_with_template(f"_npu_testing_utils.py", "npu_testing_utils.py", lambda:{
"skip_detail": [skip_template.substitute(op_name=op, decorators=doc) for op, doc in skip_list.items()]
})
if __name__ == "__main__":
gen_ops_info(unsupported_summary_dict)