import abc
from typing import Dict, Optional, Type
from .._C import ir
from ..language.core.constexpr import ConstExpr
from ..language.core.dtype import DataType
from ..language.core.ir_value import IRValue
from ..language.core.struct import Struct
class BaseArgType(abc.ABC):
@abc.abstractmethod
def to_ir(self) -> ir.Type:
raise NotImplementedError
class PointerArgType(BaseArgType):
def __init__(self, dtype: DataType):
self.dtype = dtype
def to_ir(self) -> ir.Type:
return ir.get_memref_type(self.dtype.to_ir(), ir.dynshape, ir.AddressSpace.gm)
class PlainArgType(BaseArgType):
def __init__(self, dtype: DataType):
self.dtype = dtype
def to_ir(self) -> ir.Type:
return self.dtype.to_ir()
class IRArgType(BaseArgType):
def __init__(self, py_type: Type[IRValue], ir_type: ir.Type):
if not issubclass(py_type, IRValue):
raise TypeError("Only IRValue can be passed between JIT functions")
self.py_type = py_type
self.ir_type = ir_type
def to_ir(self) -> ir.Type:
return self.ir_type
class StructArgType(BaseArgType):
def __init__(self, py_type: Type[Struct]):
if not issubclass(py_type, Struct):
raise TypeError("Only Struct can be passed from host to device")
self.py_type = py_type
def to_ir(self):
return ir.get_memref_type(self.py_type.get_ir_type(), ir.dynshape, ir.AddressSpace.gm)
class Specialization:
def __init__(self, args: Dict[str, BaseArgType], constexprs: Optional[Dict[str, ConstExpr]] = None):
self.args = args
self.constexprs = {} if constexprs is None else constexprs