from __future__ import annotations
from typing import overload
from ..._C import ir
from ..core.dtype import DataType, KnownTypes as KT
from ..core.ir_value import GlobalAddress, IRValue, IRHandle, PlainValue, RuntimeInt, materialize_ir_value as _mat
from ..core.utils import require_jit, global_builder
from ..core.tensor import GlobalTensor
from .utils import set_tensor_docstring
class TensorDesc(IRValue):
@overload
def __init__(self, dtype: DataType = KT.float32) -> None:
...
@overload
def __init__(self, handle: IRHandle, dtype: DataType = KT.float32) -> None:
...
def __init__(self, *args, **kwargs) -> None:
dtype = kwargs.pop("dtype", KT.float32)
if 'handle' in kwargs:
self.handle = kwargs['handle']
self.dtype = dtype
return
if global_builder.get_ir_builder() is not None:
self.__initjit__(dtype=dtype)
return
builder = global_builder.get_ir_builder()
self.dtype = dtype
self.handle = builder.create_asc_TensorDescOp(
builder.get_asc_TensorDescType(),
dtype.to_ir()
)
@require_jit
def __initjit__(self, dtype: DataType = KT.float32) -> None:
builder = global_builder.get_ir_builder()
self.dtype = dtype
self.handle = builder.create_asc_TensorDescOp(
builder.get_asc_TensorDescType(),
dtype.to_ir()
)
@classmethod
def from_ir(cls, handle: IRHandle, dtype: DataType = KT.float32) -> TensorDesc:
return cls(handle=handle, dtype=dtype)
def to_ir(self) -> IRHandle:
return self.handle
@overload
def set_shape_addr(self, shape_ptr: int) -> None:
...
@require_jit
@set_tensor_docstring(tensor_name="TensorDesc", api_name="set_shape_addr")
def set_shape_addr(self, shape_ptr: RuntimeInt) -> None:
builder = global_builder.get_ir_builder()
self.handle = builder.create_asc_TensorDescSetShapeAddrOp(self.to_ir(), _mat(shape_ptr, KT.uint64).to_ir())
@require_jit
@set_tensor_docstring(tensor_name="TensorDesc", api_name="get_dim")
def get_dim(self) -> RuntimeInt:
builder = global_builder.get_ir_builder()
handle = builder.create_asc_TensorDescGetDimOp(builder.get_ui64_type(), self.to_ir())
return PlainValue(handle)
@require_jit
@set_tensor_docstring(tensor_name="TensorDesc", api_name="get_index")
def get_index(self) -> RuntimeInt:
builder = global_builder.get_ir_builder()
handle = builder.create_asc_TensorDescGetIndexOp(builder.get_ui64_type(), self.to_ir())
return PlainValue(handle)
@require_jit
@set_tensor_docstring(tensor_name="TensorDesc", api_name="get_shape")
def get_shape(self, offset: RuntimeInt) -> RuntimeInt:
builder = global_builder.get_ir_builder()
handle = builder.create_asc_TensorDescGetShapeOp(
builder.get_i64_type(),
self.to_ir(),
_mat(offset, KT.int32).to_ir()
)
return PlainValue(handle)
@require_jit
@set_tensor_docstring(tensor_name="TensorDesc", api_name="get_data_ptr")
def get_data_ptr(self) -> GlobalAddress:
builder = global_builder.get_ir_builder()
element_type = self.dtype.to_ir()
handle = builder.create_asc_TensorDescGetDataPtrOp(
ir.get_unranked_memref_type(element_type, ir.AddressSpace.gm),
self.to_ir()
)
return GlobalAddress(handle, self.dtype)
@require_jit
@set_tensor_docstring(tensor_name="TensorDesc", api_name="get_data_obj")
def get_data_obj(self) -> GlobalTensor:
builder = global_builder.get_ir_builder()
element_type = self.dtype.to_ir()
handle = builder.create_asc_TensorDescGetDataObjOp(
ir.get_global_tensor_type(element_type),
self.to_ir()
)
return GlobalTensor(handle)
class ListTensorDesc(IRValue):
@overload
def __init__(self) -> None:
...
@overload
def __init__(self, data: GlobalAddress, length: int = 0xffffffff, shape_size: int = 0xffffffff) -> None:
...
@overload
def __init__(self, handle: IRHandle) -> None:
"""This contructor should not be called by user"""
...
def __init__(self, *args, **kwargs) -> None:
if 'handle' in kwargs:
self.handle = kwargs['handle']
return
builder = global_builder.get_ir_builder()
if 'data' in kwargs:
data = kwargs['data']
length = kwargs['length']
shape_size = kwargs['shape_size']
self.handle = builder.create_asc_ListTensorDescV2Op(builder.get_asc_ListTensorDescType(), data.to_ir(),
_mat(length, KT.uint32).to_ir(),
_mat(shape_size, KT.uint32).to_ir())
return
self.handle = builder.create_asc_ListTensorDescOp(builder.get_asc_ListTensorDescType())
@classmethod
def from_ir(cls, handle: IRHandle) -> ListTensorDesc:
return cls(handle=handle)
def to_ir(self) -> IRHandle:
return self.handle
@overload
def init(self, data: GlobalAddress, length: int = 0xffffffff, shape_size: int = 0xffffffff) -> None:
...
@require_jit
@set_tensor_docstring(tensor_name="ListTensorDesc", api_name="init")
def init(self, data: GlobalAddress, length: RuntimeInt = 0xffffffff, shape_size: RuntimeInt = 0xffffffff) -> None:
builder = global_builder.get_ir_builder()
self.handle = builder.create_asc_ListTensorDescInitOp(self.to_ir(), data.to_ir(),
_mat(length, KT.uint32).to_ir(),
_mat(shape_size, KT.uint32).to_ir())
@overload
def get_desc(self, desc: TensorDesc, index: int) -> None:
...
@require_jit
@set_tensor_docstring(tensor_name="ListTensorDesc", api_name="get_desc")
def get_desc(self, desc: TensorDesc, index: RuntimeInt) -> None:
builder = global_builder.get_ir_builder()
self.handle = builder.create_asc_ListTensorDescGetDescOp(self.to_ir(), desc.to_ir(),
_mat(index, KT.uint32).to_ir())
@overload
def get_data_ptr(self, index: int, dtype: DataType) -> GlobalAddress:
...
@require_jit
@set_tensor_docstring(tensor_name="ListTensorDesc", api_name="get_data_ptr")
def get_data_ptr(self, index: RuntimeInt, dtype: DataType) -> GlobalAddress:
builder = global_builder.get_ir_builder()
ga_type = ir.get_unranked_memref_type(dtype.to_ir(), ir.AddressSpace.gm)
handle = builder.create_asc_ListTensorDescGetDataPtrOp(ga_type, self.to_ir(), _mat(index, KT.uint32).to_ir(),
dtype.to_ir())
return GlobalAddress(handle)
@overload
def get_size(self) -> int:
...
@require_jit
@set_tensor_docstring(tensor_name="ListTensorDesc", api_name="get_size")
def get_size(self) -> RuntimeInt:
builder = global_builder.get_ir_builder()
self.handle = builder.create_asc_ListTensorDescGetSizeOp(builder.get_ui32_type(), self.to_ir())