from typing import List, Dict, Sequence
import copy
from dataclasses import dataclass
from torchnpugen.model import BaseTy, SchemaKind, BaseType, Argument, NativeFunction, ListType
from torchnpugen.context import native_function_manager
from torchnpugen.api.types import NativeSignature
from torchnpugen.api import cpp
STRUCTURED_GEN_OPAPI_ALLOWED_KEYS = {'new_params', 'exec'}
def filt_input_tensor(arguments: Sequence[Argument]) -> List[str]:
input_tensors = []
for arg in arguments:
if isinstance(arg.type, BaseType) and arg.type.name == BaseTy.Tensor:
input_tensors.append(arg.name)
if isinstance(arg.type, ListType) and arg.type.elem.name == BaseTy.Tensor:
input_tensors.append(f'{arg.name}[0]')
return input_tensors
@dataclass(frozen=True)
class ResInfo:
name: str
size: str
dtype: str
option: str
infer_name: str = None
@staticmethod
def parse(results: Dict[str, Dict[str, str]], f: 'NativeFunction') -> List['ResInfo']:
kind = f.func.kind()
tensor_number = sum(map(lambda x: x.type.name == BaseTy.Tensor, f.func.returns))
if len(results) != tensor_number and kind == SchemaKind.functional:
raise RuntimeError(
f"The number of result info in yaml is {len(results)}."
f"That does not match {f.func.name}'s returns tensor number {tensor_number}"
)
arguments = f.func.arguments.flat_all
input_tensors = filt_input_tensor(arguments)
res_infos = []
if kind == SchemaKind.out:
result_names = cpp.return_names(f)
else:
result_names = list(results)
for name in result_names:
info = results.get(name, None)
if kind == SchemaKind.out and info is None:
size, dtype = name, name
infer_name = None
elif info is None:
raise RuntimeError(f"The '{name}' is missing in results for {f.func.name}")
elif kind == SchemaKind.inplace:
size = info.pop('size', None)
dtype = info.pop('dtype', None)
infer_name = info.pop('name', None)
else:
size = info.pop('size', None)
dtype = info.pop('dtype', None)
size = name if size is None and kind == SchemaKind.out else size
dtype = name if dtype is None and kind == SchemaKind.out else dtype
if size is None:
raise RuntimeError(f"The {name}'s size is None in {f.func.name}")
if dtype is None:
raise RuntimeError(f"The {name}'s dtype is None in {f.func.name}")
infer_name = info.pop('name', None)
if size in input_tensors:
size_formula = f'{size}.sizes()'
else:
size_formula = size
if dtype in input_tensors:
dtype_formula = f'{dtype}.scalar_type()'
else:
dtype_formula = dtype
res_infos.append(
ResInfo(
name=name, size=size_formula, dtype=dtype_formula, option=input_tensors[0], infer_name=infer_name
)
)
return res_infos
@dataclass(frozen=True)
class CpuScalarOp:
param: str
exec_cmd: str
@staticmethod
def parse(op_list: List[Dict[str, str]]) -> List['CpuScalarOp']:
ops = []
for op in op_list:
param = op.get('param', '')
exec_cmd = op.get('exec', '')
ops.append(CpuScalarOp(param=param, exec_cmd=exec_cmd))
return ops
@dataclass(frozen=True)
class StructInfo:
name: str
structured_inherit: 'NativeFunction' = None
aclnn_name: str = None
func: 'NativeFunction' = None
cmd_args: str = None
return_args: str = None
acl_op: bool = None
results: List[ResInfo] = None
new_params: Dict[str, str] = None
integral_identity_tensor: str = None
cpu_scalar_h2d: List[str] = None
cpu_scalar_op: List[CpuScalarOp] = None
use_structured_meta: bool = False
@staticmethod
def from_yaml(
es: Sequence[Dict[str, object]],
native_functions: Sequence[NativeFunction],
) -> "List[StructInfo]":
"""
Parse a StructInfo from a dictionary as directly parsed
from op_plugin_functions.yaml
"""
functions_by_schema: Dict[str, NativeFunction] = {}
for function in native_functions:
if str(function.func) in functions_by_schema:
raise RuntimeError(f"{function.func} has multiple definitions in op_plugin_functions.yaml")
functions_by_schema[str(function.func)] = function
def gen_func_name(schema: str) -> str:
return schema.split('(')[0]
def get_return_arguments(func: NativeFunction, results: List[ResInfo]) -> List[str]:
kind = func.func.kind()
if kind == SchemaKind.inplace:
return [func.func.arguments.self_arg.argument.name]
if kind == SchemaKind.out and len(results) == 0:
return cpp.return_names(func)
return [result.name for result in results]
def format_return_args(func: NativeFunction, return_argument: List[str]) -> str:
kind = func.func.kind()
if len(return_argument) == 0:
return ''
if len(return_argument) == 1:
return ''.join([' ', return_argument[0]])
if kind == SchemaKind.out:
return ''.join([' ', f"std::forward_as_tuple({', '.join(return_argument)})"])
move_args = ', '.join(f'std::move({arg})' for arg in return_argument)
return ''.join([' ', f'std::make_tuple({move_args})'])
funcname_map: Dict[str, NativeFunction] = {}
struct_map: Dict[str, Dict] = {}
for e in es:
e.pop('acl_op', None)
e.pop('op_api', None)
schema_str = e.get('func', None)
op_name = gen_func_name(schema_str)
funcname_map[op_name] = functions_by_schema.get(schema_str)
struct_map[op_name] = copy.deepcopy(e)
struct_infos: List[StructInfo] = []
for e in es:
schema_str = e.pop('func', None)
defn_name = gen_func_name(schema_str)
schema_function = functions_by_schema.get(schema_str)
if not schema_function:
avail = "\n".join(k for k in functions_by_schema if gen_func_name(k) == defn_name)
raise RuntimeError(
f"could not find ATen function for schema: {schema_str} . Available signatures:\n{avail}"
)
func_kind = schema_function.func.kind()
if 'op_api' not in schema_function.impl_ns:
raise RuntimeError(
f"The Aten function {schema_str} has no op_opi implement in op_plugin_functions yaml"
)
acl_op = 'acl_op' in schema_function.impl_ns
gen_opapi_info = e.get('gen_opapi')
if not isinstance(gen_opapi_info, dict):
raise RuntimeError(f"The Aten function {schema_str} has invalid gen_opapi configuration")
structured_inherit = gen_opapi_info.pop('structured_inherit', None)
if structured_inherit is not None:
if func_kind == SchemaKind.inplace:
struct_info = StructInfo(
name=defn_name,
func=schema_function,
structured_inherit=funcname_map.get(structured_inherit),
acl_op=acl_op,
)
struct_infos.append(struct_info)
continue
gen_opapi_info = struct_map.get(structured_inherit)
if gen_opapi_info is None:
raise RuntimeError(f'The structured_inherit func {structured_inherit} is None')
gen_opapi_info = gen_opapi_info.get('gen_opapi')
integral_identity_tensor = gen_opapi_info.pop('integral_identity_tensor', None)
aclnn_arguments = gen_opapi_info.pop('exec', None)
if aclnn_arguments is None:
raise RuntimeError(f"The Aten function {schema_str}'s gen_opapi has no exec configuration")
aclnn_arguments_list = [argument.strip() for argument in aclnn_arguments.split(',')]
aclnn_name = aclnn_arguments_list[0]
new_params_dict = gen_opapi_info.pop('new_params', dict())
cpu_scalar_h2d = gen_opapi_info.pop('cpu_scalar_h2d', None)
cpu_scalar_op_raw = gen_opapi_info.pop('cpu_scalar_op', None)
cpu_scalar_op = CpuScalarOp.parse(cpu_scalar_op_raw) if cpu_scalar_op_raw else None
cmd_args_expand = False
if len(aclnn_arguments_list) == 1:
cmd_args_expand = True
with native_function_manager(schema_function):
sig = NativeSignature(schema_function.func, prefix='', symint=False)
for a in sig.arguments():
aclnn_arguments_list.append(a.name)
aclnn_arguments = ', '.join(aclnn_arguments_list)
if func_kind == SchemaKind.out:
output_names = cpp.return_names(schema_function)
for key in gen_opapi_info:
if key not in output_names:
raise ValueError(f"Result infomations contains invalid key: {key} in {schema_str}")
use_structured_meta = bool(e.get('structured', False))
results = [] if use_structured_meta else ResInfo.parse(gen_opapi_info, schema_function)
return_argument = get_return_arguments(schema_function, results)
return_args = format_return_args(schema_function, return_argument)
if cmd_args_expand and func_kind == SchemaKind.functional:
aclnn_arguments = f"{aclnn_arguments}, {', '.join(return_argument)}"
struct_info = StructInfo(
name=defn_name,
aclnn_name=aclnn_name,
cmd_args=aclnn_arguments,
func=schema_function,
results=results,
return_args=return_args,
acl_op=acl_op,
new_params=new_params_dict,
integral_identity_tensor=integral_identity_tensor,
cpu_scalar_h2d=cpu_scalar_h2d,
cpu_scalar_op=cpu_scalar_op,
use_structured_meta=use_structured_meta
)
struct_infos.append(struct_info)
return struct_infos