import dataclasses
import itertools
import re
from dataclasses import dataclass
from enum import auto, Enum
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
from codegen.utils import assert_never, NamespaceHelper, OrderedSet
@dataclass(frozen=True)
class Location:
file: str
line: int
def __str__(self) -> str:
return "{}:{}".format(self.file, self.line)
class Variant(Enum):
function = auto()
method = auto()
DEFAULT_KERNEL_NAMESPACE = "at::native"
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"]
AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
"Autograd" + component for component in BACKEND_COMPONENTS
]
FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
class DispatchKey(Enum):
Undefined = 0
CatchAll = Undefined
FPGA = auto()
ORT = auto()
Vulkan = auto()
Metal = auto()
MKLDNN = auto()
OpenGL = auto()
OpenCL = auto()
IDEEP = auto()
CustomRNGKeyId = auto()
MkldnnCPU = auto()
Sparse = auto()
SparseCsrCPU = auto()
SparseCsrCUDA = auto()
Python = auto()
FuncTorchDynamicLayerBackMode = auto()
ZeroTensor = auto()
BackendSelect = auto()
Named = auto()
AutogradOther = auto()
AutogradFunctionality = auto()
AutogradNestedTensor = auto()
Tracer = auto()
Autocast = auto()
Batched = auto()
VmapMode = auto()
FuncTorchGradWrapper = auto()
FuncTorchBatched = auto()
FuncTorchVmapMode = auto()
FuncTorchDynamicLayerFrontMode = auto()
Functionalize = auto()
TESTING_ONLY_GenericWrapper = auto()
TESTING_ONLY_GenericMode = auto()
ADInplaceOrView = auto()
Autograd = auto()
CompositeImplicitAutograd = auto()
CompositeImplicitAutogradNestedTensor = auto()
CompositeExplicitAutograd = auto()
CompositeExplicitAutogradNonFunctional = auto()
FuncTorchBatchedDecomposition = auto()
CPU = auto()
CUDA = auto()
HIP = auto()
XLA = auto()
MPS = auto()
IPU = auto()
XPU = auto()
HPU = auto()
VE = auto()
Lazy = auto()
Meta = auto()
PrivateUse1 = auto()
PrivateUse2 = auto()
PrivateUse3 = auto()
QuantizedCPU = auto()
QuantizedCUDA = auto()
QuantizedHIP = auto()
QuantizedXLA = auto()
QuantizedMPS = auto()
QuantizedIPU = auto()
QuantizedXPU = auto()
QuantizedHPU = auto()
QuantizedVE = auto()
QuantizedLazy = auto()
QuantizedMeta = auto()
QuantizedPrivateUse1 = auto()
QuantizedPrivateUse2 = auto()
QuantizedPrivateUse3 = auto()
SparseCPU = auto()
SparseCUDA = auto()
SparseHIP = auto()
SparseXLA = auto()
SparseMPS = auto()
SparseIPU = auto()
SparseXPU = auto()
SparseHPU = auto()
SparseVE = auto()
SparseLazy = auto()
SparseMeta = auto()
SparsePrivateUse1 = auto()
SparsePrivateUse2 = auto()
SparsePrivateUse3 = auto()
NestedTensorCPU = auto()
NestedTensorCUDA = auto()
NestedTensorHIP = auto()
NestedTensorXLA = auto()
NestedTensorMPS = auto()
NestedTensorIPU = auto()
NestedTensorXPU = auto()
NestedTensorHPU = auto()
NestedTensorVE = auto()
NestedTensorLazy = auto()
NestedTensorMeta = auto()
NestedTensorPrivateUse1 = auto()
NestedTensorPrivateUse2 = auto()
NestedTensorPrivateUse3 = auto()
AutogradCPU = auto()
AutogradCUDA = auto()
AutogradHIP = auto()
AutogradXLA = auto()
AutogradMPS = auto()
AutogradIPU = auto()
AutogradXPU = auto()
AutogradHPU = auto()
AutogradVE = auto()
AutogradLazy = auto()
AutogradMeta = auto()
AutogradPrivateUse1 = auto()
AutogradPrivateUse2 = auto()
AutogradPrivateUse3 = auto()
def __str__(self) -> str:
return self.name
def lower(self) -> str:
return str(self).lower()
@staticmethod
def parse(value: str) -> "DispatchKey":
for k, v in DispatchKey.__members__.items():
if k == value:
return v
raise AssertionError(f"unknown dispatch key {value}")
def codegen_per_backend_entries() -> str:
local_r = []
for local_fk in FUNCTIONALITY_KEYS:
for local_bc in BACKEND_COMPONENTS:
local_r.append(f" {local_fk}{local_bc} = auto()")
return "\n".join(local_r)
for fk in FUNCTIONALITY_KEYS:
for bc in BACKEND_COMPONENTS:
if not hasattr(DispatchKey, fk + bc):
r = codegen_per_backend_entries()
print(r)
raise RuntimeError(
f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}"
)
STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU}
UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
dispatch_keys = [
DispatchKey.CPU,
DispatchKey.SparseCPU,
DispatchKey.SparseCsrCPU,
DispatchKey.MkldnnCPU,
DispatchKey.CUDA,
DispatchKey.MPS,
DispatchKey.SparseCUDA,
DispatchKey.SparseCsrCUDA,
DispatchKey.QuantizedCPU,
DispatchKey.QuantizedCUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.NestedTensorCPU,
DispatchKey.NestedTensorCUDA,
DispatchKey.Meta,
DispatchKey.SparseMeta,
DispatchKey.QuantizedMeta,
DispatchKey.NestedTensorMeta,
DispatchKey.ZeroTensor,
]
def is_generic_dispatch_key(dk: DispatchKey) -> bool:
return dk in {
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
}
def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
return dk in {
DispatchKey.CUDA,
DispatchKey.QuantizedCUDA,
DispatchKey.SparseCUDA,
DispatchKey.SparseCsrCUDA,
DispatchKey.NestedTensorCUDA,
DispatchKey.AutogradCUDA,
}
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
return dk in STRUCTURED_DISPATCH_KEYS
def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
return dk in UFUNC_DISPATCH_KEYS
class ScalarType(Enum):
Byte = auto()
Char = auto()
Short = auto()
Int = auto()
Long = auto()
Half = auto()
Float = auto()
Double = auto()
ComplexHalf = auto()
ComplexFloat = auto()
ComplexDouble = auto()
Bool = auto()
BFloat16 = auto()
def __str__(self) -> str:
return self.name
@staticmethod
def maybe_parse(value: str) -> Optional["ScalarType"]:
for k, v in ScalarType.__members__.items():
if k == value:
return v
return None
@staticmethod
def parse(value: str) -> "ScalarType":
mb_r = ScalarType.maybe_parse(value)
if mb_r is None:
raise TypeError(f"unknown dtype {value}")
return mb_r
@staticmethod
def parse_set(values: str) -> OrderedSet["ScalarType"]:
dtypes: OrderedSet[ScalarType] = OrderedSet()
for value in values.split(", "):
if value in DTYPE_CLASSES:
dtypes.update(DTYPE_CLASSES[value])
else:
dtypes.add(ScalarType.parse(value))
return dtypes
DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {}
DTYPE_CLASSES["Integral"] = OrderedSet(
[
ScalarType.Byte,
ScalarType.Char,
ScalarType.Int,
ScalarType.Long,
ScalarType.Short,
]
)
DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
DTYPE_CLASSES["Complex"] = OrderedSet(
[ScalarType.ComplexFloat, ScalarType.ComplexDouble]
)
DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
DTYPE_CLASSES["FloatingAndComplex"] = (
DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
)
class UfuncKey(Enum):
CUDAFunctor = auto()
CUDAFunctorOnOther = auto()
CUDAFunctorOnSelf = auto()
CPUScalar = auto()
CPUVector = auto()
ScalarOnly = auto()
Generic = auto()
def __str__(self) -> str:
return self.name
@staticmethod
def parse(value: str) -> "UfuncKey":
for k, v in UfuncKey.__members__.items():
if k == value:
return v
raise AssertionError(f"unknown ufunc key {value}")
class DeviceCheckType(Enum):
NoCheck = 0
ExactSame = 1
class ViewSchemaKind(Enum):
aliasing = auto()
aliasing_inplace = auto()
non_aliasing = auto()
@dataclass(frozen=True)
class NativeFunction:
func: "FunctionSchema"
impl_name: Optional[str]
impl_ns: Optional[str]
sparse: Optional[str]
internal_format_opapi: Optional[str]
use_const_ref_for_mutable_tensors: bool
structured: bool
structured_delegate: Optional["OperatorName"]
@staticmethod
def from_yaml(
ei: Dict[str, object],
ignore_keys: Optional[Set[DispatchKey]] = None,
) -> Tuple[
"NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]
]:
"""
Parse a NativeFunction from a dictionary as directly parsed
from native_functions.yaml
"""
e = ei.copy()
funcs = e.pop("func")
if not isinstance(funcs, str):
raise TypeError(f"not a str: {funcs}")
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
namespace = namespace_helper.get_cpp_namespace(default="aten")
func = FunctionSchema.parse(namespace_helper.entity_name)
impl_name = e.pop("impl_name", None)
impl_ns = e.pop("impl_ns", "").split(', ')
sparse = e.pop("sparse", None)
internal_format_opapi = e.pop("internal_format_opapi", None)
use_const_ref_for_mutable_tensors = e.pop("use_const_ref_for_mutable_tensors", False)
structured = e.pop("structured", False)
structured_delegate = e.pop("structured_delegate", None)
return (
NativeFunction(
func=func,
impl_name=impl_name,
impl_ns=impl_ns,
sparse=sparse,
internal_format_opapi=internal_format_opapi,
use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
structured=structured,
structured_delegate=structured_delegate,
),
None,
)
def validate_unstructured(self) -> None:
if self.structured:
raise ValueError("This function is structured, but there was "
"no valid functional variant of it.")
if self.structured_delegate:
raise ValueError("This function delegates to another structured out function, "
"but no valid function was found (the delegate may not exist, or it has the wrong type)")
def __post_init__(self) -> None:
pass
@property
def has_composite_kernel(self) -> bool:
return (
self.has_composite_implicit_autograd_kernel
or self.has_composite_explicit_autograd_kernel
or self.has_composite_explicit_autograd_non_functional_kernel
) or (
self.has_composite_implicit_autograd_kernel
and self.has_composite_implicit_autograd_nested_tensor_kernel
)
@property
def is_view_op(self) -> bool:
rets = self.func.returns
is_non_mutating_view = len(rets) > 0 and any(
r.annotation is not None and not r.annotation.is_write for r in rets
)
is_inplace_view = (
"inplace_view" in self.tags
and str(self.func.name) != "resize_"
and str(self.func.name) != "resize_as_"
)
is_wildcard_view = any(
inp.annotation is not None and "*" in inp.annotation.alias_set_after
for inp in self.func.schema_order_arguments()
)
return is_non_mutating_view or is_inplace_view or is_wildcard_view
@property
def view_schema_kind(self) -> ViewSchemaKind:
if self.is_view_op and self.func.name.name.inplace:
if "inplace_view" not in self.tags:
raise ValueError("inplace_view is undefined")
return ViewSchemaKind.aliasing_inplace
if self.is_view_op:
return ViewSchemaKind.aliasing
else:
return ViewSchemaKind.non_aliasing
@property
def root_name(self) -> str:
return self.func.name.name.base
@property
def part_of_structured_group(self) -> bool:
return self.structured or self.structured_delegate is not None
class SchemaKind(Enum):
functional = auto()
inplace = auto()
out = auto()
mutable = auto()
scratch = auto()
@dataclass(frozen=True)
class NativeFunctionsGroup:
functional: NativeFunction
inplace: Optional[NativeFunction]
mutable: Optional[NativeFunction]
out: NativeFunction
@property
def structured(self) -> bool:
return self.out.structured
def __post_init__(self) -> None:
test_sig: FunctionSchema = self.functional.func.signature()
for f in self.functions():
if test_sig != f.func.signature():
raise AssertionError(
"NativeFunctionsGroup constructed from two NativeFunctions "
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
)
if self.structured != f.part_of_structured_group:
raise AssertionError(
"NativeFunctionsGroup constructed from structured and unstructured "
f"functions: {self.out.func.name} and {f.func.name}"
)
if self.functional.func.kind() != SchemaKind.functional:
raise ValueError("self.functional.func.kind() != SchemaKind.functional")
if self.out.func.kind() != SchemaKind.out:
raise ValueError("self.out.func.kind() != SchemaKind.out")
if self.functional.namespace != self.out.namespace:
raise ValueError("self.functional.namespace != self.out.namespace")
if self.functional.namespace != self.out.namespace:
raise ValueError("self.functional.namespace != self.out.namespace")
if self.inplace is not None:
if self.inplace.func.kind() != SchemaKind.inplace:
raise ValueError("self.inplace.func.kind() != SchemaKind.inplace")
if self.inplace.namespace != self.functional.namespace:
raise ValueError("self.inplace.namespace != self.functional.namespace")
if self.mutable is not None:
if self.mutable.func.kind() != SchemaKind.mutable:
raise ValueError("self.mutable.func.kind() != SchemaKind.mutable")
if self.mutable.namespace != self.functional.namespace:
raise ValueError("self.mutable.namespace != self.functional.namespace")
if not self.functional.func.name.name.functional_overload:
raise ValueError("self.functional.func.name.name.functional_overload is false")
if self.structured:
if self.out.has_composite_implicit_autograd_kernel or\
self.out.has_composite_implicit_autograd_nested_tensor_kernel:
raise ValueError("self.out.has_composite_implicit_autograd_kernel" +
"or self.out.has_composite_implicit_autograd_kernel is true")
if self.functional.structured_delegate != self.out.func.name:
raise ValueError(f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
f"but its actual delegate is {self.out.func.name}")
if self.inplace is not None:
if self.inplace.structured_delegate != self.out.func.name:
raise ValueError("self.inplace.structured_delegate != self.out.func.name")
generated_fns = sorted(
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
)
generated_fns_str = ", ".join(str(x) for x in generated_fns)
expected_generated_fns: Set[str] = set()
for f in self.functions():
expected_generated_fns.update(str(op) for op in f.autogen)
expected_generated_fns_str = ", ".join(
str(x) for x in sorted(expected_generated_fns)
)
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
raise RuntimeError(
f"The codegen expects to be able to generate '{generated_fns_str}'."
" In order to generate them however, we expect them to be called out explicitly in the yaml."
f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
)
if expected_generated_fns_str != generated_fns_str:
raise RuntimeError(
f"The codegen expects to be able to generate '{generated_fns_str}'."
f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
)
def signature(self) -> "FunctionSchema":
return self.out.func.signature()
def functions(self) -> Iterator[NativeFunction]:
yield self.functional
yield self.out
if self.inplace is not None:
yield self.inplace
if self.mutable is not None:
yield self.mutable
@property
def root_name(self) -> str:
return self.functional.root_name
@staticmethod
def from_dict(
d: Dict[SchemaKind, NativeFunction]
) -> Optional["NativeFunctionsGroup"]:
if not d:
raise ValueError("d is none")
if len(d) == 1:
return None
d = dict(d)
functional = d.pop(SchemaKind.functional, None)
inplace = d.pop(SchemaKind.inplace, None)
mutable = d.pop(SchemaKind.mutable, None)
out = d.pop(SchemaKind.out, None)
if d:
raise ValueError("d is not None")
if functional is None:
raise ValueError("functional is None")
if out is None:
return None
return NativeFunctionsGroup(
functional=functional,
inplace=inplace,
mutable=mutable,
out=out,
)
@dataclass(frozen=True)
class BackendMetadata:
kernel: str
structured: bool
cpp_namespace: str
def supports_symint(self) -> bool:
return "_symint" in self.kernel
@dataclass(frozen=True)
class UfuncInnerLoop:
name: str
supported_dtypes: OrderedSet[ScalarType]
ufunc_key: UfuncKey
@staticmethod
def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop":
name, supported_dtypes_str = value.split(" ", 1)
if supported_dtypes_str[0] != "(":
raise ValueError("supported_dtypes_str[0] != /(/")
if supported_dtypes_str[-1] != ")":
raise ValueError("supported_dtypes_str[0] != /)/")
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
for k in supported_dtypes_str[1:-1].split(", "):
supported_dtypes |= ScalarType.parse_set(k)
return UfuncInnerLoop(
name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
)
@dataclass(frozen=True)
class BackendIndex:
dispatch_key: DispatchKey
use_out_as_primary: bool
device_guard: bool
external: bool
index: Dict["OperatorName", BackendMetadata]
@staticmethod
def grow_index(
parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
) -> None:
for k, v in child_index.items():
for op_name, metadata in v.items():
if op_name in parent_index[k]:
raise ValueError(f"duplicate operator {op_name} for dispatch key {k}")
parent_index[k][op_name] = metadata
def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
if self.use_out_as_primary:
return g.out
else:
return g.functional
def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
m = self.get_kernel(g)
return m is not None
def get_kernel(
self, g: Union[NativeFunction, NativeFunctionsGroup]
) -> Optional[BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
f = self.primary(g)
else:
assert_never(g)
if f.func.name not in self.index:
return None
return self.index[f.func.name]
def native_function_class_name(self) -> Optional[str]:
if self.external:
return f"{str(self.dispatch_key)}NativeFunctions"
else:
return None
@dataclass(frozen=True)
class FunctionSchema:
name: "OperatorName"
arguments: "Arguments"
returns: Tuple["Return", ...]
def schema_order_arguments(self) -> Iterator["Argument"]:
return itertools.chain(
self.arguments.flat_positional,
self.arguments.flat_kwarg_only,
self.arguments.out,
)
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
@staticmethod
def parse(func: str) -> "FunctionSchema":
decls = FunctionSchema.decl_re.findall(func)
if len(decls) != 1:
raise ValueError(f"Invalid function schema: {func}")
ops, args, return_decl = decls[0]
name = OperatorName.parse(ops)
arguments = Arguments.parse(args)
returns = parse_returns(return_decl)
func_schema_obj = FunctionSchema(name=name, arguments=arguments, returns=returns)
if str(func_schema_obj) != func:
raise ValueError(f"{str(func_schema_obj)} != {func}")
return func_schema_obj
def returns_are_aliased(self) -> bool:
return any(
r
for r in self.returns
if r.annotation is not None and r.annotation.is_write
)
def __post_init__(self) -> None:
for arg, ret in zip(self.arguments.out, self.returns):
if arg.annotation != ret.annotation:
raise ValueError("Out arguments must have matching return Tensor; furthermore, "
"the ith-argument needs to correspond to the ith return")
for a in self.arguments.post_self_positional_mutable:
if any(a.annotation == r.annotation for r in self.returns):
raise ValueError(f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}")
out_and_self = list(self.arguments.out) + [
arg for arg in self.arguments.flat_positional if arg.name == "self"
]
mutable_returns = [
ret
for ret in self.returns
if ret.annotation is not None and ret.annotation.is_write
]
immutable_returns = [
ret
for ret in self.returns
if ret.annotation is None or not ret.annotation.is_write
]
if len(mutable_returns) != 0 and len(immutable_returns) != 0:
raise ValueError(f"NativeFunctions must have either only mutable returns, " +
"or only immutable returns. Found: {str(self)}")
for ret in mutable_returns:
if not any(ret.annotation == arg.annotation for arg in out_and_self):
raise ValueError('All mutable returns must be aliased either to a keyword argument, or to "self". '
"Did you forget to mark an out argument as keyword-only?")
if self.arguments.out:
if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
if len(self.returns) != 0:
raise ValueError("out= ops that accept tensor lists as out arguments "
"are expected to have no return type (since you can't do method chaining on them)")
else:
if len([
arg
for arg in self.arguments.out
if not arg.name.startswith("_scratch_")
]) != len(self.returns):
raise ValueError("Must return as many arguments as there are out arguments, or no return at all")
if self.name.name.inplace:
self_a = self.arguments.self_arg
if not (self_a and self_a.argument.annotation
and self_a.argument.annotation.is_write):
raise ("self_a and self_a.argument.annotation"
"and self_a.argument.annotation.is_write is true.")
if self_a.argument.type == BaseType(BaseTy.Tensor):
if not (len(self.returns) == 1
and self.returns[0].annotation == self_a.argument.annotation):
raise ValueError("len(self.returns) != 1"
"or self.returns[0].annotation != self_a.argument.annotation)")
else:
if len(self.returns) != 0:
raise ValueError("len(self.returns) != 0")
if self.arguments.tensor_options is not None:
if self.kind() != SchemaKind.functional:
raise ValueError("Found an operator that is not functional or out varuabt, but has tensor options arguments."
"This is not allowed- tensor options arguments are only allowed for factory functions."
f"schema: {str(self)}")
if self.is_functional_fn():
if self.kind() != SchemaKind.functional:
raise ValueError("Found an operator that is not functional, but its overload contains the string 'functional'."
"This is a special keyword in the codegen, please use a different overload name."
f"schema: {str(self)}")
def is_functional_fn(self) -> bool:
return "functional" in self.name.overload_name
def is_out_fn(self) -> bool:
return bool(self.arguments.out)
def kind(self) -> SchemaKind:
"""
What kind of schema is this? A functional schema is one
that returns a newly allocated output; an inplace schema
modifies the self argument inplace; an out schema writes
the result into an explicitly provided out argument.
"""
is_out = bool(self.arguments.out)
is_scratch = bool(
[arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
)
is_inplace = self.name.name.inplace
is_mutable = any(
a.annotation is not None and a.annotation.is_write
for a in self.arguments.post_self_positional
)
if (is_out and is_inplace):
raise ValueError("(is_out and is_inplace)")
if is_inplace:
return SchemaKind.inplace
elif is_scratch:
if not is_out:
raise ValueError("invariant: all scratch operators are expected to be out= operators too")
return SchemaKind.scratch
elif is_out:
if is_scratch:
raise ValueError("We should not categorize a scratch op as an out variant."
"Check if the order of if statements are expected!")
return SchemaKind.out
elif is_mutable:
return SchemaKind.mutable
else:
return SchemaKind.functional
def aliased_return_names(self) -> List[Optional[str]]:
outs: List[Optional[str]] = []
for return_name in self.returns:
aliased_args = [
a
for a in self.arguments.flat_all
if a.annotation is not None and a.annotation == return_name.annotation
]
if len(aliased_args) == 0:
outs.append(None)
elif len(aliased_args) == 1:
outs.append(aliased_args[0].name)
else:
aliased_names = ", ".join(a.name for a in aliased_args)
raise AssertionError(
f"Found a return ({return_name.name})that aliases multiple inputs ({aliased_names})"
)
return outs
def signature(
self,
*,
strip_default: bool = False,
strip_view_copy_name: bool = False,
keep_return_names: bool = False,
) -> "FunctionSchema":
"""
Certain schemas are 'related', in that they are simply
inplace/out/functional versions of the same function. This method
factors these schemas into the "core" functional signature which
is equal across all versions.
Here is what normalization happens to the schema to convert
it to a signature:
- The overload name is stripped (name is retained, since
it expresses semantic content about what the function does)
- Inplace is set False
- Out arguments are stripped
- Mutable post_self_positional args are converted to returns
- Mutability annotations are stripped (this is sound
because you cannot overload on mutability annotation)
- Return names are stripped since they are not overloadable and
some variants have return names but some not
- TensorOptions are dropped
because out= variants of factory functions don't include them
(and we want to be able to pair up factory functions with their out variants)
Finally, we want to be able to pair up related "view" and their
corresponding "view_copy" operators. We do this by optionally
stripping the trailing "_copy" from the base name.
Example of a mutable op before and after:
f.func (Mutable operator):
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
f.func (Corresponding functional operator):
_fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950
f.func.signature() output:
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950
"""
def strip_ret_annotation(return_name: Return) -> Return:
return Return(
name=return_name.name if keep_return_names else None,
type=return_name.type,
annotation=None,
)
base_name = self.name.name.base
if strip_view_copy_name and base_name.endswith("_copy"):
base_name = base_name.replace("_copy", "")
returns_from_mutable_inputs = []
for a in itertools.chain(
[self.arguments.self_arg.argument]
if self.arguments.self_arg is not None
else [],
self.arguments.out,
self.arguments.post_self_positional,
):
if a.annotation is not None and a.annotation.is_write and not any(
a.annotation == r.annotation for r in self.returns):
returns_from_mutable_inputs.append(Return(
name=f"{a.name}_out" if keep_return_names else None,
type=a.type,
annotation=None,
))
original_returns = tuple(map(strip_ret_annotation, self.returns))
returns = original_returns + tuple(returns_from_mutable_inputs)
args_sig = self.arguments.signature(strip_default=strip_default)
if str(self.name) == "bernoulli.p":
args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
return FunctionSchema(
name=OperatorName(
name=BaseOperatorName(
base=base_name,
inplace=False,
dunder_method=self.name.name.dunder_method,
),
overload_name="",
),
arguments=args_sig,
returns=returns,
)
def view_signature(self) -> "FunctionSchema":
return self.signature(strip_view_copy_name=True)
def with_name(self, name: "OperatorName") -> "FunctionSchema":
return FunctionSchema(
name=name,
arguments=self.arguments,
returns=self.returns,
)
@property
def modifies_arguments(self) -> bool:
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
def has_symint(self) -> bool:
return self.arguments.has_symint_arg()
def __str__(self) -> str:
all_arguments_str = str(self.arguments)
if len(self.returns) == 1:
returns = str(self.returns[0])
else:
returns = "(" + ", ".join(map(str, self.returns)) + ")"
return f"{self.name}({all_arguments_str}) -> {returns}"
@dataclass(frozen=True)
class Annotation:
alias_set: Tuple[str, ...]
is_write: bool
alias_set_after: Tuple[str, ...]
@staticmethod
def parse(ann: str) -> "Annotation":
m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
if m is None:
raise ValueError(f"unrecognized alias annotation {ann}")
before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
alias_set = tuple(before_alias.split("|"))
is_write = m.group(3) == "!"
if is_write and len(alias_set) > 1:
raise ValueError(f"alias set larger than 1 is not mutable, got {ann} instead.")
after_set = tuple(m.group(5).split("|")) if m.group(5) else tuple()
if len(before_alias) > 1 and len(after_set) > 1:
raise ValueError(f"before alias set and after alias set cannot be larger "
"than 1 at the same time, got {ann} instead.")
annotation_obj = Annotation(
alias_set=alias_set, is_write=is_write, alias_set_after=after_set
)
if str(annotation_obj) != ann:
raise ValueError(f"{annotation_obj} != {ann}")
return annotation_obj
def __str__(self) -> str:
alias_set = "|".join(self.alias_set)
if self.is_write:
alias_set = f"{alias_set}!"
alias_set_after = "|".join(self.alias_set_after)
if alias_set_after:
alias_set = f'{alias_set}{" -> "}{alias_set_after}'
return alias_set
@dataclass(frozen=True)
class Type:
@staticmethod
def parse(t: str) -> "Type":
parse_res = Type._parse(t)
if str(parse_res) != t:
raise ValueError(f"{parse_res} != {t}")
return parse_res
@staticmethod
def _parse(t: str) -> "Type":
m = re.match(r"^(.+)\?$", t)
if m is not None:
return OptionalType(Type.parse(m.group(1)))
m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
if m is not None:
size = int(m.group(2)) if m.group(2) is not None else None
return ListType(elem=Type.parse(m.group(1)), size=size)
m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
if m is not None:
return CustomClassType(m.group(1))
try:
return BaseType(BaseTy[t])
except KeyError as e:
raise RuntimeError(f"unrecognized type {t}") from e
def __str__(self) -> str:
raise NotImplementedError
def is_base_ty_like(self, base_ty: "BaseTy") -> bool:
raise NotImplementedError
def is_tensor_like(self) -> bool:
return self.is_base_ty_like(BaseTy.Tensor)
def is_generator_like(self) -> bool:
return self.is_base_ty_like(BaseTy.Generator)
def is_symint_like(self) -> bool:
return self.is_base_ty_like(BaseTy.SymInt)
def is_nullable(self) -> bool:
raise NotImplementedError
def is_list_like(self) -> Optional["ListType"]:
raise NotImplementedError
class BaseTy(Enum):
Generator = auto()
ScalarType = auto()
Tensor = auto()
int = auto()
Dimname = auto()
DimVector = auto()
float = auto()
str = auto()
bool = auto()
Layout = auto()
Device = auto()
Scalar = auto()
MemoryFormat = auto()
QScheme = auto()
Storage = auto()
Stream = auto()
SymInt = auto()
ConstQuantizerPtr = auto()
@dataclass(frozen=True)
class BaseType(Type):
name: BaseTy
def __str__(self) -> str:
return f"{self.name.name}"
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
return self.name == base_ty
def is_nullable(self) -> bool:
return False
def is_list_like(self) -> Optional["ListType"]:
return None
def is_symint_like(self) -> bool:
return self.name == BaseTy.SymInt
@dataclass(frozen=True)
class OptionalType(Type):
elem: Type
def __str__(self) -> str:
return f"{self.elem}?"
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
return self.elem.is_base_ty_like(base_ty)
def is_symint_like(self) -> bool:
return self.elem.is_symint_like()
def is_nullable(self) -> bool:
return True
def is_list_like(self) -> Optional["ListType"]:
return self.elem.is_list_like()
@dataclass(frozen=True)
class CustomClassType(Type):
class_name: str
def __str__(self) -> str:
"""
Return the class name will prefix __torch__.torch.classes
"""
return f"__torch__.torch.classes.{self.class_name}"
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
return False
def is_symint_like(self) -> bool:
return False
def is_nullable(self) -> bool:
"""
Assume a custom class is not nullable.
"""
return False
def is_list_like(self) -> Optional["ListType"]:
return None
@dataclass(frozen=True)
class ListType(Type):
elem: Type
size: Optional[int]
def __str__(self) -> str:
size = f"{self.size}" if self.size else ""
return f"{self.elem}[{size}]"
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
return self.elem.is_base_ty_like(base_ty)
def is_symint_like(self) -> bool:
return self.elem.is_symint_like()
def is_nullable(self) -> bool:
return self.elem.is_nullable()
def is_list_like(self) -> Optional["ListType"]:
return self
@dataclass(frozen=True)
class Argument:
name: str
type: Type
default: Optional[str]
annotation: Optional[Annotation]
@staticmethod
def parse(arg: str) -> "Argument":
name: str
default: Optional[str]
type_and_annot, name_and_default = arg.rsplit(" ", 1)
if "=" in name_and_default:
name, default = name_and_default.split("=")
else:
name = name_and_default
default = None
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
annotation: Optional[Annotation]
if match:
if not match.group(2) in ["", "?", "[]"]:
raise ValueError("unrecognized alias analysis form with Tensor")
type_s = "Tensor" + match.group(2)
annotation = Annotation.parse(match.group(1))
else:
type_s = type_and_annot
annotation = None
parse_type = Type.parse(type_s)
argument_obj = Argument(
name=name,
type=parse_type,
default=default,
annotation=annotation,
)
if str(argument_obj) != arg:
raise ValueError(f"{str(argument_obj)} != {arg}")
return argument_obj
@property
def is_write(self) -> bool:
return self.annotation is not None and self.annotation.is_write
def __str__(self) -> str:
self_type = f"{self.type}"
if self.annotation:
if self_type not in ["Tensor", "Tensor?", "Tensor[]"]:
raise ValueError("Type is undefined")
self_type = self_type.replace("Tensor", f"Tensor({self.annotation})")
if self.name is None:
return self_type
else:
mb_default = ""
if self.default:
mb_default = f"={self.default}"
return f"{self_type} {self.name}{mb_default}"
@dataclass(frozen=True)
class Return:
name: Optional[str]
type: Type
annotation: Optional[Annotation]
@staticmethod
def parse(arg: str) -> "Return":
name: Optional[str]
if " " in arg:
type_and_annot, name = arg.rsplit(" ", 1)
else:
type_and_annot = arg
name = None
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
annotation: Optional[Annotation]
if match:
if match.group(2) not in ["", "?", "[]"]:
raise ValueError("unrecognized alias analysis form with Tensor")
type_s = "Tensor" + match.group(2)
annotation = Annotation.parse(match.group(1))
else:
type_s = type_and_annot
annotation = None
parse_type = Type.parse(type_s)
return_obj = Return(
name=name,
type=parse_type,
annotation=annotation,
)
if str(return_obj) != arg:
raise ValueError(f"{str(return_obj)} != {arg}")
return return_obj
@property
def is_write(self) -> bool:
return self.annotation is not None and self.annotation.is_write
def __str__(self) -> str:
self_type = f"{self.type}"
if self.annotation:
if self_type not in ["Tensor", "Tensor?", "Tensor[]"]:
raise ValueError("type not in [Tensor, Tensor?, Tensor[]]")
self_type = self_type.replace("Tensor", f"Tensor({self.annotation})")
if self.name is None:
return self_type
else:
return f"{self_type} {self.name}"
@dataclass(frozen=True)
class SelfArgument:
argument: Argument
@dataclass(frozen=True)
class TensorOptionsArguments:
dtype: Argument
layout: Argument
device: Argument
pin_memory: Argument
def all(self) -> Sequence[Argument]:
return [self.dtype, self.layout, self.device, self.pin_memory]
@dataclass(frozen=True)
class Arguments:
pre_self_positional: Tuple[Argument, ...]
self_arg: Optional[SelfArgument]
post_self_positional: Tuple[Argument, ...]
pre_tensor_options_kwarg_only: Tuple[Argument, ...]
tensor_options: Optional[TensorOptionsArguments]
post_tensor_options_kwarg_only: Tuple[Argument, ...]
out: Tuple[Argument, ...]
@property
def flat_non_out(self) -> Sequence[Argument]:
ret: List[Argument] = []
ret.extend(self.flat_positional)
ret.extend(self.flat_kwarg_only)
return ret
@property
def flat_positional(self) -> Sequence[Argument]:
ret: List[Argument] = []
ret.extend(self.pre_self_positional)
if self.self_arg is not None:
ret.append(self.self_arg.argument)
ret.extend(self.post_self_positional)
return ret
@property
def post_self_positional_mutable(self) -> Sequence[Argument]:
return [a for a in self.post_self_positional if a.is_write]
@property
def flat_kwarg_only(self) -> Sequence[Argument]:
ret: List[Argument] = []
ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None:
ret.extend(self.tensor_options.all())
ret.extend(self.post_tensor_options_kwarg_only)
return ret
@property
def flat_all(self) -> Sequence[Argument]:
ret: List[Argument] = []
ret.extend(self.flat_positional)
ret.extend(self.flat_kwarg_only)
ret.extend(self.out)
return ret
@property
def non_out(
self,
) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
ret.extend(self.positional)
ret.extend(self.kwarg_only)
return ret
@property
def positional(self) -> Sequence[Union[Argument, SelfArgument]]:
ret: List[Union[Argument, SelfArgument]] = []
ret.extend(self.pre_self_positional)
if self.self_arg is not None:
ret.append(self.self_arg)
ret.extend(self.post_self_positional)
return ret
@property
def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]:
ret: List[Union[Argument, TensorOptionsArguments]] = []
ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None:
ret.append(self.tensor_options)
ret.extend(self.post_tensor_options_kwarg_only)
return ret
@property
def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
ret.extend(self.positional)
ret.extend(self.kwarg_only)
ret.extend(self.out)
return ret
def mutable_arg_names(self) -> List[str]:
return [
a.name
for a in self.flat_all
if a.annotation is not None and a.annotation.is_write
]
def has_tensor_arg(self) -> bool:
return any(a.type.is_tensor_like() for a in self.flat_non_out)
def has_symint_arg(self) -> bool:
return any(a.type.is_symint_like() for a in self.flat_non_out)
def has_generator_arg(self) -> bool:
return any(a.type.is_generator_like() for a in self.flat_non_out)
def signature(self, *, strip_default: bool = False) -> "Arguments":
def strip_arg_annotation(a: Argument) -> Argument:
return Argument(
name=a.name,
type=a.type,
default=a.default if not strip_default else None,
annotation=None,
)
return Arguments(
pre_self_positional=tuple(
map(strip_arg_annotation, self.pre_self_positional)
),
self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
if self.self_arg is not None
else None,
post_self_positional=tuple(
map(strip_arg_annotation, self.post_self_positional)
),
pre_tensor_options_kwarg_only=tuple(
map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
)
+ tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
tensor_options=None,
post_tensor_options_kwarg_only=tuple(),
out=(),
)
def remove_self_annotation(self) -> "Arguments":
if self.self_arg is None:
raise ValueError("self.self_arg is None")
return dataclasses.replace(
self,
self_arg=SelfArgument(
dataclasses.replace(self.self_arg.argument, annotation=None)
),
)
def with_out_args(self, outs: List[Argument]) -> "Arguments":
if len(self.out) != 0:
raise ValueError("len(self.out) != 0")
return dataclasses.replace(
self,
out=tuple(outs),
)
@staticmethod
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
positional: List[Argument] = []
kwarg_only: List[Argument] = []
out: List[Argument] = []
arguments_acc = positional
for arg in args.split(", "):
if not arg:
continue
if arg == "*":
if arguments_acc is not positional:
raise TypeError("invalid syntax: kwarg-only specifier * can only occur once")
arguments_acc = kwarg_only
continue
parg = Argument.parse(arg)
if parg.annotation is not None and parg.annotation.is_write:
if arguments_acc is positional:
pass
elif arguments_acc is kwarg_only:
arguments_acc = out
else:
if arguments_acc is out:
raise TypeError("arguments_acc is not out")
arguments_acc.append(parg)
return positional, kwarg_only, out
@staticmethod
def parse(args: str) -> "Arguments":
"""
Input: 'int x, int y, int z'
"""
positional, kwarg_only, out = Arguments._preparse(args)
self_ix = None
for i, a in enumerate(positional):
if a.name == "self":
self_ix = i
break
pre_self_positional: List[Argument]
self_arg: Optional[SelfArgument]
post_self_positional: List[Argument]
if self_ix is not None:
pre_self_positional = positional[:self_ix]
self_arg = SelfArgument(positional[self_ix])
post_self_positional = positional[self_ix + 1 :]
else:
pre_self_positional = []
self_arg = None
post_self_positional = positional
pre_tensor_options_kwarg_only: List[Argument] = []
tensor_options: Optional[TensorOptionsArguments] = None
post_tensor_options_kwarg_only: List[Argument] = []
kwarg_only_acc = pre_tensor_options_kwarg_only
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
predicates = [
pred("dtype", Type.parse("ScalarType")),
pred("layout", Type.parse("Layout")),
pred("device", Type.parse("Device")),
pred("pin_memory", Type.parse("bool")),
]
i = 0
while i < len(kwarg_only):
if i <= len(kwarg_only) - len(predicates):
if all(
p(a)
for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
):
if kwarg_only_acc is not pre_tensor_options_kwarg_only:
raise ValueError("kwarg_only_acc is not pre_tensor_options_kwarg_only")
tensor_options = TensorOptionsArguments(
dtype=kwarg_only[i],
layout=kwarg_only[i + 1],
device=kwarg_only[i + 2],
pin_memory=kwarg_only[i + 3],
)
i += len(predicates)
kwarg_only_acc = post_tensor_options_kwarg_only
continue
kwarg_only_acc.append(kwarg_only[i])
i += 1
return Arguments(
pre_self_positional=tuple(pre_self_positional),
self_arg=self_arg,
post_self_positional=tuple(post_self_positional),
pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
tensor_options=tensor_options,
post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
out=tuple(out),
)
def __str__(self) -> str:
all_arguments: List[str] = []
all_arguments.extend(map(str, self.flat_positional))
if self.flat_kwarg_only or self.out:
all_arguments.append("*")
all_arguments.extend(map(str, self.flat_kwarg_only))
all_arguments.extend(map(str, self.out))
return ", ".join(all_arguments)
def __post_init__(self) -> None:
if self.self_arg is None:
if self.pre_self_positional:
raise ValueError("self.pre_self_positional is true.")
if self.tensor_options is None:
if self.post_tensor_options_kwarg_only:
raise ValueError("self.post_tensor_options_kwarg_only")
mutable_pre_self_positionals = [
a
for a in self.pre_self_positional
if a.annotation is not None and a.annotation.is_write
]
if len(mutable_pre_self_positionals) != 0:
raise ValueError("mutable pre_self_positional arguments are not currently supported in the schema")
AUGMENTED_ASSIGNMENT_NAMES = [
"add",
"sub",
"mul",
"div",
"mod",
"pow",
"lshift",
"rshift",
"and",
"xor",
"or",
]
@dataclass(frozen=True)
class BaseOperatorName:
base: str
inplace: bool
dunder_method: bool
functional_overload: bool = False
@staticmethod
def parse(op: str) -> "BaseOperatorName":
if op == "":
raise ValueError("op is an empty str.")
if op.endswith("_out"):
raise ValueError("_out suffix is reserved and not permitted for operator names; "
"did you mean to specify an out overload name instead?")
m = re.match(r"^__([^_]+)__$", op)
if m is not None:
dunder_method = True
base = m.group(1)
if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
inplace = True
base = base[1:]
else:
inplace = False
if base[0] == "i":
raise ValueError("base[0] == i")
else:
dunder_method = False
base = op
if base[-1] == "_":
inplace = True
base = base[:-1]
else:
inplace = False
functional_suffix = "_functional"
if base.endswith(functional_suffix):
functional_overload = True
base = base[: -len(functional_suffix)]
if not (not dunder_method and not inplace):
raise ValueError("(not dunder_method and not inplace) must be true")
else:
functional_overload = False
base_operator_name_obj = BaseOperatorName(
base=base,
inplace=inplace,
dunder_method=dunder_method,
functional_overload=functional_overload,
)
if str(base_operator_name_obj) != op:
raise ValueError("str(r) != op")
return base_operator_name_obj
def __str__(self) -> str:
if self.dunder_method:
i = "i" if self.inplace else ""
return f"__{i}{self.base}__"
else:
i = (
"_"
if self.inplace
else "_functional"
if self.functional_overload
else ""
)
return f"{self.base}{i}"
@dataclass(frozen=True)
class OperatorName:
name: BaseOperatorName
overload_name: str
@staticmethod
def parse(op_name: str) -> "OperatorName":
if "." in op_name:
name, overload_name = op_name.split(".", 1)
else:
name = op_name
overload_name = ""
operator_name_obj = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
if str(operator_name_obj) != op_name:
raise ValueError(f"{str(operator_name_obj)} != {op_name}")
return operator_name_obj
def __str__(self) -> str:
if self.overload_name:
return f"{self.name}.{self.overload_name}"
else:
return f"{self.name}"
def unambiguous_name(self) -> str:
if self.overload_name:
return f"{self.name}_{self.overload_name}"
else:
return f"{self.name}"
def remove_inplace(self) -> "OperatorName":
return OperatorName(
name=BaseOperatorName(
base=self.name.base,
inplace=False,
dunder_method=self.name.dunder_method,
),
overload_name=self.overload_name,
)
def with_overload(self, overload: str) -> "OperatorName":
return OperatorName(
name=BaseOperatorName(
base=self.name.base,
inplace=False,
dunder_method=self.name.dunder_method,
),
overload_name=overload,
)
def gets_generated_out_inplace_wrapper(
f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
) -> bool:
return (
f.func.kind() is not SchemaKind.functional
and not b.has_kernel(f)
and b.has_kernel(g.functional)
)
@dataclass(frozen=True)
class NativeFunctionsViewGroup:
view: NativeFunction
view_copy: Optional[NativeFunction]
view_inplace: Optional[NativeFunction]
def __post_init__(self) -> None:
if not self.view.is_view_op:
raise ValueError("self.view.is_view_op is None")
if self.view_copy is None:
if gets_generated_view_copy(self.view):
raise ValueError(f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
" The codegen expects you to add a corresponding operator to native_functions.yaml:"
f" {get_view_copy_name(self.view)!s}."
" See Note [view_copy NativeFunctions] for details.")
else:
if not self.view_copy.func.name.name.base.endswith("_copy"):
raise ValueError("self.view_copy.func.name.name.base is not end with _copy")
if self.view.func.signature() != self.view_copy.func.signature(strip_view_copy_name=True):
raise ValueError("self.view.func.signature() != "
"self.view_copy.func.signature(strip_view_copy_name=True)")
if "view_copy" not in self.view_copy.tags:
raise ValueError(f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
" view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
" See Note [view_copy NativeFunction] for details.")
if self.view_inplace is not None:
if self.view.func.signature() != self.view_inplace.func.signature():
raise ValueError("self.view.func.signature() != self.view_inplace.func.signature()")
if self.view.has_composite_implicit_autograd_kernel:
if self.view_inplace is not None:
if not self.view_inplace.has_composite_implicit_autograd_kernel:
raise ValueError(f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
" both have CompositeImplicitAutograd kernels, or both not have composite kernels.")
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
if self.view_inplace is not None:
if not self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel:
raise ValueError(f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
" both have CompositeImplicitAutogradNestedTensor kernels,"
" or both not have composite kernels.")
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
yield self.view
if self.view_inplace is not None:
yield self.view_inplace
if self.view_copy is not None and include_copy:
yield self.view_copy
@property
def root_name(self) -> str:
return self.view.root_name
@property
def composite(self) -> bool:
return self.view.has_composite_implicit_autograd_kernel
def gets_generated_view_copy(f: NativeFunction) -> bool:
if not f.is_view_op:
return False
if f.has_composite_implicit_autograd_kernel:
return False
if "inplace_view" in f.tags:
return False
return True
def get_view_copy_name(f: NativeFunction) -> "OperatorName":
list_of_ops_with_explicit_view_copy_operators = ["narrow"]
if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
if not gets_generated_view_copy(f):
raise ValueError("gets_generated_view_copy(f) is not true")
base_name = f"{f.func.name.name.base}_copy"
view_copy_name = OperatorName(
name=BaseOperatorName(
base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
),
overload_name=f.func.name.overload_name,
)
return view_copy_name
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
"""
Input: '()'
Output: []
"""
if return_decl == "()":
return ()
if return_decl[0] == "(" and return_decl[-1] == ")":
return_decl = return_decl[1:-1]
return tuple(Return.parse(arg) for arg in return_decl.split(", "))
@dataclass(frozen=True)
class Precompute:
replace: Dict[str, List[Argument]]
add: List[Argument]
@staticmethod
def parse(src: object) -> "Precompute":
if not isinstance(src, list):
raise TypeError("isinstance(src, list) is not true")
add_args = []
if " -> " not in src[-1]:
add_list = src[-1].split(",")
add_args = [Argument.parse(name.strip()) for name in add_list]
src = src[:-1]
replace = {}
for raw_replace_item in src:
if not isinstance(raw_replace_item, str):
raise TypeError("isinstance(raw_replace_item, str) is not true")
if " -> " not in raw_replace_item:
raise ValueError("precomputed parameters without replacement"
" are allowed only in the last line")
arg, with_list_raw = raw_replace_item.split(" -> ")
with_list = with_list_raw.split(",")
with_list_args = [Argument.parse(name.strip()) for name in with_list]
replace[arg] = with_list_args
precompute_obj = Precompute(replace=replace, add=add_args)
if precompute_obj.to_list() != src:
raise ValueError("r.to_list() != src")
return precompute_obj
def __post_init__(self) -> None:
for a in self.add:
if a.name.upper() == a.name:
raise ValueError("a.name.upper() == a.name")
for args in self.replace.values():
for a in args:
if a.name.upper() == a.name:
raise ValueError("a.name.upper() != a.name")
def to_list(self) -> List[str]:
replace_list = []
for kernel_param, replacement_params in self.replace.items():
replacements = ", ".join(str(param) for param in replacement_params)
replace_list.append(f"{kernel_param} -> {replacements}")
return replace_list