import importlib
import inspect
import json
import os
import re
import unittest
from typing import Callable
from itertools import chain
from pathlib import Path
import pkgutil
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch._utils_internal import get_file_path_2
import torch_npu
def set_failure_list(api_str, value, signature, failure_list):
failure_list.append(f"# {api_str}:")
failure_list.append(f" - function signature is different: ")
failure_list.append(f" - the base signature is {value}.")
failure_list.append(f" - now it is {signature}.")
def is_not_compatibility(base_str, new_str):
base_io_params = base_str.split("->")
new_io_params = new_str.split("->")
base_input_params = base_io_params[0].strip()
new_input_params = new_io_params[0].strip()
base_out_params = "" if len(base_io_params) == 1 else base_io_params[1].strip()
new_out_params = "" if len(new_io_params) == 1 else new_io_params[1].strip()
if base_out_params != new_out_params:
return True
base_params = base_input_params[1:-1].split(",")
new_params = new_input_params[1:-1].split(",")
base_diff_params = set(base_params) - set(new_params)
if base_diff_params:
return True
new_diff_params = set(new_params) - set(base_params)
for elem in new_diff_params:
if "=" not in elem:
return True
base_arr = [elem for elem in base_params if "=" not in elem]
new_arr = [elem for elem in new_params if "=" not in elem]
i = 0
while i < len(base_arr):
if base_arr[i] != new_arr[i]:
return True
i += 1
return False
def get_func_from_yaml(yaml_path):
content = []
with open(yaml_path, 'r') as f:
for line in f.readlines():
if " func:" in line:
string = line.split("func:")[1].strip()
content.append(string)
return content
def parse_func_str(func_str):
if "(" in func_str:
func_name = func_str.split("(")[0]
signature = func_str.split(func_name)[1]
else:
func_name = func_str
signature = ""
return func_name, signature
def func_from_yaml(content, base_schema, failure_list):
torch_npu_path = torch_npu.__path__[0]
yaml1_path = os.path.join(torch_npu_path, "csrc/aten/npu_native_functions_by_codegen.yaml")
yaml2_path = os.path.join(torch_npu_path, "csrc/aten/npu_native_functions.yaml")
yaml1_content = get_func_from_yaml(yaml1_path)
yaml2_content = get_func_from_yaml(yaml2_path)
op_funcs = set(yaml1_content) - set(yaml2_content)
for func in op_funcs:
func_name, signature = parse_func_str(func)
if "func: " + func_name in base_schema:
value = base_schema["func: " + func_name]["signature"]
if is_not_compatibility(value, signature):
set_failure_list("func: " + func_name, value, signature, failure_list)
content["func: " + func_name] = {"signature": signature}
def _get_test_torch_version():
torch_npu_version = torch_npu.__version__
version_list = torch_npu_version.split('.')
if len(version_list) > 2:
return f'v{version_list[0]}.{version_list[1]}'
else:
raise RuntimeError("Invalid torch_npu version.")
def api_signature(obj, api_str, content, base_schema, failure_list):
signature = inspect.signature(obj)
signature = str(signature)
if api_str in base_schema.keys():
value = base_schema[api_str]["signature"]
if is_not_compatibility(value, signature):
set_failure_list(api_str, value, signature, failure_list)
content[api_str] = {"signature": signature}
def _discover_path_importables(pkg_pth, pkg_name):
"""Yield all importables under a given path and package.
This is like pkgutil.walk_packages, but does *not* skip over namespace
packages.
"""
for dir_path, _d, file_names in os.walk(pkg_pth):
pkg_dir_path = Path(dir_path)
if pkg_dir_path.parts[-1] == '__pycache__':
continue
if all(Path(_).suffix != '.py' for _ in file_names):
continue
rel_pt = pkg_dir_path.relative_to(pkg_pth)
pkg_pref = '.'.join((pkg_name,) + rel_pt.parts)
yield from (
pkg_path
for _, pkg_path, _ in pkgutil.walk_packages(
(str(pkg_dir_path),), prefix=f'{pkg_pref}.',
)
)
def _find_all_importables(pkg):
"""Find all importables in the project.
Return them in order.
"""
return sorted(
set(
chain.from_iterable(
_discover_path_importables(Path(p), pkg.__name__)
for p in pkg.__path__
),
),
)
class TestOpApiCompatibility(TestCase):
@staticmethod
def _is_mod_public(modname):
split_strs = modname.split('.')
for elem in split_strs:
if elem.startswith("_"):
return False
return True
@staticmethod
def _deleted_apis(base_funcs, now_funcs, failure_list):
deleted_apis = set(base_funcs) - set(now_funcs)
for func in deleted_apis:
failure_list.append(f"# {func}:")
failure_list.append(f" - {func} has been deleted.")
@staticmethod
def _newly_apis(base_funcs, now_funcs, failure_list, content):
newly_apis = set(now_funcs) - set(base_funcs)
for func in newly_apis:
failure_list.append(f"# {func}:")
failure_list.append(f" - {func} is new. Please add it to the torch_npu_OpApi_schema_all.json")
signature = content[func]["signature"]
failure_list.append(f" - it's signature is {signature}.")
@unittest.skipIf(_get_test_torch_version() in ["v1.11", "v2.0", "v2.2"],
"Skipping test for these torch versions.")
def test_op_api_compatibility(self):
failure_list = []
version_tag = _get_test_torch_version()
with open(get_file_path_2(os.path.dirname(os.path.dirname(__file__)),
'allowlist_for_publicAPI.json')) as json_file:
allow_dict_info = json.load(json_file)
allow_dict = {}
if "torch_npu" in allow_dict_info and "all_version" in allow_dict_info["torch_npu"]:
allow_dict["torch_npu"] = allow_dict_info["torch_npu"]["all_version"]
if version_tag in allow_dict_info["torch_npu"] and allow_dict_info["torch_npu"][version_tag]:
allow_dict["torch_npu"].extend(allow_dict_info["torch_npu"][version_tag])
base_schema = {}
with open(get_file_path_2(os.path.dirname(__file__), "torch_npu_OpApi_schema_all.json")) as fp:
base_schema0 = json.load(fp)
for key, value in base_schema0.items():
if key.startswith("op_api:"):
if not ("all_version" in value['version'] or version_tag in value['version']):
if 'newest' not in value['version']:
continue
else:
idx = value['version'].index("newest") - 1
if int(version_tag.replace(".", "")[1:]) < int(value['version'][idx].replace(".", "")[1:]):
continue
key = key.replace("op_api: ", "").strip()
func_name, signature = parse_func_str(key)
if func_name not in base_schema:
base_schema[func_name] = dict()
base_schema[func_name]["signature"] = signature
content = {}
def test_module(modname):
try:
if "__main__" in modname or \
modname in ["torch_npu.dynamo.torchair.core._backend",
"torch_npu.dynamo.torchair.core._torchair"]:
return
mod = importlib.import_module(modname)
except Exception:
return
if not self._is_mod_public(modname):
return
def check_one_element(elem, modname, mod, *, is_public):
obj = getattr(mod, elem)
if not (isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)):
return
elem_module = getattr(obj, '__module__', None)
if not elem_module == "torch._ops.npu":
return
if (modname in allow_dict and elem in allow_dict[modname]) or elem in torch_npu.__all__:
api_str = f"{modname}.{elem}"
api_signature(obj, api_str, content, base_schema, failure_list)
if hasattr(mod, '__all__'):
public_api = mod.__all__
all_api = dir(mod)
for elem in all_api:
check_one_element(elem, modname, mod, is_public=elem in public_api)
else:
all_api = dir(mod)
for elem in all_api:
if not elem.startswith('_'):
check_one_element(elem, modname, mod, is_public=True)
for modname in _find_all_importables(torch_npu):
test_module(modname)
test_module('torch_npu')
base_funcs = base_schema.keys()
now_funcs = content.keys()
self._deleted_apis(base_funcs, now_funcs, failure_list)
self._newly_apis(base_funcs, now_funcs, failure_list, content)
msg = "All the APIs below do not meet the compatibility guidelines. "
msg += "If the change timeline has been reached, you can modify the torch_npu_OpApi_schema_all.json to make it OK."
msg += "\n\nFull list:\n"
msg += "\n".join(map(str, failure_list))
self.assertTrue(not failure_list, msg)
@unittest.skipIf(_get_test_torch_version() in ["v1.11", "v2.0", "v2.2"],
"Skipping test for these torch versions.")
def test_op_func_compatibility(self):
failure_list = []
all_base_schema = {}
with open(get_file_path_2(os.path.dirname(__file__), "torch_npu_OpApi_schema_all.json")) as fp:
all_base_schema0 = json.load(fp)
for key, value in all_base_schema0.items():
if not key.startswith("op_api:"):
all_base_schema[key] = value
version_tag = _get_test_torch_version()
base_schema = {}
for key, value in all_base_schema.items():
if not("all_version" in value['version'] or version_tag in value['version']):
if 'newest' not in value['version']:
continue
else:
idx = value['version'].index("newest") - 1
if int(version_tag.replace(".", "")[1:]) < int(value['version'][idx].replace(".", "")[1:]):
continue
func_name, signature = parse_func_str(key)
if func_name not in base_schema:
base_schema[func_name] = dict()
base_schema[func_name]["signature"] = signature
content = {}
func_from_yaml(content, base_schema, failure_list)
base_funcs = base_schema.keys()
now_funcs = content.keys()
self._deleted_apis(base_funcs, now_funcs, failure_list)
self._newly_apis(base_funcs, now_funcs, failure_list, content)
msg = "All the OpAPIs below do not meet the compatibility guidelines. "
msg += "If the change timeline has been reached, you can modify the torch_npu_OpApi_schema_all.json to make it OK."
msg += "\n\nFull list:\n"
msg += "\n".join(map(str, failure_list))
self.assertTrue(not failure_list, msg)
if __name__ == '__main__':
run_tests()