from dataclasses import dataclass
from typing import Optional, Union, TypeVar, List, Dict
from enum import Enum
from codegen.model import Argument, SelfArgument, TensorOptionsArguments, BaseTy
_T = TypeVar('_T')
SpecialArgName = Enum('SpecialArgName', (
'possibly_redundant_memory_format',
))
ArgName = Union[str, SpecialArgName]
@dataclass(frozen=True)
class BaseCppType:
ns: Optional[str]
name: str
def __str__(self) -> str:
if self.ns is None or self.ns == '':
return self.name
return f"{self.ns}::{self.name}"
intT = BaseCppType('', 'int64_t')
doubleT = BaseCppType('', 'double')
boolT = BaseCppType('', 'bool')
voidT = BaseCppType('', 'void')
longT = BaseCppType('', 'int64_t')
stringT = BaseCppType('c10', 'string_view')
generatorT = BaseCppType('at', 'Generator')
scalarTypeT = BaseCppType('at', 'ScalarType')
tensorT = BaseCppType('at', 'Tensor')
optionalTensorRefT = BaseCppType('at', 'OptionalTensorRef')
tensorListT = BaseCppType('at', 'TensorList')
dimnameT = BaseCppType('at', 'Dimname')
dimnameListT = BaseCppType('at', 'DimnameList')
layoutT = BaseCppType('at', 'Layout')
deviceT = BaseCppType('at', 'Device')
scalarT = BaseCppType('at', 'Scalar')
optionalScalarRefT = BaseCppType('at', 'OptionalScalarRef')
memoryFormatT = BaseCppType('at', 'MemoryFormat')
qschemeT = BaseCppType('at', 'QScheme')
storageT = BaseCppType('at', 'Storage')
streamT = BaseCppType('at', 'Stream')
intArrayRefT = BaseCppType('at', 'IntArrayRef')
tensorOptionsT = BaseCppType('at', 'TensorOptions')
typeAndSizeT = BaseCppType('torch::autograd::generated', 'TypeAndSize')
tensorGeometryT = BaseCppType('at', 'TensorGeometry')
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.int: intT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
BaseTy.str: stringT,
BaseTy.Generator: generatorT,
BaseTy.ScalarType: scalarTypeT,
BaseTy.Tensor: tensorT,
BaseTy.Dimname: dimnameT,
BaseTy.Layout: layoutT,
BaseTy.Device: deviceT,
BaseTy.Scalar: scalarT,
BaseTy.MemoryFormat: memoryFormatT,
BaseTy.QScheme: qschemeT,
BaseTy.Storage: storageT,
BaseTy.Stream: streamT,
}
@dataclass(frozen=True)
class BaseCType:
type: BaseCppType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return str(self.type)
def cpp_type_registration_declarations(self) -> str:
return str(self.type).replace('at::', '')
def remove_const_ref(self) -> 'CType':
return self
@dataclass(frozen=True)
class ConstRefCType:
elem: 'CType'
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f'const {self.elem.cpp_type()} &'
def cpp_type_registration_declarations(self) -> str:
return f'const {self.elem.cpp_type_registration_declarations()} &'
def remove_const_ref(self) -> 'CType':
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class MutRefCType:
elem: 'CType'
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
return self.elem.cpp_type(strip_ref=strip_ref)
return f'{self.elem.cpp_type()} &'
def cpp_type_registration_declarations(self) -> str:
return f'{self.elem.cpp_type_registration_declarations()} &'
def remove_const_ref(self) -> 'CType':
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class OptionalCType:
elem: 'CType'
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f'c10::optional<{self.elem.cpp_type()}>'
def cpp_type_registration_declarations(self) -> str:
return f'c10::optional<{self.elem.cpp_type_registration_declarations()}>'
def remove_const_ref(self) -> 'CType':
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ListCType:
elem: 'CType'
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f'c10::List<{self.elem.cpp_type()}>'
def cpp_type_registration_declarations(self) -> str:
return f'c10::List<{self.elem.cpp_type_registration_declarations()}>'
def remove_const_ref(self) -> 'CType':
return ListCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType:
elem: 'CType'
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f'at::ArrayRef<{self.elem.cpp_type()}>'
def cpp_type_registration_declarations(self) -> str:
return f'ArrayRef<{self.elem.cpp_type_registration_declarations()}>'
def remove_const_ref(self) -> 'CType':
return ArrayRefCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class VectorCType:
elem: 'CType'
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f'::std::vector<{self.elem.cpp_type()}>'
def cpp_type_registration_declarations(self) -> str:
return f'::std::vector<{self.elem.cpp_type_registration_declarations()}>'
def remove_const_ref(self) -> 'CType':
return VectorCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayCType:
elem: 'CType'
size: int
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f'::std::array<{self.elem.cpp_type()},{self.size}>'
def cpp_type_registration_declarations(self) -> str:
return f'::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>'
def remove_const_ref(self) -> 'CType':
return ArrayCType(self.elem.remove_const_ref(), self.size)
@dataclass(frozen=True)
class TupleCType:
elems: List['CType']
def cpp_type(self, *, strip_ref: bool = False) -> str:
return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>'
def cpp_type_registration_declarations(self) -> str:
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
def remove_const_ref(self) -> 'CType':
return TupleCType([e.remove_const_ref() for e in self.elems])
CType = Union[
BaseCType,
OptionalCType,
ConstRefCType,
MutRefCType,
ListCType,
ArrayRefCType,
ArrayCType,
VectorCType,
TupleCType
]
@dataclass(frozen=True)
class NamedCType:
name: ArgName
type: CType
def cpp_type(self, *, strip_ref: bool = False) -> str:
return self.type.cpp_type(strip_ref=strip_ref)
def cpp_type_registration_declarations(self) -> str:
return self.type.cpp_type_registration_declarations()
def remove_const_ref(self) -> 'NamedCType':
return NamedCType(self.name, self.type.remove_const_ref())
def with_name(self, name: str) -> 'NamedCType':
return NamedCType(name, self.type)
@dataclass(frozen=True)
class Binding:
name: str
nctype: NamedCType
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
default: Optional[str] = None
@property
def type(self) -> str:
return self.nctype.cpp_type()
def no_default(self) -> 'Binding':
return Binding(
name=self.name,
nctype=self.nctype,
default=None,
argument=self.argument,
)
def decl(self, *, func_ptr_cast: bool = False) -> str:
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
if func_ptr_cast:
return f"{self.type}"
else:
return f"{self.type} {self.name}{mb_default}"
def decl_registration_declarations(self) -> str:
type_s = self.nctype.cpp_type_registration_declarations()
mb_default = ""
if self.default is not None:
mb_default = f"={self.default}"
return f"{type_s} {self.name}{mb_default}"
def defn(self) -> str:
return f"{self.type} {self.name}"
@dataclass(frozen=True)
class Expr:
expr: str
type: NamedCType