import re
from dataclasses import dataclass
from typing import List, Dict, Optional, Iterator, Tuple, Set, NoReturn, Sequence, Callable, Union
from enum import Enum, auto
import itertools
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
@dataclass(frozen=True)
class Location:
file: str
line: int
def __str__(self) -> str:
return "{}:{}".format(self.file, self.line)
Variant = Enum('Variant', ('function', 'method'))
class DispatchKey(Enum):
Undefined = 0
CatchAll = Undefined
CPU = auto()
CUDA = auto()
Math = auto()
HIP = auto()
FPGA = auto()
ORT = auto()
XLA = auto()
Lazy = auto()
Vulkan = auto()
Metal = auto()
XPU = auto()
NPU = auto()
MKLDNN = auto()
OpenGL = auto()
OpenCL = auto()
IDEEP = auto()
QuantizedCPU = auto()
QuantizedCUDA = auto()
QuantizedXPU = auto()
QuantizedNPU = auto()
CustomRNGKeyId = auto()
MkldnnCPU = auto()
SparseCPU = auto()
SparseCUDA = auto()
SparseCsrCPU = auto()
SparseCsrCUDA = auto()
SparseHIP = auto()
SparseXPU = auto()
SparseNPU = auto()
NestedTensor = auto()
PrivateUse1 = auto()
PrivateUse2 = auto()
PrivateUse3 = auto()
EndOfBackendKeys = PrivateUse3
Unsupport = auto()
ZeroTensor = auto()
Meta = auto()
BackendSelect = auto()
Named = auto()
AutogradOther = auto()
AutogradCPU = auto()
AutogradCUDA = auto()
AutogradXLA = auto()
AutogradLazy = auto()
AutogradNestedTensor = auto()
AutogradXPU = auto()
AutogradNPU = auto()
AutogradPrivateUse1 = auto()
AutogradPrivateUse2 = auto()
AutogradPrivateUse3 = auto()
Tracer = auto()
Autocast = auto()
Batched = auto()
VmapMode = auto()
TESTING_ONLY_GenericWrapper = auto()
TESTING_ONLY_GenericMode = auto()
NumDispatchKeys = auto()
Autograd = auto()
CompositeImplicitAutograd = auto()
CompositeExplicitAutograd = auto()
EndOfAliasKeys = CompositeExplicitAutograd
DefaultBackend = CompositeExplicitAutograd
CPUTensorId = CPU
CUDATensorId = CUDA
PrivateUse1_PreAutograd = AutogradPrivateUse1
PrivateUse2_PreAutograd = AutogradPrivateUse2
PrivateUse3_PreAutograd = AutogradPrivateUse3
def __str__(self) -> str:
return self.name
def lower(self) -> str:
return str(self).lower()
@staticmethod
def parse(value: str) -> 'DispatchKey':
for k, v in DispatchKey.__members__.items():
if k == value:
return v
raise AssertionError(f'unknown dispatch key {value}')
class UseC10Dispatcher(Enum):
full = 0
hacky_wrapper_for_legacy_signatures = 1
STRUCTURED_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
def is_generic_dispatch_key(dk: DispatchKey) -> bool:
return dk in {DispatchKey.CompositeExplicitAutograd, DispatchKey.CompositeImplicitAutograd}
def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
return dk in {
DispatchKey.CUDA,
DispatchKey.QuantizedCUDA,
DispatchKey.SparseCUDA,
DispatchKey.SparseCsrCUDA,
DispatchKey.AutogradCUDA,
DispatchKey.CUDATensorId,
}
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
return dk in STRUCTURED_DISPATCH_KEYS
class DeviceCheckType(Enum):
NoCheck = 0
ExactSame = 1
class Tag(Enum):
inplace_view = 0
def __str__(self) -> str:
return self.name
@staticmethod
def parse(value: str) -> 'Tag':
for k, v in Tag.__members__.items():
if k == value:
return v
raise AssertionError(f'unknown tag {value}')
@dataclass(frozen=True)
class NativeFunction:
func: 'FunctionSchema'
use_const_ref_for_mutable_tensors: bool
device_guard: bool
device_check: DeviceCheckType
python_module: Optional[str]
category_override: Optional[str]
variants: Set[Variant]
manual_kernel_registration: bool
manual_cpp_binding: bool
structured: bool
structured_delegate: Optional['OperatorName']
structured_inherits: Optional[str]
precomputed: Optional['Precompute']
cpp_no_default_args: Set[str]
is_abstract: bool
has_composite_implicit_autograd_kernel: bool
has_composite_explicit_autograd_kernel: bool
tag: Optional['Tag']
op_api: bool
@staticmethod
def from_yaml(ei: Dict[str, object]) -> Tuple['NativeFunction',
Dict[DispatchKey, Dict['OperatorName', 'BackendMetadata']]]:
e = ei.copy()
def parse_func(e):
funcs = e.pop('func')
if not isinstance(funcs, str):
raise TypeError(f'not a str: {funcs}')
func = FunctionSchema.parse(funcs)
return func
def parse_cpp_no_default_args(e):
cpp_no_default_args_list = e.pop('cpp_no_default_args', [])
if not isinstance(cpp_no_default_args_list, list):
raise TypeError(f'not a list: {cpp_no_default_args_list}')
cpp_no_default_args = set(cpp_no_default_args_list)
return cpp_no_default_args
def parse_variants(e):
variants_s = e.pop('variants', 'function')
if not isinstance(variants_s, str):
raise TypeError(f'not a str: {variants_s}')
variants: Set[Variant] = set()
for v in variants_s.split(', '):
if v == 'function':
variants.add(Variant.function)
elif v == 'method':
variants.add(Variant.method)
else:
raise AssertionError(f'illegal variant {v}')
return variants
func = parse_func(e)
cpp_no_default_args = parse_cpp_no_default_args(e)
variants = parse_variants(e)
use_const_ref_for_mutable_tensors = e.pop('use_const_ref_for_mutable_tensors', False)
manual_kernel_registration = e.pop('manual_kernel_registration', False)
manual_cpp_binding = e.pop('manual_cpp_binding', False)
device_guard = e.pop('device_guard', True)
if not isinstance(use_const_ref_for_mutable_tensors, bool) or not isinstance(manual_kernel_registration, bool):
raise TypeError(f'exists non-bool value.')
if not isinstance(manual_cpp_binding, bool) or not isinstance(device_guard, bool):
raise TypeError(f'exists non-bool value.')
def parse_device_check(e):
device_check_s = e.pop('device_check', None)
if device_check_s is not None and not isinstance(device_check_s, str):
raise TypeError(f'not a str: {device_check_s}')
device_check: DeviceCheckType
if device_check_s is None:
device_check = DeviceCheckType.ExactSame
else:
device_check = DeviceCheckType[device_check_s]
return device_check
device_check = parse_device_check(e)
structured = e.pop('structured', False)
if not isinstance(structured, bool):
raise TypeError(f'not a bool: {structured}')
def parse_structured_delegate(e):
structured_delegate_s = e.pop('structured_delegate', None)
if structured_delegate_s is not None and not isinstance(structured_delegate_s, str):
raise TypeError(f'not a str: {structured_delegate_s}')
structured_delegate: Optional[OperatorName] = None
if structured_delegate_s is not None:
structured_delegate = OperatorName.parse(structured_delegate_s)
return structured_delegate
structured_delegate = parse_structured_delegate(e)
structured_inherits = e.pop('structured_inherits', None)
if structured_inherits is not None and not isinstance(structured_inherits, str):
raise TypeError(f'not a str: {structured_inherits}')
python_module = e.pop('python_module', None)
if python_module is not None and not isinstance(python_module, str):
raise TypeError(f'not a str: {python_module}')
category_override = e.pop('category_override', None)
if category_override is not None and not isinstance(category_override, str):
raise TypeError(f'not a str: {category_override}')
def parse_precomputed(e, structured):
precomputed_dict = e.pop('precomputed', None)
if precomputed_dict is not None and structured is False:
raise TypeError("structured is False.")
precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
return precomputed
precomputed = parse_precomputed(e, structured)
def parse_tag(e, structured):
tag_str = e.pop('tags', None)
if tag_str is not None and not isinstance(tag_str, str):
raise TypeError(f'not a str: {tag_str}')
tag = Tag.parse(tag_str) if tag_str else None
return tag
tag = parse_tag(e, structured)
op_api = e.pop("op_api", False)
from codegen.api import cpp
def parse_dispatch(e, func, manual_kernel_registration, structured_delegate, structured):
raw_dispatch = e.pop('dispatch', None)
if raw_dispatch is not None and not isinstance(raw_dispatch, dict):
raise TypeError("raw_dispatch is not dict.")
dispatch: Dict[DispatchKey, str] = {}
if raw_dispatch is not None:
if manual_kernel_registration:
raise ValueError("cannot specify both manual_kernel_registration and dispatch; with " \
"manual registration, dispatch has no effect!")
for ks, v in raw_dispatch.items():
if ks == '__line__':
continue
if not isinstance(ks, str):
raise TypeError(f'not a str: {ks}')
if not isinstance(v, str):
raise TypeError(f'not a str: {v}')
for k in ks.split(","):
dispatch_key = DispatchKey.parse(k.strip())
dispatch[dispatch_key] = v
if dispatch == {DispatchKey.CompositeImplicitAutograd: cpp.name(func)}:
raise ValueError("unnecessary dispatch table for this function; just delete the dispatch " \
"key entirely")
if not structured_delegate and dispatch.keys() == {DispatchKey.CompositeImplicitAutograd}:
raise ValueError(f"unexpected name for singleton CompositeImplicitAutograd dispatch entry:" \
f" expected {cpp.name(func)} but got {dispatch[DispatchKey.CompositeImplicitAutograd]}." \
f" Rename your implementation to the expected name, then delete the dispatch table")
elif not structured and structured_delegate is None:
dispatch[DispatchKey.CompositeImplicitAutograd] = cpp.name(func)
if DispatchKey.CompositeExplicitAutograd in dispatch and \
DispatchKey.CompositeImplicitAutograd in dispatch:
raise KeyError("cannot specify both CompositeExplicitAutograd and CompositeImplicitAutograd on a single kernel;" \
" each strictly subsumes the other. If you wanted to provide an explicit autograd " \
"implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only")
return dispatch
dispatch = parse_dispatch(e, func, manual_kernel_registration, structured_delegate, structured)
is_abstract = True if structured_delegate else dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
has_composite_implicit_autograd_kernel = DispatchKey.CompositeImplicitAutograd in dispatch.keys()
has_composite_explicit_autograd_kernel = DispatchKey.CompositeExplicitAutograd in dispatch.keys()
backend_metadata = {k: {func.name: BackendMetadata(
kernel=v, structured=structured and is_structured_dispatch_key(k))} for k, v in dispatch.items()}
def assert_last(e, structured_delegate, dispatch):
e.pop('__line__', None)
if e:
raise ValueError(f"leftover entries: {e}")
if structured_delegate is not None:
for key in STRUCTURED_DISPATCH_KEYS:
if key in dispatch:
raise KeyError(f"if structured_delegate, then must not have {key} in dispatch dictionary " \
"(it is delegated!)")
assert_last(e, structured_delegate, dispatch)
return NativeFunction(func=func,
use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
variants=variants, structured=structured, structured_delegate=structured_delegate,
structured_inherits=structured_inherits, precomputed=precomputed,
manual_kernel_registration=manual_kernel_registration, manual_cpp_binding=manual_cpp_binding,
python_module=python_module, category_override=category_override, device_guard=device_guard,
device_check=device_check, cpp_no_default_args=cpp_no_default_args, is_abstract=is_abstract,
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, tag=tag,
op_api=op_api), backend_metadata
def validate_unstructured(self) -> None:
if self.structured:
raise ValueError("This function is structured, but there was " \
"no valid functional variant of it.")
if not self.structured_delegate:
raise ValueError("This function delegates to another structured out function, " \
"but no valid function was found (the delegate may not exist, or it has the wrong type)")
def __post_init__(self) -> None:
if self.func.arguments.out:
if self.variants != {Variant.function}:
raise ValueError("Native functions with out arguments MUST " \
"be declared with only function variant; e.g., variants: function; " \
"otherwise you will tickle a Python argument binding bug " \
"(which usually manifests itself as the result variable being undefined.)")
if self.structured:
if self.func.kind() != SchemaKind.out:
raise ValueError("Put structured field on the out= " \
"variant of a function; did you mean structured_delegate?")
if not self.device_guard:
raise ValueError("device_guard: False is not respected by structured kernels")
if self.structured_delegate:
if self.func.kind() == SchemaKind.out:
raise ValueError("structured_delegate field not allowed " \
"on out= functions; did you mean structured?")
if not self.device_guard:
raise ValueError("device_guard: False is not respected by structured kernels")
if self.structured and self.structured_delegate:
raise ValueError("Cannot have both structured and structured_delegate on function")
defaulted_arguments = {a.name for a in self.func.schema_order_arguments()
if a.default is not None}
invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
if len(invalid_args) != 0:
raise ValueError(f'Invalid cpp_no_default_args: {invalid_args}')
if self.structured_inherits is not None:
if not self.structured:
raise ValueError("structured_inherits must also imply structured: True")
if str(self.func.name).startswith('_foreach'):
if self.device_check != DeviceCheckType.NoCheck:
raise ValueError("foreach kernels fall back to slow path when tensor are on different devices, " \
"device_check not allowed to be enabled")
@property
def has_composite_kernel(self) -> bool:
return self.has_composite_implicit_autograd_kernel or self.has_composite_explicit_autograd_kernel
@property
def is_view_op(self) -> bool:
rets = self.func.returns
is_non_mutating_view = len(rets) > 0 and any(r.annotation is not None
and not r.annotation.is_write for r in rets)
is_inplace_view = self.tag is not None and self.tag is Tag.inplace_view
is_wildcard_view = any(inp.annotation is not None and
inp.annotation.alias_set_after != "" for inp in self.func.schema_order_arguments())
return is_non_mutating_view or is_inplace_view or is_wildcard_view
@property
def root_name(self) -> str:
return self.func.name.name.base
SchemaKind = Enum('SchemaKind', ('functional', 'inplace', 'out'))
@dataclass(frozen=True)
class NativeFunctionsGroup:
functional: NativeFunction
inplace: Optional[NativeFunction]
out: NativeFunction
@property
def structured(self) -> bool:
return self.out.structured
def __post_init__(self) -> None:
test_sig: FunctionSchema = self.functional.func.signature()
for f in self.functions():
if test_sig != f.func.signature():
raise AssertionError(
"NativeFunctionsGroup constructed from two NativeFunctions "
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
)
if self.functional.func.kind() != SchemaKind.functional:
raise ValueError("self.functional.func.kind() != SchemaKind.functional")
if self.out.func.kind() != SchemaKind.out:
raise ValueError("self.out.func.kind() != SchemaKind.out")
if self.inplace is not None:
if self.inplace.func.kind() != SchemaKind.inplace:
raise ValueError("self.inplace.func.kind() != SchemaKind.inplace")
if self.structured:
if self.out.has_composite_implicit_autograd_kernel:
raise ValueError("self.out.has_composite_implicit_autograd_kernel")
if self.functional.structured_delegate != self.out.func.name:
raise ValueError(f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " \
f"but its actual delegate is {self.out.func.name}")
if self.inplace is not None:
if self.inplace.structured_delegate != self.out.func.name:
raise ValueError("self.inplace.structured_delegate != self.out.func.name")
def signature(self) -> 'FunctionSchema':
return self.out.func.signature()
def functions(self) -> Iterator[NativeFunction]:
yield self.functional
yield self.out
if self.inplace is not None:
yield self.inplace
@staticmethod
def from_dict(d: Dict[SchemaKind, NativeFunction]) -> Optional['NativeFunctionsGroup']:
if not d:
raise ValueError("not d")
if len(d) == 1:
return None
d = dict(d)
functional = d.pop(SchemaKind.functional, None)
inplace = d.pop(SchemaKind.inplace, None)
out = d.pop(SchemaKind.out, None)
if d:
raise ValueError("d")
if functional is None:
raise ValueError("functional is None")
if out is None:
return None
return NativeFunctionsGroup(
functional=functional,
inplace=inplace,
out=out,
)
def is_foreach_op(name: str) -> bool:
return str(name) in set([
'_amp_foreach_non_finite_check_and_unscale_',
'_foreach_add_.ScalarList',
'_foreach_sub_.ScalarList',
'_foreach_mul_.ScalarList',
'_foreach_div_.ScalarList',
'_foreach_add_.Scalar',
'_foreach_sub_.Scalar',
'_foreach_mul_.Scalar',
'_foreach_div_.Scalar',
'_foreach_add_.List',
'_foreach_sub_.List',
'_foreach_mul_.List',
'_foreach_div_.List',
'_foreach_exp_',
'_foreach_sqrt_',
'_foreach_abs_',
'_foreach_acos_',
'_foreach_asin_',
'_foreach_atan_',
'_foreach_ceil_',
'_foreach_cos_',
'_foreach_cosh_',
'_foreach_erf_',
'_foreach_erfc_',
'_foreach_expm1_',
'_foreach_floor_',
'_foreach_log_',
'_foreach_log10_',
'_foreach_log1p_',
'_foreach_log2_',
'_foreach_neg_',
'_foreach_tan_',
'_foreach_tanh_',
'_foreach_sin_',
'_foreach_sinh_',
'_foreach_round_',
'_foreach_lgamma_',
'_foreach_frac_',
'_foreach_reciprocal_',
'_foreach_sigmoid_',
'_foreach_trunc_',
'_foreach_addcmul_.Scalar',
'_foreach_addcdiv_.Scalar',
'_foreach_addcmul_.ScalarList',
'_foreach_addcdiv_.ScalarList',
'_foreach_zero_'])
@dataclass(frozen=True)
class BackendMetadata:
kernel: str
structured: bool
@dataclass(frozen=True)
class BackendIndex:
dispatch_key: DispatchKey
use_out_as_primary: bool
external: bool
index: Dict['OperatorName', BackendMetadata]
@staticmethod
def grow_index(
parent_index: Dict[DispatchKey, Dict['OperatorName', BackendMetadata]],
child_index: Dict[DispatchKey, Dict['OperatorName', BackendMetadata]]
) -> None:
for k, v in child_index.items():
for op_name, metadata in v.items():
if op_name in parent_index[k]:
raise ValueError(f'duplicate operator {op_name} for dispatch key {k}')
parent_index[k][op_name] = metadata
def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
if self.use_out_as_primary:
return g.out
else:
return g.functional
def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
m = self.get_kernel(g)
return m is not None
def get_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> Optional[BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
f = self.primary(g)
else:
assert_never(g)
if f.func.name not in self.index:
return None
return self.index[f.func.name]
def native_function_class_name(self) -> Optional[str]:
if self.external:
return f'{str(self.dispatch_key)}NativeFunctions'
else:
return None
@dataclass(frozen=True)
class FunctionSchema:
name: 'OperatorName'
arguments: 'Arguments'
returns: Tuple['Return', ...]
def schema_order_arguments(self) -> Iterator['Argument']:
return itertools.chain(
self.arguments.flat_positional,
self.arguments.flat_kwarg_only,
self.arguments.out
)
@staticmethod
def parse(func: str) -> 'FunctionSchema':
if ' -> ' not in func:
raise ValueError("function schema missing return type (spaces are mandatory)")
last_index = func.rfind(" -> ")
func_decl = func[:last_index]
return_decl = func[last_index + len(" -> "):]
ops, args = func_decl.split('(', 1)
if args[-1] != ")":
raise ValueError("Expecting closing )")
args = args[:-1]
name = OperatorName.parse(ops)
arguments = Arguments.parse(args)
returns = parse_returns(return_decl)
r = FunctionSchema(
name=name,
arguments=arguments,
returns=returns
)
if str(r) != func:
raise ValueError(f'{str(r)} != {func}')
return r
def __post_init__(self) -> None:
for arg, ret in zip(self.arguments.out, self.returns):
if arg.annotation != ret.annotation:
raise ValueError("Out arguments must have matching return Tensor; furthermore, " \
"the ith-argument needs to correspond to the ith return")
out_and_self = list(self.arguments.out) + [arg for arg in self.arguments.flat_positional if arg.name == "self"]
mutable_returns = [ret for ret in self.returns if ret.annotation is not None and ret.annotation.is_write]
for ret in mutable_returns:
if not any([ret.annotation == arg.annotation for arg in out_and_self]):
raise ValueError("All mutable returns must be aliased either to a keyword argument, or to \"self\". " \
"Did you forget to mark an out argument as keyword-only?")
if self.arguments.out:
if len(self.arguments.out) != len(self.returns):
raise ValueError("Must return as many arguments as there are out arguments")
if self.name.name.inplace:
if not is_foreach_op(str(self.name)):
if len(self.returns) != 1:
raise ValueError("len(self.returns) != 1")
def is_out_fn(self) -> bool:
return bool(self.arguments.out)
def kind(self) -> SchemaKind:
"""
What kind of schema is this? A functional schema is one
that returns a newly allocated output; an inplace schema
modifies the self argument inplace; an out schema writes
the result into an explicitly provided out argument.
"""
is_inplace = self.name.name.inplace
is_out = bool(self.arguments.out)
if is_inplace and is_out:
raise ValueError("is_inplace and is_out")
if is_inplace:
return SchemaKind.inplace
elif is_out:
return SchemaKind.out
else:
return SchemaKind.functional
def signature(self, *, strip_default: bool = False) -> 'FunctionSchema':
"""
Certain schemas are 'related', in that they are simply
inplace/out/functional versions of the same function. This method
factors these schemas into the "core" functional signature which
is equal across all versions.
Here is what normalization happens to the schema to convert
it to a signature:
- The overload name is stripped (name is retained, since
it expresses semantic content about what the function does)
- Inplace is set False
- Out arguments are stripped
- Mutability annotations are stripped (this is sound
because you cannot overload on mutability annotation)
- Return names are stripped since they are not overloadable and
some variants have return names but some not
"""
def strip_ret_annotation(r: Return) -> Return:
return Return(
name=None,
type=r.type,
annotation=None,
)
return FunctionSchema(
name=OperatorName(
name=BaseOperatorName(
base=self.name.name.base,
inplace=False,
dunder_method=self.name.name.dunder_method,
),
overload_name="",
),
arguments=self.arguments.signature(strip_default=strip_default),
returns=tuple(map(strip_ret_annotation, self.returns)),
)
def __str__(self) -> str:
all_arguments_str = str(self.arguments)
if len(self.returns) == 1:
returns = str(self.returns[0])
else:
returns = '(' + ', '.join(map(str, self.returns)) + ')'
return f'{self.name}({all_arguments_str}) -> {returns}'
@dataclass(frozen=True)
class Annotation:
alias_set: Tuple[str, ...]
is_write: bool
alias_set_after: str
@staticmethod
def parse(ann: str) -> 'Annotation':
becomes_wildcard_index = ann.find(" -> *")
if becomes_wildcard_index != -1:
after_set = "*"
m = re.match(r'^([a-z])(!?)(!?)$',
ann[:becomes_wildcard_index] + ann[becomes_wildcard_index + len(" -> *"):])
else:
after_set = ""
m = re.match(r'^([a-z])(!?)(!?)$', ann)
if m is None:
raise ValueError(f'unrecognized alias annotation {ann}')
alias_set = (m.group(1),)
is_write = m.group(2) == '!'
r = Annotation(alias_set=alias_set, is_write=is_write, alias_set_after=after_set)
if str(r) != ann:
raise ValueError(f'{r} != {ann}')
return r
def __str__(self) -> str:
alias_set = '|'.join(self.alias_set)
if self.alias_set_after:
alias_set = f'{alias_set}{" -> "}{self.alias_set_after}'
is_write = '!' if self.is_write else ''
return f'{alias_set}{is_write}'
@dataclass(frozen=True)
class Type:
@staticmethod
def parse(t: str) -> 'Type':
r = Type._parse(t)
if str(r) != t:
raise ValueError(f'{r} != {t}')
return r
@staticmethod
def _parse(t: str) -> 'Type':
m = re.match(r'^(.+)\?$', t)
if m is not None:
return OptionalType(Type.parse(m.group(1)))
m = re.match(r'^(.+)\[([0-9]+)?\]$', t)
if m is not None:
size = int(m.group(2)) if m.group(2) is not None else None
return ListType(elem=Type.parse(m.group(1)), size=size)
try:
return BaseType(BaseTy[t])
except KeyError as e:
raise RuntimeError(f"unrecognized type {t}") from e
def __str__(self) -> str:
raise NotImplementedError
def is_tensor_like(self) -> bool:
raise NotImplementedError
def is_nullable(self) -> bool:
raise NotImplementedError
def is_list_like(self) -> Optional['ListType']:
raise NotImplementedError
BaseTy = Enum('BaseTy', (
'Generator',
'ScalarType',
'Tensor',
'int',
'Dimname',
'float',
'str',
'bool',
'Layout',
'Device',
'Scalar',
'MemoryFormat',
'QScheme',
'Storage',
'Stream',
'ConstQuantizerPtr',
))
@dataclass(frozen=True)
class BaseType(Type):
name: BaseTy
def __str__(self) -> str:
return f'{self.name.name}'
def is_tensor_like(self) -> bool:
return self.name == BaseTy.Tensor
def is_nullable(self) -> bool:
return False
def is_list_like(self) -> Optional['ListType']:
return None
@dataclass(frozen=True)
class OptionalType(Type):
elem: Type
def __str__(self) -> str:
return f'{self.elem}?'
def is_tensor_like(self) -> bool:
return self.elem.is_tensor_like()
def is_nullable(self) -> bool:
return True
def is_list_like(self) -> Optional['ListType']:
return self.elem.is_list_like()
@dataclass(frozen=True)
class ListType(Type):
elem: Type
size: Optional[int]
def __str__(self) -> str:
size = f'{self.size}' if self.size else ''
return f'{self.elem}[{size}]'
def is_tensor_like(self) -> bool:
return self.elem.is_tensor_like()
def is_nullable(self) -> bool:
return self.elem.is_nullable()
def is_list_like(self) -> Optional['ListType']:
return self
@dataclass(frozen=True)
class Argument:
name: str
type: Type
default: Optional[str]
annotation: Optional[Annotation]
@staticmethod
def parse(arg: str) -> 'Argument':
name: str
default: Optional[str]
type_and_annot, name_and_default = arg.rsplit(' ', 1)
if '=' in name_and_default:
name, default = name_and_default.split('=')
else:
name = name_and_default
default = None
match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot)
annotation: Optional[Annotation]
if match:
if match.group(2) not in ['', '?', '[]']:
raise ValueError('unrecognized alias analysis form with Tensor')
type_s = 'Tensor' + match.group(2)
annotation = Annotation.parse(match.group(1))
else:
type_s = type_and_annot
annotation = None
func_type = Type.parse(type_s)
r = Argument(
name=name,
type=func_type,
default=default,
annotation=annotation,
)
if str(r) != arg:
raise ValueError(f'{str(r)} != {arg}')
return r
@property
def is_write(self) -> bool:
return self.annotation is not None and self.annotation.is_write
def __str__(self) -> str:
func_type = f'{self.type}'
if self.annotation:
if func_type not in ['Tensor', 'Tensor?', 'Tensor[]']:
raise KeyError("func_type not in ['Tensor', 'Tensor?', 'Tensor[]']")
func_type = func_type.replace('Tensor', f'Tensor({self.annotation})')
if self.name is None:
return func_type
else:
mb_default = ''
if self.default:
mb_default = f'={self.default}'
return f"{func_type} {self.name}{mb_default}"
@dataclass(frozen=True)
class Return:
name: Optional[str]
type: Type
annotation: Optional[Annotation]
@staticmethod
def parse(arg: str) -> 'Return':
name: Optional[str]
if ' ' in arg:
type_and_annot, name = arg.rsplit(' ', 1)
else:
type_and_annot = arg
name = None
match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot)
annotation: Optional[Annotation]
if match:
if match.group(2) not in ['', '?', '[]']:
raise KeyError('unrecognized alias analysis form with Tensor')
type_s = 'Tensor' + match.group(2)
annotation = Annotation.parse(match.group(1))
else:
type_s = type_and_annot
annotation = None
func_type = Type.parse(type_s)
r = Return(
name=name,
type=func_type,
annotation=annotation,
)
if str(r) != arg:
raise ValueError(f'{str(r)} != {arg}')
return r
@property
def is_write(self) -> bool:
return self.annotation is not None and self.annotation.is_write
def __str__(self) -> str:
func_type = f'{self.type}'
if self.annotation:
if func_type not in ['Tensor', 'Tensor?', 'Tensor[]']:
raise KeyError("func_type not in ['Tensor', 'Tensor?', 'Tensor[]']")
func_type = func_type.replace('Tensor', f'Tensor({self.annotation})')
if self.name is None:
return func_type
else:
return f"{func_type} {self.name}"
@dataclass(frozen=True)
class SelfArgument:
argument: Argument
@dataclass(frozen=True)
class TensorOptionsArguments:
dtype: Argument
layout: Argument
device: Argument
pin_memory: Argument
def all(self) -> Sequence[Argument]:
return [self.dtype, self.layout, self.device, self.pin_memory]
@dataclass(frozen=True)
class Arguments:
pre_self_positional: Tuple[Argument, ...]
self_arg: Optional[SelfArgument]
post_self_positional: Tuple[Argument, ...]
pre_tensor_options_kwarg_only: Tuple[Argument, ...]
tensor_options: Optional[TensorOptionsArguments]
post_tensor_options_kwarg_only: Tuple[Argument, ...]
out: Tuple[Argument, ...]
@property
def flat_non_out(self) -> Sequence[Argument]:
ret: List[Argument] = []
ret.extend(self.flat_positional)
ret.extend(self.flat_kwarg_only)
return ret
@property
def flat_positional(self) -> Sequence[Argument]:
ret: List[Argument] = []
ret.extend(self.pre_self_positional)
if self.self_arg is not None:
ret.append(self.self_arg.argument)
ret.extend(self.post_self_positional)
return ret
@property
def flat_kwarg_only(self) -> Sequence[Argument]:
ret: List[Argument] = []
ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None:
ret.extend(self.tensor_options.all())
ret.extend(self.post_tensor_options_kwarg_only)
return ret
@property
def non_out(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
ret.extend(self.positional)
ret.extend(self.kwarg_only)
return ret
@property
def positional(self) -> Sequence[Union[Argument, SelfArgument]]:
ret: List[Union[Argument, SelfArgument]] = []
ret.extend(self.pre_self_positional)
if self.self_arg is not None:
ret.append(self.self_arg)
ret.extend(self.post_self_positional)
return ret
@property
def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]:
ret: List[Union[Argument, TensorOptionsArguments]] = []
ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None:
ret.append(self.tensor_options)
ret.extend(self.post_tensor_options_kwarg_only)
return ret
def signature(self, *, strip_default: bool = False) -> 'Arguments':
def strip_arg_annotation(a: Argument) -> Argument:
return Argument(
name=a.name,
type=a.type,
default=a.default if not strip_default else None,
annotation=None,
)
return Arguments(
pre_self_positional=tuple(map(strip_arg_annotation, self.pre_self_positional)),
self_arg=SelfArgument(
strip_arg_annotation(self.self_arg.argument)
) if self.self_arg is not None else None,
post_self_positional=tuple(map(strip_arg_annotation, self.post_self_positional)),
pre_tensor_options_kwarg_only=tuple(map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)),
tensor_options=self.tensor_options,
post_tensor_options_kwarg_only=tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
out=(),
)
@staticmethod
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
positional: List[Argument] = []
kwarg_only: List[Argument] = []
out: List[Argument] = []
arguments_acc = positional
for arg in args.split(', '):
if not arg:
continue
if arg == '*':
if arguments_acc is not positional:
raise ValueError("invalid syntax: kwarg-only specifier * can only occur once")
arguments_acc = kwarg_only
continue
parg = Argument.parse(arg)
if parg.annotation is not None and parg.annotation.is_write:
if arguments_acc is positional:
pass
elif arguments_acc is kwarg_only:
arguments_acc = out
else:
if arguments_acc is out:
raise ValueError("arguments_acc is out")
arguments_acc.append(parg)
return positional, kwarg_only, out
@staticmethod
def parse(args: str) -> 'Arguments':
positional, kwarg_only, out = Arguments._preparse(args)
self_ix = None
for i, a in enumerate(positional):
if a.name == "self":
self_ix = i
break
pre_self_positional: List[Argument]
self_arg: Optional[SelfArgument]
post_self_positional: List[Argument]
if self_ix is not None:
pre_self_positional = positional[:self_ix]
self_arg = SelfArgument(positional[self_ix])
post_self_positional = positional[self_ix + 1:]
else:
pre_self_positional, self_arg, post_self_positional = [], None, positional
pre_tensor_options_kwarg_only: List[Argument] = []
tensor_options: Optional[TensorOptionsArguments] = None
post_tensor_options_kwarg_only: List[Argument] = []
kwarg_only_acc = pre_tensor_options_kwarg_only
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
predicates = [
pred('dtype', Type.parse('ScalarType')), pred('layout', Type.parse('Layout')),
pred('device', Type.parse('Device')), pred('pin_memory', Type.parse('bool')),
]
i = 0
while i < len(kwarg_only):
if i <= len(kwarg_only) - len(predicates):
if all(p(a) for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])):
if kwarg_only_acc is not pre_tensor_options_kwarg_only:
raise ValueError("kwarg_only_acc is not pre_tensor_options_kwarg_only")
tensor_options = TensorOptionsArguments(
dtype=kwarg_only[i],
layout=kwarg_only[i + 1],
device=kwarg_only[i + 2],
pin_memory=kwarg_only[i + 3],
)
i += len(predicates)
kwarg_only_acc = post_tensor_options_kwarg_only
continue
kwarg_only_acc.append(kwarg_only[i])
i += 1
return Arguments(
pre_self_positional=tuple(pre_self_positional),
self_arg=self_arg,
post_self_positional=tuple(post_self_positional),
pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
tensor_options=tensor_options,
post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
out=tuple(out),
)
def __str__(self) -> str:
all_arguments: List[str] = []
all_arguments.extend(map(str, self.flat_positional))
if self.flat_kwarg_only or self.out:
all_arguments.append('*')
all_arguments.extend(map(str, self.flat_kwarg_only))
all_arguments.extend(map(str, self.out))
return ', '.join(all_arguments)
def __post_init__(self) -> None:
if self.self_arg is None:
if self.pre_self_positional:
raise ValueError("self.pre_self_positional is True.")
if self.tensor_options is None:
if self.post_tensor_options_kwarg_only:
raise ValueError("self.post_tensor_options_kwarg_only is True.")
AUGMENTED_ASSIGNMENT_NAMES = ['add', 'sub', 'mul', 'div', 'mod', 'pow', 'lshift', 'rshift', 'and', 'xor', 'or']
@dataclass(frozen=True)
class BaseOperatorName:
base: str
inplace: bool
dunder_method: bool
@staticmethod
def parse(op: str) -> 'BaseOperatorName':
if op == '':
raise ValueError("op == ''")
if op.endswith('_out'):
raise ValueError("_out suffix is reserved and not permitted for operator names; " \
"did you mean to specify an out overload name instead?")
m = re.match(r'^__([^_]+)__$', op)
if m is not None:
dunder_method = True
base = m.group(1)
if any(base == f'i{n}' for n in AUGMENTED_ASSIGNMENT_NAMES):
inplace = True
base = base[1:]
else:
inplace = False
if base[0] == 'i':
raise ValueError("base[0] == 'i'")
else:
dunder_method = False
base = op
if base[-1] == '_':
inplace = True
base = base[:-1]
else:
inplace = False
r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method)
if str(r) != op:
raise ValueError(f'{str(r)} != {op}')
return r
def __str__(self) -> str:
if self.dunder_method:
i = 'i' if self.inplace else ''
return f'__{i}{self.base}__'
else:
i = '_' if self.inplace else ''
return f'{self.base}{i}'
@dataclass(frozen=True)
class OperatorName:
name: BaseOperatorName
overload_name: str
@staticmethod
def parse(op_name: str) -> 'OperatorName':
if '.' in op_name:
name, overload_name = op_name.split('.', 1)
else:
name = op_name
overload_name = ''
r = OperatorName(
name=BaseOperatorName.parse(name),
overload_name=overload_name
)
if str(r) != op_name:
raise ValueError(f'{str(r)} != {op_name}')
return r
def __str__(self) -> str:
if self.overload_name:
return f"{self.name}.{self.overload_name}"
else:
return f"{self.name}"
def unambiguous_name(self) -> str:
if self.overload_name:
return f"{self.name}_{self.overload_name}"
else:
return f"{self.name}"
def gets_generated_out_inplace_wrapper(f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex) -> bool:
return f.func.kind() is not SchemaKind.functional and \
not b.has_kernel(f) and \
b.has_kernel(g.functional)
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
"""
Input: '()'
Output: []
"""
if return_decl == '()':
return ()
if return_decl[0] == '(' and return_decl[-1] == ')':
return_decl = return_decl[1:-1]
return tuple(Return.parse(arg) for arg in return_decl.split(', '))
@dataclass(frozen=True)
class Precompute:
replace: Dict[str, List[Argument]]
@staticmethod
def parse(src: object) -> 'Precompute':
if not isinstance(src, list):
raise TypeError("src is not list.")
replace = {}
for raw_replace_item in src:
if not isinstance(raw_replace_item, str):
raise TypeError("raw_replace_item is not str.")
arg, with_list_raw = raw_replace_item.split(' -> ')
with_list = with_list_raw.split(',')
with_list_args = [Argument.parse(name.strip()) for name in with_list]
replace[arg] = with_list_args
r = Precompute(replace=replace)
if r.to_list() != src:
raise ValueError('r.to_list() != src')
return r
def to_list(self) -> List[str]:
replace_list = []
for kernel_param, replacement_params in self.replace.items():
replacements = ', '.join(str(param) for param in replacement_params)
replace_list.append(f'{kernel_param} -> {replacements}')
return replace_list