import os
import stat
import functools
import hashlib
from typing import (List, Dict, Optional, Set, Callable, Any,
Union, TypeVar, Iterable, Tuple)
from collections import defaultdict
import yaml
from codegen.code_template import CodeTemplate
from codegen.model import (NativeFunction, SelfArgument,
TensorOptionsArguments,
assert_never)
from codegen.api.types.signatures import NativeSignature
from codegen.context import native_function_manager
from codegen.utils import concatMap, context
T = TypeVar('T')
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn)
def string_stable_hash(s: str) -> int:
sha256 = hashlib.sha256(s.encode('latin1')).digest()
return int.from_bytes(sha256, byteorder='little')
class FileManager:
install_dir: str
template_dir: str
dry_run: bool
filenames: Set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir
self.template_dir = template_dir
self.filenames = set()
self.dry_run = dry_run
@staticmethod
def _remove_path_safety(filepath: str) -> None:
if os.path.islink(filepath):
raise RuntimeError(f"Invalid path is a soft chain: {filepath}")
if os.path.exists(filepath):
os.remove(filepath)
@staticmethod
def _write_if_changed(filename: str, contents: str) -> None:
old_contents: Optional[str]
filepath = os.path.realpath(filename)
try:
with open(filepath, 'r') as f:
old_contents = f.read()
except IOError:
old_contents = None
if contents != old_contents:
FileManager._remove_path_safety(filepath)
with os.fdopen(os.open(filepath, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), "w") as f:
f.write(contents)
os.chmod(filepath, stat.S_IRUSR | stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
def write_with_template(self, filename: str, template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
filename = '{}/{}'.format(self.install_dir, filename)
if filename in self.filenames:
raise ValueError(f"duplicate file write {filename}")
self.filenames.add(filename)
if not self.dry_run:
env = env_callable()
if isinstance(env, dict):
if 'generated_comment' not in env:
comment = "@" + "generated by tools/codegen/gen.py"
comment += " from {}".format(os.path.basename(template_fn))
env['generated_comment'] = comment
env['legacy_th_headers'] = []
template = _read_template(os.path.join(self.template_dir, template_fn))
self._write_if_changed(filename, template.substitute(env))
elif isinstance(env, str):
self._write_if_changed(filename, env)
else:
assert_never(env)
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
self.write_with_template(filename, filename, env_callable)
def write_sharded(
self,
filename: str,
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str]
) -> None:
everything: Dict[str, Any] = {'shard_id': 'Everything'}
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
all_shards = [everything] + shards
if base_env is not None:
for shard in all_shards:
shard.update(base_env)
for key in sharded_keys:
for shard in all_shards:
if key in shard:
if not isinstance(shard[key], list):
raise TypeError("sharded keys in base_env must be a list")
shard[key] = shard[key].copy()
else:
shard[key] = []
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
for k, v in from_.items():
if k not in sharded_keys:
raise KeyError("undeclared sharded key {k}")
into[k] += v
for item in items:
key = key_fn(item)
sid = string_stable_hash(key) % num_shards
env = env_callable(item)
merge_env(shards[sid], env)
merge_env(everything, env)
dot_pos = filename.rfind('.')
if dot_pos == -1:
dot_pos = len(filename)
base_filename = filename[:dot_pos]
extension = filename[dot_pos:]
for shard in all_shards:
shard_id = shard['shard_id']
self.write_with_template(f"{base_filename}{shard_id}{extension}",
filename,
lambda: shard)
self.filenames.discard(
f"{self.install_dir}/{base_filename}Everything{extension}")
def write_outputs(self, filename: str) -> None:
"""Write a file containing the list of all outputs which are
generated by this script.
"""
self._write_if_changed(
filename,
''.join(name + ";" for name in sorted(self.filenames)))
SYMINT_SET = set()
def parse_native_yaml_struct(
es: object,
) -> List[NativeFunction]:
rs: List[NativeFunction] = []
if not es:
return rs
if 'symint' not in es:
raise AssertionError("Can't find symint in yaml.")
if 'official' not in es:
raise AssertionError("Can't find official in yaml.")
if 'custom' not in es:
raise AssertionError("Can't find custom in yaml.")
if es['symint']:
for e in es['symint']:
global SYMINT_SET
SYMINT_SET.add(e['func'].split("(")[0])
all_funcs = []
if es['official']:
all_funcs += es['official']
if es['custom']:
all_funcs += es['custom']
if ('quant' in es) and es['quant']:
all_funcs += es['quant']
if not isinstance(all_funcs, list):
raise TypeError("all_funcs must be a list")
for e in all_funcs:
funcs = e.get("func")
with context(lambda: f"in:\n {funcs}"):
func, m = NativeFunction.from_yaml(e)
rs.append(func)
return rs
def gen_function_declaration(
f: NativeFunction,
backend_decalarations: Dict,
):
with native_function_manager(f):
has_symint = False
op_name = str(f.func.name.name)
global SYMINT_SET
if f.func.is_out_fn():
op_name += "_out"
if str(f.func.name) in SYMINT_SET:
op_name += "_symint"
has_symint = True
sig = NativeSignature(f.func, prefix='', symint=has_symint)
sig_str = f"OP_PLUGIN_HIDDEN {sig.decl(name=op_name)};"
backend_decalarations["op_api"].append(sig_str)
backend_decalarations["acl_op"].append(sig_str)
if f.sparse is not None:
op_name += "_sparse"
backend_decalarations["sparse"].append(f"OP_PLUGIN_HIDDEN {sig.decl(name=op_name)};")
def gen_return(
f: NativeFunction,
deprecated_dict: Dict,
) -> List[Optional[str]]:
ret = []
with native_function_manager(f):
has_symint = False
op_name_with_overload = str(f.func.name)
op_name = str(f.func.name.name)
global SYMINT_SET
if f.func.is_out_fn():
op_name += "_out"
if str(f.func.name) in SYMINT_SET:
op_name += "_symint"
has_symint = True
sig = NativeSignature(f.func, prefix='', symint=has_symint)
args_exprs_str = ', '.join(a.name for a in sig.arguments())
impl_name = f.impl_name
if not f.impl_name:
impl_name = op_name
deprecated_warn = ""
if op_name_with_overload in deprecated_dict.keys():
deprecated_func = f'torch_npu.{str(f.func.name.name)}'
deprecated_replace = deprecated_dict[op_name_with_overload]
if deprecated_replace is not None:
deprecated_warn += f'TORCH_WARN_ONCE("{deprecated_func} is deprecated and will be removed in future version. \
Use {deprecated_replace} instead.");'
else:
deprecated_warn += f'TORCH_WARN_ONCE("{deprecated_func} is deprecated and will be removed in future version.");'
format_check = []
format_display = []
place_holder = []
format_for_args = []
inputs_list = ""
is_aclnn_only = "c10_npu::IsAclnnOnly()"
for a in sig.arguments():
argument = a.argument
if isinstance(a.argument, SelfArgument):
argument = a.argument.argument
if not isinstance(a.argument, TensorOptionsArguments) and argument.type.is_tensor_like():
format_for_args.append(
f" bool {a.name}_base_format = at_npu::native::FormatHelper::IsOpInputBaseFormat({a.name});\n")
format_check.append(f" && {a.name}_base_format")
format_display.append(f", !{a.name}_base_format")
place_holder.append(f", {a.name} is internal format: %d")
inputs_list += f"{a.name}, "
inputs_list = inputs_list[:-2]
if "op_api" in f.impl_ns and "acl_op" in f.impl_ns:
if not f.internal_format_opapi:
ret.append(f"""{sig.defn(name=op_name)}{{
{deprecated_warn}
bool is_jit_disable = at_npu::native::env::CheckJitDisable();
{"".join(format_for_args)}
OP_EXEC_LOG({impl_name}, "exec aten", {inputs_list});
ASCEND_LOGI("{impl_name} exec with jit compile: %d{"".join(place_holder)}",
!is_jit_disable{"".join(format_display)});
if (is_jit_disable{"".join(format_check)}) {{
return op_api::{impl_name}({args_exprs_str});
}} else {{
if ({is_aclnn_only}) {{
TORCH_CHECK(false,
"Current device only support aclnn operator, and current operator {impl_name} do not support internal format.",
PTA_ERROR(ErrCode::NOT_SUPPORT));
}}
return acl_op::{impl_name}({args_exprs_str});
}}
}}
""")
else:
ret.append(f"""{sig.defn(name=op_name)}{{
{deprecated_warn}
bool is_jit_disable = at_npu::native::env::CheckJitDisable();
OP_EXEC_LOG({impl_name}, "exec aten", {inputs_list});
ASCEND_LOGI("{impl_name} exec with jit compile: %d", !is_jit_disable);
if (is_jit_disable) {{
return op_api::{impl_name}({args_exprs_str});
}} else {{
return acl_op::{impl_name}({args_exprs_str});
}}
}}
""")
elif "op_api" in f.impl_ns:
ns = f.impl_ns[0]
if f.internal_format_opapi:
ret.append(f"""{sig.defn(name=op_name)}{{
{deprecated_warn}
OP_EXEC_LOG({impl_name}, "exec aten", {inputs_list});
return {ns}::{impl_name}({args_exprs_str});
}}
""")
else:
ret.append(f"""{sig.defn(name=op_name)}{{
{deprecated_warn}
{"".join(format_for_args)}
OP_EXEC_LOG({impl_name}, "exec aten", {inputs_list});
if ({("".join(format_check)).replace(" && ", " || !").replace(" || ", "", 1)}) {{
TORCH_CHECK(false,
"Current operator {impl_name} do not support internal format. ",
PTA_ERROR(ErrCode::NOT_SUPPORT));
}}
return {ns}::{impl_name}({args_exprs_str});
}}
""")
elif "acl_op" in f.impl_ns:
ns = f.impl_ns[0]
ret.append(f"""{sig.defn(name=op_name)}{{
{deprecated_warn}
OP_EXEC_LOG({impl_name}, "exec aten", {inputs_list});
if ({is_aclnn_only}) {{
TORCH_CHECK(false,
"Current device only support aclnn operator, "
"but current operator {impl_name} do not have aclnn implementation",
PTA_ERROR(ErrCode::NOT_SUPPORT));
}}
return {ns}::{impl_name}({args_exprs_str});
}}
""")
if f.sparse is not None:
ret.append(f"""{sig.defn(name=op_name + "_sparse")}{{
{deprecated_warn}
OP_EXEC_LOG({impl_name}, "exec aten", {inputs_list});
return sparse::{impl_name}_sparse({args_exprs_str});
}}
""")
return ret
def parse_native_yaml(
path: str,
deprecate_path: str,
) -> Tuple[Dict[str, list], List[Optional[str]]]:
with open(path, "r") as f:
es = yaml.safe_load(f)
with open(deprecate_path, "r") as f:
dp = yaml.safe_load(f)
res = parse_native_yaml_struct(es)
backend_declarations = defaultdict(list)
for f in res:
gen_function_declaration(f, backend_declarations)
deprecated = dp["deprecated"]
deprecated_dict = {}
for item in deprecated:
deprecated_dict[item.get("name")] = item.get("replace", None)
dispatch_registrations_body = sorted(set(concatMap(lambda f: gen_return(f, deprecated_dict), res)))
return backend_declarations, dispatch_registrations_body