import os
from collections import defaultdict
from typing import List, Dict, Sequence, Tuple
import yaml
from torchnpugen.code_template import CodeTemplate
from torchnpugen.gen import FileManager
from torchnpugen.model import BaseTy, SchemaKind, BaseType, NativeFunction
from torchnpugen.op_codegen_utils import concatMap, PathManager
from torchnpugen.context import native_function_manager
from torchnpugen.api.types import NativeSignature
from torchnpugen.api import cpp
from .model import STRUCTURED_GEN_OPAPI_ALLOWED_KEYS, StructInfo, ResInfo, filt_input_tensor
USED_KEYS = ['official', 'custom', 'autograd']
SYMINT_OPS = set()
_torch_ver = os.environ.get('PYTORCH_VERSION')
if not _torch_ver:
raise RuntimeError('PYTORCH_VERSION environment variable is required')
PYTORCH_VERSION = _torch_ver.split('.')
if len(PYTORCH_VERSION) < 2:
raise RuntimeError('PYTORCH_VERSION must contain major and minor components')
ACLNN_FUNCTIONS_DEFINITION = CodeTemplate("""\
${return_type} ${func_name}(${args_str})
{
${integral_identity_guard}
${do_compatibility}
${new_params_def}
${cpu_scalar_h2d_code}
${compute_names}
${define_size_or_dtype}
${check_or_apply_tensor}
${cpu_scalar_op_code}
${exec_cmd}(${aclnnargs});
${infer_name}
return${result};
}
""")
ACLNN_FUNCTIONS_DELEGATE = CodeTemplate("""\
${return_type} ${func_name}(${args_str})
{
return op_api::${func_name_out}(${args_str_out});
}
""")
DO_COMPATIBILITY = CodeTemplate("""\
DO_COMPATIBILITY(${aclnnname}, acl_op::${func_name}(${args_exprs_str}));
""")
INTEGRAL_IDENTITY_GUARD = CodeTemplate("""\
if (at::isIntegralType(${tensor}.scalar_type(), /*includeBool=*/false)) {
return acl_op::${func_name}(${args_exprs_str});
}
""")
APPLY_TENSOR = CodeTemplate("""\
at::Tensor ${result_name} = npu_preparation::apply_tensor_without_format(${size},
${dtype});
""")
CHECK_TENSOR = CodeTemplate("""\
npu_preparation::check_tensor({${input}}, ${result_name}, ${dtype}, ${size});
""")
INFER_NAME = CodeTemplate("""\
at::namedinference::propagate_names(${result_name}, ${infer_func});
""")
INFER_NAME_GROUP = CodeTemplate("""\
at::namedinference::propagate_names_if_nonempty(${result_name}, maybe_names);
""")
COMPUTE_NAME_GROUP = CodeTemplate("""\
std::vector<at::Tensor> tensor_list = {${tensor_list}};
auto maybe_names = op_plugin::utils::compute_names_npu(tensor_list);
""")
def is_support_version(op):
op_api_version = op.get('op_api', None)
version = f"v{PYTORCH_VERSION[0]}.{PYTORCH_VERSION[1]}"
def parse_version(v):
return tuple(map(int, v.lstrip('v').split('.')))
if op_api_version is None:
is_support = False
elif op_api_version == 'all_version':
is_support = True
elif isinstance(op_api_version, list):
is_support = parse_version(version) >= parse_version(op_api_version[0])
else:
is_support = version in op_api_version
return is_support
def filt_op_branch(struct_ops: Dict) -> Dict:
support_ops = []
for key in USED_KEYS:
if struct_ops[key] is not None:
support_ops += struct_ops[key]
def filt_gen_opapi(op) -> bool:
if 'gen_opapi' not in op.keys() or not is_support_version(op):
return False
gen_opapi_info = op.get('gen_opapi')
if not isinstance(gen_opapi_info, dict):
raise RuntimeError(f"{op.get('func')} has invalid gen_opapi configuration")
if op.get('structured_delegate') is not None:
raise RuntimeError(
f"{op.get('func')} specifies structured_delegate, so gen_opapi is not allowed"
)
if bool(op.get('structured', False)):
invalid_keys = set(gen_opapi_info.keys()) - STRUCTURED_GEN_OPAPI_ALLOWED_KEYS
if invalid_keys:
raise RuntimeError(
f"{op.get('func')} is marked structured, so gen_opapi only supports "
f"{sorted(STRUCTURED_GEN_OPAPI_ALLOWED_KEYS)}. Unsupported keys: {sorted(invalid_keys)}"
)
return True
if 'symint' in struct_ops and struct_ops['symint']:
for op in struct_ops['symint']:
if is_support_version(op):
SYMINT_OPS.add(op['func'].split("(")[0])
filt_ops = [op for op in support_ops if filt_gen_opapi(op)]
return filt_ops
def remove_empty_lines(text):
lines = text.split('\n')
struct_code = '\n'.join(line for line in lines if line.strip())
return struct_code + '\n\n'
def parse_struct_yaml(path, native_functions: Sequence[NativeFunction]) -> List[StructInfo]:
path = os.path.realpath(path)
PathManager.check_directory_path_readable(path)
with open(path, 'r', encoding='utf-8') as struct_file:
struct_ops = yaml.safe_load(struct_file)
filt_ops = filt_op_branch(struct_ops)
struct_infos = StructInfo.from_yaml(filt_ops, native_functions)
return struct_infos
def gen_size_dtype_map(resinfos: List['ResInfo']) -> Tuple[Dict[str, str], Dict[str, str]]:
size_map: Dict[str, str] = defaultdict(None)
dtype_map: Dict[str, str] = defaultdict(None)
for resinfo in resinfos:
if size_map.get(resinfo.size) is None:
size_map[resinfo.size] = f"output_size_{len(size_map)}"
if dtype_map.get(resinfo.dtype) is None:
dtype_map[resinfo.dtype] = f"output_dtype_{len(dtype_map)}"
return size_map, dtype_map
def _gen_cpu_scalar_op_code(struct: StructInfo, args, exec_cmd: str) -> str:
if not struct.cpu_scalar_op:
return ""
branch_code_parts = []
copy_set = set(struct.cpu_scalar_h2d) if struct.cpu_scalar_h2d else set()
result_names = [res.name for res in struct.results] if struct.results else []
cmd_args_list = [arg.strip() for arg in struct.cmd_args.split(',')]
original_params = cmd_args_list[1:]
scalar_params = {}
for op in struct.cpu_scalar_op:
if op.param not in scalar_params:
scalar_params[op.param] = f"{op.param}_scalar"
scalar_decls = "".join(
f"at::Scalar {var} = {param}.item();\n"
for param, var in scalar_params.items()
)
for op in struct.cpu_scalar_op:
params_list = list(original_params)
for i, p in enumerate(params_list):
if p == op.param:
params_list[i] = scalar_params[op.param]
elif p in copy_set:
params_list[i] = f"{p}_cp"
for result_name in result_names:
for i, p in enumerate(params_list):
if p == "out":
params_list[i] = result_name
params_str = ", ".join(params_list)
condition = f"npu_preparation::IsCPUScalar({op.param})"
branch_code_parts.append(
f"if ({condition}) {{\n"
f" {scalar_decls}"
f" {exec_cmd}({op.exec_cmd}, {params_str});\n"
f" return{struct.return_args};\n"
f"}}\n"
)
return "".join(branch_code_parts)
def _gen_cpu_scalar_h2d_code(struct: StructInfo, args) -> str:
if not struct.cpu_scalar_h2d:
return ""
copy_code_parts = []
arg_names = [a.name for a in args]
tensor_arg_names = [a.name for a in args
if hasattr(a, 'argument') and hasattr(a.argument, 'type')
and isinstance(a.argument.type, BaseType) and a.argument.type.name == BaseTy.Tensor]
copy_set = set(struct.cpu_scalar_h2d)
for param in struct.cpu_scalar_h2d:
if param not in arg_names:
raise RuntimeError(f"cpu_scalar_h2d contains invalid parameter: {param}")
non_copy_tensor_args = [n for n in tensor_arg_names if n not in copy_set]
other_copy_params = [p for p in struct.cpu_scalar_h2d if p != param]
if non_copy_tensor_args:
device_expr = f"{non_copy_tensor_args[0]}.device()"
elif other_copy_params:
device_expr = f"{other_copy_params[0]}.device()"
else:
device_expr = "c10::Device(c10::DeviceType::PrivateUse1)"
copy_code_parts.append(
f"at::Tensor {param}_cp = {param};\n"
f"if (npu_preparation::IsCPUScalar({param})) {{\n"
f" at::Scalar {param}_scalar = {param}.item();\n"
f" {param}_cp = npu_preparation::copy_scalar_to_device({param}_scalar, {param}.scalar_type(), {device_expr});\n"
f"}}\n"
)
return "".join(copy_code_parts)
def _replace_cmd_args_with_copy(struct: StructInfo, args) -> str:
if not struct.cpu_scalar_h2d:
return struct.cmd_args
cmd_args_list = [arg.strip() for arg in struct.cmd_args.split(',')]
scalar_set = set(struct.cpu_scalar_h2d)
new_args_list = []
for arg in cmd_args_list:
if arg in scalar_set:
new_args_list.append(f"{arg}_cp")
else:
new_args_list.append(arg)
return ", ".join(new_args_list)
def _replace_scalar_refs_in_expr(expr: str, cpu_scalar_h2d: List[str]) -> str:
if not cpu_scalar_h2d or expr is None:
return expr
result = expr
for param in cpu_scalar_h2d:
result = result.replace(f"{param}.sizes()", f"{param}_cp.sizes()")
result = result.replace(f"{param}.scalar_type()", f"{param}_cp.scalar_type()")
result = result.replace(f"{param}.device()", f"{param}_cp.device()")
result = result.replace(f"{param}.options()", f"{param}_cp.options()")
result = result.replace(f"({param},", f"({param}_cp,")
result = result.replace(f", {param},", f", {param}_cp,")
result = result.replace(f", {param})", f", {param}_cp)")
result = result.replace(f"({param})", f"({param}_cp)")
if result == expr:
for param in cpu_scalar_h2d:
if result == param:
result = f"{param}_cp"
break
return result
def compute_op_api_definition(struct: StructInfo, env_aclnn_extension_switch: bool):
f = struct.func
is_symint = struct.name in SYMINT_OPS
with native_function_manager(f):
kind = f.func.kind()
sig = NativeSignature(f.func, prefix='', symint=is_symint)
name = cpp.name(f.func)
name = name + '_symint' if is_symint else name
args = sig.arguments()
args_str = ', '.join(a.defn() for a in args)
args_exprs_str = ', '.join(a.name for a in args)
return_type = cpp.returns_type(f.func.returns).cpp_type()
res_infos = struct.results
if struct.structured_inherit is not None and kind == SchemaKind.inplace:
delegate_function = struct.structured_inherit
delegate_name = cpp.name(delegate_function.func)
delegate_name = delegate_name + '_symint' if is_symint else delegate_name
delegate_args_exprs_str = f'{args_exprs_str}, {f.func.arguments.self_arg.argument.name}'
return [
ACLNN_FUNCTIONS_DELEGATE.substitute(
return_type=return_type,
func_name=name,
args_str=args_str,
func_name_out=delegate_name,
args_str_out=delegate_args_exprs_str,
)
]
do_compatibility = (
DO_COMPATIBILITY.substitute(aclnnname=struct.aclnn_name, func_name=name, args_exprs_str=args_exprs_str)
if struct.acl_op
else ""
)
integral_identity_guard = ""
if struct.integral_identity_tensor is not None:
integral_identity_guard = INTEGRAL_IDENTITY_GUARD.substitute(
tensor=struct.integral_identity_tensor, func_name=name, args_exprs_str=args_exprs_str
)
integral_identity_guard = ""
if struct.integral_identity_tensor is not None:
integral_identity_guard = INTEGRAL_IDENTITY_GUARD.substitute(
tensor=struct.integral_identity_tensor,
func_name=name,
args_exprs_str=args_exprs_str)
if struct.cpu_scalar_op and struct.acl_op:
for op in struct.cpu_scalar_op:
if op.exec_cmd != struct.aclnn_name:
do_compatibility += DO_COMPATIBILITY.substitute(aclnnname=op.exec_cmd,
func_name=name,
args_exprs_str=args_exprs_str)
tensor_arguments_list = filt_input_tensor(f.func.arguments.flat_non_out)
if struct.cpu_scalar_h2d:
scalar_set = set(struct.cpu_scalar_h2d)
tensor_arguments_list = [f"{t}_cp" if t in scalar_set else t for t in tensor_arguments_list]
tensor_arguments = ", ".join(tensor_arguments_list)
new_params_def = "".join(
[f"auto {para_name} = {para_def};\n" for para_name, para_def in struct.new_params.items()]
)
exec_cmd = "EXEC_NPU_CMD_EXT" if env_aclnn_extension_switch else "EXEC_NPU_CMD"
cpu_scalar_op_code = _gen_cpu_scalar_op_code(struct, args, exec_cmd)
cpu_scalar_h2d_code = _gen_cpu_scalar_h2d_code(struct, args)
cmd_args = _replace_cmd_args_with_copy(struct, args)
valid_param_set = (
set(struct.new_params.keys())
| set(map(lambda arg: arg.name, args))
| set(map(lambda res: res.name, res_infos))
)
aclnn_params_set = set(map(lambda arg: arg.strip(), struct.cmd_args.split(',')[1:]))
if not aclnn_params_set.issubset(valid_param_set):
raise RuntimeError(
f"exec configuration field contains invalid parameters{aclnn_params_set - valid_param_set}"
)
size_map, dtype_map = gen_size_dtype_map(res_infos)
if struct.cpu_scalar_h2d:
new_res_infos = []
for ri in res_infos:
new_size = _replace_scalar_refs_in_expr(ri.size, struct.cpu_scalar_h2d)
new_dtype = _replace_scalar_refs_in_expr(ri.dtype, struct.cpu_scalar_h2d)
new_option = _replace_scalar_refs_in_expr(ri.option, struct.cpu_scalar_h2d)
new_res_infos.append(ResInfo(
name=ri.name, size=new_size, dtype=new_dtype,
option=new_option, infer_name=ri.infer_name))
res_infos = new_res_infos
size_map, dtype_map = gen_size_dtype_map(res_infos)
define_size = "".join([f"auto {name} = {size};\n" for size, name in size_map.items()])
define_dtype = "".join([f"auto {name} = {dtype};\n" for dtype, name in dtype_map.items()])
define_size_or_dtype = "" if kind == SchemaKind.inplace else "".join([define_size, define_dtype])
if kind == SchemaKind.out:
apply_tensor_list = list(
concatMap(
lambda res_info: [
CHECK_TENSOR.substitute(
input=tensor_arguments,
result_name=res_info.name,
size=size_map[res_info.size],
dtype=dtype_map[res_info.dtype],
)
],
res_infos,
)
)
elif kind == SchemaKind.inplace:
apply_tensor_list = []
else:
apply_tensor_list = list(
concatMap(
lambda res_info: [
APPLY_TENSOR.substitute(
result_name=res_info.name,
size=size_map[res_info.size],
dtype=f'{res_info.option}.options().dtype({dtype_map[res_info.dtype]})',
)
],
res_infos,
)
)
compute_name_list = []
infer_name_list = []
for res_info in res_infos:
if res_info.infer_name is not None:
tensor_list = res_info.infer_name.split(", ")
if len(tensor_list) == 1:
names = INFER_NAME.substitute(result_name=res_info.name, infer_func=res_info.infer_name)
infer_name_list.append(names)
compute_name_list.append("")
else:
name_list = COMPUTE_NAME_GROUP.substitute(tensor_list=res_info.infer_name)
names = INFER_NAME_GROUP.substitute(result_name=res_info.name)
infer_name_list.append(names)
compute_name_list.append(name_list)
else:
infer_name_list.append("")
compute_name_list.append("")
compute_names = "".join(compute_name_list)
infer_name = "".join(infer_name_list)
apply_tensor = "".join(apply_tensor_list)
exec_cmd = "EXEC_NPU_CMD_EXT" if env_aclnn_extension_switch else "EXEC_NPU_CMD"
return [
remove_empty_lines(
ACLNN_FUNCTIONS_DEFINITION.substitute(
return_type=return_type,
func_name=name,
integral_identity_guard=integral_identity_guard,
new_params_def=new_params_def,
define_size_or_dtype=define_size_or_dtype,
args_str=args_str,
check_or_apply_tensor=apply_tensor,
do_compatibility=do_compatibility,
compute_names=compute_names,
exec_cmd=exec_cmd,
aclnnargs=cmd_args,
result=struct.return_args,
infer_name=infer_name,
cpu_scalar_op_code=cpu_scalar_op_code,
cpu_scalar_h2d_code=cpu_scalar_h2d_code
)
)
]
def gen_op_api(fm: FileManager, struct_functions: Sequence[StructInfo], env_aclnn_extension_switch: bool) -> None:
if env_aclnn_extension_switch:
includes = """#include <torch/library.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/extension.h>
#include <ATen/native/TypeProperties.h>
#include "op_plugin/include/npu_cpp_extension.h" """
else:
includes = """#include <ATen/native/TypeProperties.h>
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h" """
def struct_kernel_template_env():
return {
'includes': includes,
'op_api_definition': list(
concatMap(
lambda f: compute_op_api_definition(f, env_aclnn_extension_switch),
struct_functions,
)
),
}
fm.write_with_template(
'StructKernelNpuOpApi.cpp',
'StructKernelNpuOpApi.cpp',
struct_kernel_template_env,
)