from typing import Optional, Sequence, Union, List, Set
from codegen.model import (Argument, Arguments, BaseTy, BaseType,
FunctionSchema, ListType, NativeFunction,
OptionalType, Return, SelfArgument,
TensorOptionsArguments, Type, assert_never)
from codegen.api.types import (ArgName, BaseCType, Binding, ConstRefCType, NamedCType, CType,
MutRefCType, ArrayCType, ListCType, VectorCType, ArrayRefCType,
OptionalCType, TupleCType, SpecialArgName, boolT, scalarT,
tensorListT, dimnameListT, tensorT, voidT,
BaseTypeToCppMapping, intArrayRefT, tensorOptionsT)
from codegen import local
def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
func_name = str(func.name.name)
if func.is_out_fn():
if faithful_name_for_out_overloads:
func_name += '_outf'
else:
func_name += '_out'
return func_name
def valuetype_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if str(t.elem) == 'bool':
if t.size is None:
raise ValueError("t.size is None")
return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
r = valuetype_type(t, binds=binds)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == 'Tensor':
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(tensorT))))
elif str(t.elem) == 'Scalar':
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
type_dict = {
"int": BaseCType(intArrayRefT),
"Tensor": BaseCType(tensorListT),
"Scalar": ArrayRefCType(BaseCType(scalarT)),
"Dimname": BaseCType(dimnameListT),
"Tensor?": ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
}
if str(t.elem) in type_dict:
return NamedCType(binds, type_dict[str(t.elem)])
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
def returntype_type(t: Type, *, mutable: bool) -> CType:
r = valuetype_type(t, binds="__placeholder__")
if r is not None:
return r.type
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
elif mutable:
return MutRefCType(BaseCType(tensorT))
else:
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
elem = returntype_type(t.elem, mutable=mutable)
if t.size is not None:
raise ValueError(f"fixed size list returns not supported: {t}")
return VectorCType(elem)
raise AssertionError(f"unrecognized return type {t}")
def return_type(r: Return) -> CType:
return returntype_type(r.type, mutable=r.is_write)
def returns_type(rs: Sequence[Return]) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0])
else:
return TupleCType([return_type(r) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = 'result') -> Sequence[str]:
returns: List[str] = []
for i, r in enumerate(f.func.returns):
if f.func.name.name.inplace:
if i != 0:
raise ValueError("illegal inplace function with multiple returns")
func_name = 'self'
elif f.func.is_out_fn():
func_name = f.func.arguments.out[i].name
elif r.name:
name_conflict = any(r.name == a.name for a in f.func.schema_order_arguments())
if name_conflict and not f.func.is_out_fn():
func_name = f'{r.name}_return'
else:
func_name = r.name
else:
func_name = fallback_name if len(f.func.returns) == 1 else f'{fallback_name}{i}'
returns.append(func_name)
return returns
JIT_TO_CPP_DEFAULT = {
'False': 'false',
'True': 'true',
'None': 'c10::nullopt',
'Mean': 'at::Reduction::Mean',
'[]': '{}',
'contiguous_format': 'c10::MemoryFormat::Contiguous',
'long': 'at::kLong',
}
def default_expr(d: str, t: Type) -> str:
def deal_str_basetype(d):
s = ''
i = 1
while i + 1 < len(d):
if d[i] != '\\':
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i:i + 2]
i += 2
return f'"{s}"'
if d == 'None' and str(t) == 'Tensor?':
return '{}'
if isinstance(t, BaseType) and t.name is BaseTy.str:
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
return deal_str_basetype(d)
if isinstance(t, OptionalType):
if d == 'None':
return 'c10::nullopt'
return default_expr(d, t.elem)
if isinstance(t, ListType):
if (d.startswith('[') and d.endswith(']')):
return '{' + d[1:-1] + '}'
elif t.size is None:
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
def argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument],
*, cpp_no_default_args: Set[str], method: bool, faithful: bool,
has_tensor_options: bool
) -> List[Binding]:
def sub_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]:
return argument(
a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful,
has_tensor_options=has_tensor_options)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: Optional[str] = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [Binding(
nctype=argument_type(a, binds=binds),
name=a.name,
default=default,
argument=a,
)]
elif isinstance(a, TensorOptionsArguments):
if faithful:
return sub_argument(a.dtype) + sub_argument(a.layout) + \
sub_argument(a.device) + sub_argument(a.pin_memory)
else:
default = None
if 'options' in cpp_no_default_args:
raise KeyError("'options' in cpp_no_default_args")
if all(x.default == "None" for x in a.all()):
default = '{}'
elif a.dtype.default == "long":
default = 'at::kLong'
return [Binding(
nctype=NamedCType('options', BaseCType(tensorOptionsT)),
name='options',
default=default,
argument=a,
)]
elif isinstance(a, SelfArgument):
if method:
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
func_arguments: Arguments,
*, faithful: bool, method: bool, cpp_no_default_args: Set[str]
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
args.extend(func_arguments.non_out)
args.extend(func_arguments.out)
else:
args.extend(func_arguments.out)
args.extend(func_arguments.non_out)
return [
r.no_default() if faithful else r for a in args
for r in argument(
a, faithful=faithful, method=method,
has_tensor_options=func_arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args)
]