from typing import overload
from ..._C import ir
from ..core.aipp_types import AippParams
from ..core.dtype import DataType, KnownTypes, KnownTypes as KT
from ..core.enums import AippInputFormat, MaskMode, TPosition, AtomicDtype, AtomicOp
from ..core.ir_value import GlobalAddress, PlainValue, materialize_ir_value as _mat, RuntimeBool, RuntimeInt
from ..core.tensor import LocalTensor, GlobalTensor
from ..core.types import CheckLocalMemoryIAParam
from ..core.utils import require_jit, global_builder, OverloadDispatcher
from .utils import set_common_docstring
@require_jit
def ascend_is_aic() -> PlainValue:
builder = global_builder.get_ir_builder()
handle = builder.create_asc_AscendIsAICOp(builder.get_i1_type())
return PlainValue(handle)
@require_jit
def ascend_is_aiv() -> PlainValue:
builder = global_builder.get_ir_builder()
handle = builder.create_asc_AscendIsAIVOp(builder.get_i1_type())
return PlainValue(handle)
@overload
def data_cache_preload(src: GlobalTensor, cache_offset: int) -> None:
...
@require_jit
@set_common_docstring(api_name="data_cache_preload")
def data_cache_preload(src: GlobalTensor, cache_offset: RuntimeInt) -> None:
builder = global_builder.get_ir_builder()
offset_val = _mat(cache_offset)
builder.create_asc_DataCachePreloadOp(
src.to_ir(),
offset_val.to_ir()
)
@require_jit
@set_common_docstring(api_name="get_hccl_context")
def get_hccl_context(index: int) -> GlobalAddress:
builder = global_builder.get_ir_builder()
idx_ir = _mat(index, KnownTypes.uint32).to_ir()
ir_type = ir.get_memref_type(builder.get_ui8_type(), [ir.dynshape], ir.AddressSpace.gm)
return GlobalAddress(builder.create_asc_GetHcclContextOp(ir_type, idx_ir), KnownTypes.uint8)
@require_jit
@set_common_docstring(api_name="get_sys_workspace")
def get_sys_workspace() -> GlobalAddress:
builder = global_builder.get_ir_builder()
ir_type = ir.get_memref_type(builder.get_ui8_type(), [ir.dynshape], ir.AddressSpace.gm)
return GlobalAddress(builder.create_asc_GetSysWorkspacePtrOp(ir_type), KnownTypes.uint8)
@require_jit
def pop_stack_buffer(dtype: DataType, position: TPosition = TPosition.VECCALC) -> LocalTensor:
builder = global_builder.get_ir_builder()
handle = builder.create_asc_LocalTensorOp(ir.get_local_tensor_type(dtype.to_ir()))
tensor = LocalTensor(handle, dtype, shape=None)
builder.create_asc_PopStackBufferOp(ir.TPosition.symbolize(position), tensor.to_ir())
return tensor
@require_jit
@set_common_docstring(api_name="reset_mask")
def reset_mask() -> None:
builder = global_builder.get_ir_builder()
builder.create_asc_ResetMaskOp()
@overload
def set_aipp_functions(src0: GlobalTensor, input_format: AippInputFormat, config: AippParams) -> None:
...
@overload
def set_aipp_functions(src0: GlobalTensor, src1: GlobalTensor,
input_format: AippInputFormat, config: AippParams) -> None:
...
@require_jit
@set_common_docstring(api_name="set_aipp_functions")
def set_aipp_functions(*args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
dispatcher = OverloadDispatcher(set_aipp_functions)
@dispatcher.register_auto
def _(src0: GlobalTensor, src1: GlobalTensor, input_format: AippInputFormat, config: AippParams):
builder.create_asc_SetAippFunctionsOp(
src0.to_ir(), src1.to_ir(), ir.AippInputFormat.symbolize(input_format), config.to_ir())
@dispatcher.register_auto
def _(src0: GlobalTensor, input_format: AippInputFormat, config: AippParams):
builder.create_asc_SetAippFunctionsOp(
src0.to_ir(), None, ir.AippInputFormat.symbolize(input_format), config.to_ir())
dispatcher(*args, **kwargs)
@overload
@require_jit
def set_hccl_context(index: int, context: GlobalAddress) -> None:
...
@require_jit
@set_common_docstring(api_name="set_hccl_context")
def set_hccl_context(index: RuntimeInt, context: GlobalAddress) -> None:
builder = global_builder.get_ir_builder()
idx_ir = _mat(index, KnownTypes.uint32).to_ir()
builder.create_asc_SetHcclContextOp(idx_ir, context.to_ir())
@overload
def set_hf32_mode(hf32_mode: bool) -> None:
...
@require_jit
@set_common_docstring(api_name="set_hf32_mode")
def set_hf32_mode(hf32_mode: RuntimeBool) -> None:
builder = global_builder.get_ir_builder()
hf32_mode = _mat(hf32_mode)
builder.create_asc_SetHF32ModeOp(hf32_mode.to_ir())
@overload
def set_hf32_trans_mode(trans_mode: bool) -> None:
...
@require_jit
@set_common_docstring(api_name="set_hf32_trans_mode")
def set_hf32_trans_mode(trans_mode: RuntimeBool) -> None:
builder = global_builder.get_ir_builder()
trans_mode = _mat(trans_mode)
builder.create_asc_SetHF32TransModeOp(trans_mode.to_ir())
@overload
def set_mm_layout_transform(mm_layout_mode: bool) -> None:
...
@require_jit
@set_common_docstring(api_name="set_mm_layout_transform")
def set_mm_layout_transform(mm_layout_mode: RuntimeBool) -> None:
builder = global_builder.get_ir_builder()
mm_layout_mode = _mat(mm_layout_mode)
builder.create_asc_SetMMLayoutTransformOp(mm_layout_mode.to_ir())
@require_jit
@set_common_docstring(api_name="set_mask_count")
def set_mask_count() -> None:
builder = global_builder.get_ir_builder()
builder.create_asc_SetMaskCountOp()
@require_jit
@set_common_docstring(api_name="set_mask_norm")
def set_mask_norm() -> None:
builder = global_builder.get_ir_builder()
builder.create_asc_SetMaskNormOp()
@require_jit
def set_sys_workspace(workspace: GlobalAddress) -> None:
global_builder.get_ir_builder().create_asc_SetSysWorkspaceOp(workspace.to_ir())
@overload
def set_vector_mask(length: int, dtype: DataType, mode: MaskMode) -> None:
...
@overload
def set_vector_mask(mask_high: int, mask_low: int, dtype: DataType,
mode: MaskMode) -> None:
...
@require_jit
@set_common_docstring("set_vector_mask")
def set_vector_mask(*args, dtype: DataType, mode: MaskMode) -> None:
builder = global_builder.get_ir_builder()
dispatcher = OverloadDispatcher("set_vector_mask")
@dispatcher.register_auto
def _(length: RuntimeInt, dtype: DataType, mode: MaskMode):
builder.create_asc_SetVectorMaskL0Op(
_mat(length, KT.int32).to_ir(),
dtype.to_ir(),
ir.MaskMode.symbolize(mode)
)
@dispatcher.register_auto
def _(mask_high: RuntimeInt, mask_low: RuntimeInt, dtype: DataType, mode: MaskMode):
builder.create_asc_SetVectorMaskL1Op(
_mat(mask_high, KT.uint64).to_ir(),
_mat(mask_low, KT.uint64).to_ir(),
dtype.to_ir(),
ir.MaskMode.symbolize(mode)
)
dispatcher(*args, dtype=dtype, mode=mode)
@require_jit
@set_common_docstring("init_soc_state")
def init_soc_state() -> None:
global_builder.get_ir_builder().create_asc_InitSocStateOp()
@overload
def set_store_atomic_config(atomic_type: AtomicDtype, atomic_op: AtomicOp) -> None:
...
@require_jit
@set_common_docstring("set_store_atomic_config")
def set_store_atomic_config(atomic_type: AtomicDtype, atomic_op: AtomicOp) -> None:
global_builder.get_ir_builder().create_asc_SetStoreAtomicConfigOp(atomic_type, atomic_op)
@overload
def get_store_atomic_config() -> tuple[int, int]:
...
@require_jit
@set_common_docstring("get_store_atomic_config")
def get_store_atomic_config() -> tuple[RuntimeInt, RuntimeInt]:
builder = global_builder.get_ir_builder()
atomic_type, atomic_op = builder.create_asc_GetStoreAtomicConfigAndResult(KT.int16.to_ir(), KT.int16.to_ir())
return PlainValue(atomic_type), PlainValue(atomic_op)
@require_jit
@set_common_docstring("check_local_memory_ia")
def check_local_memory_ia(check_params: CheckLocalMemoryIAParam):
global_builder.get_ir_builder().create_asc_CheckLocalMemoryIAOp(check_params.to_ir())