import re
import stat
from collections import defaultdict
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional, Dict
from enum import Enum
import contextlib
import textwrap
import os
import warnings
import yaml
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader
try:
from yaml import CSafeDumper as Dumper
except ImportError:
from yaml import SafeDumper as Dumper
from codegen.model import NativeFunction, FunctionSchema
from codegen.api import cpp
YamlDumper = Dumper
GLOBAL_STRUCTURED_OP_INFO_CACHE = defaultdict(str)
CUSTOM_YAML_NAME = "npu_native_functions_by_codegen.yaml"
FIELDS_TO_USE = ["func"]
MANUAL_OPS = ["argmin", "argmax", "nan_to_num", "nan_to_num_",
"nan_to_num.out", "_embedding_bag_dense_backward", "matmul_backward"]
class YamlLoader(Loader):
def construct_mapping(self, node, deep=False):
mapping = []
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
if key in mapping:
raise KeyError(f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}")
mapping.append(key)
mapping = super().construct_mapping(node, deep=deep)
return mapping
Target = Enum('Target', (
'DEFINITION',
'DECLARATION',
'REGISTRATION',
'ANONYMOUS_DEFINITION',
'ANONYMOUS_DEFINITION_UNSUPPORT',
'NAMESPACED_DEFINITION',
'NAMESPACED_DECLARATION',
))
IDENT_REGEX = r'(^|\W){}($|\W)'
class PathManager:
@classmethod
def check_path_owner_consistent(cls, path: str):
"""
Function Description:
check whether the path belong to process owner
Parameter:
path: the path to check
Exception Description:
when invalid path, prompt the user
"""
if not os.path.exists(path):
msg = f"The path does not exist: {path}"
raise RuntimeError(msg)
if os.stat(path).st_uid != os.getuid():
warnings.warn(f"Warning: The {path} owner does not match the current user.")
@classmethod
def check_directory_path_readable(cls, path):
"""
Function Description:
check whether the path is writable
Parameter:
path: the path to check
Exception Description:
when invalid data throw exception
"""
cls.check_path_owner_consistent(path)
if os.path.islink(path):
msg = f"Invalid path is a soft chain: {path}"
raise RuntimeError(msg)
if not os.access(path, os.R_OK):
msg = f"The path permission check failed: {path}"
raise RuntimeError(msg)
@classmethod
def remove_path_safety(cls, path: str):
if os.path.islink(path):
raise RuntimeError(f"Invalid path is a soft chain: {path}")
if os.path.exists(path):
os.remove(path)
def split_name_params(schema: str) -> Tuple[str, List[str]]:
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
if m is None:
raise RuntimeError(f'Unsupported function schema: {schema}')
name, _, params = m.groups()
return name, params.split(', ')
T = TypeVar('T')
S = TypeVar('S')
def map_maybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
r = func(x)
if r is not None:
yield r
def concat_map(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
for r in func(x):
yield r
@contextlib.contextmanager
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
try:
yield
except Exception as e:
msg = msg_fn()
msg = textwrap.indent(msg, ' ')
msg = f'{e.args[0]}\n{msg}' if e.args else msg
e.args = (msg,) + e.args[1:]
raise
def parse_npu_yaml(custom_path: str, use_line_loader=True) -> List:
if not os.path.exists(custom_path):
return {}
PathManager.check_directory_path_readable(custom_path)
with open(custom_path, 'r') as yaml_file:
source_es = yaml.safe_load(yaml_file)
return source_es
def merge_yaml(base_data, additional_data):
map_dict = {"official": "supported"}
key_map = lambda x: map_dict.get(x, x)
if isinstance(base_data, dict):
for key, value in additional_data.items():
if key_map(key) not in base_data:
base_data[key_map(key)] = value
else:
base_data[key_map(key)] = merge_yaml(base_data[key_map(key)], value)
elif isinstance(base_data, list):
for item in additional_data:
if item not in base_data:
base_data.append(item)
return base_data
def merge_custom_yaml(pta_path, op_plugin_path):
def parse_op_name(value):
return value["func"].split("(")[0] if isinstance(value, dict) else value
pta_es = parse_npu_yaml(pta_path)
op_es = parse_npu_yaml(op_plugin_path)
all_op_name = []
for key, value in op_es.items():
if isinstance(value, list):
all_op_name.extend([parse_op_name(op) for op in value])
for key, value in pta_es.items():
if isinstance(value, list):
pta_es[key] = [op for op in value
if parse_op_name(op) not in all_op_name]
op_es["official"] = [op for op in op_es.get("official", [])
if parse_op_name(op) not in MANUAL_OPS]
merged_yaml = merge_yaml(pta_es, op_es)
merged_yaml_path = gen_custom_yaml_path(pta_path)
PathManager.remove_path_safety(merged_yaml_path)
with os.fdopen(os.open(merged_yaml_path, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), "w") as outfile:
yaml.dump(merged_yaml, outfile, default_flow_style=False, width=float("inf"))
os.chmod(merged_yaml_path, stat.S_IRUSR | stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
return merged_yaml
def field_tag(custom_es):
for i, es in enumerate(custom_es):
if not isinstance(es, dict):
continue
custom_es[i] = {key: custom_es[i][key] for key in FIELDS_TO_USE if key in custom_es[i]}
return custom_es
def filt_exposed_api(custom_path: str):
source_es = parse_npu_yaml(custom_path)
custom_es = source_es.get('custom', []) + source_es.get('custom_autograd', [])
exposed_set = set()
for es in custom_es:
if es.get('exposed', False):
exposed_set.add(es.get('func').split('(')[0].split('.')[0])
return list(exposed_set)
def parse_opplugin_yaml(custom_path: str) -> None:
source_es = parse_npu_yaml(custom_path)
suppprt_keys = ['custom', 'official', 'autograd', 'custom_autograd']
support_ops = []
for key in suppprt_keys:
value = source_es.pop(key, [])
if value is not None:
support_ops.extend(value)
symint = source_es.pop("symint", [])
global GLOBAL_STRUCTURED_OP_INFO_CACHE
for x in support_ops:
funcs = x.get("func", None)
if not isinstance(funcs, str):
raise TypeError(f'not a str : {funcs}')
func = FunctionSchema.parse(funcs)
wrap_name = cpp.name(func)
op_key = str(func.name)
cur_wrap_name = GLOBAL_STRUCTURED_OP_INFO_CACHE.get(op_key, "")
if cur_wrap_name and cur_wrap_name != wrap_name:
print(f"Find different wrap_name for {cur_wrap_name} and {wrap_name} between pta and opplugin, ",
f"with {wrap_name} being used as the actual wrap_name")
GLOBAL_STRUCTURED_OP_INFO_CACHE[op_key] = wrap_name
def enable_opplugin() -> bool:
base_dir = os.path.dirname(os.path.abspath(__file__))
op_plugin_path = os.path.join(base_dir, '../third_party/op-plugin/op_plugin')
return os.path.exists(op_plugin_path)
def is_op_valid(op_key: str) -> bool:
return True if op_key in GLOBAL_STRUCTURED_OP_INFO_CACHE else False
def get_opplugin_wrap_name(func) -> str:
op_key = str(func.func.name) if type(func) is NativeFunction else func
return GLOBAL_STRUCTURED_OP_INFO_CACHE.get(op_key, "")
def gen_custom_yaml_path(original_path, codegen_yaml_filename=CUSTOM_YAML_NAME):
new_path = os.path.join(os.path.dirname(original_path), codegen_yaml_filename)
return new_path