from typing import Union, overload
from ..core.dtype import KnownTypes as KT
from ..core.enums import GatherMaskMode
from ..core.ir_value import materialize_ir_value as _mat, PlainValue, RuntimeBool, RuntimeInt
from ..core.tensor import LocalTensor
from ..core.utils import require_jit, global_builder, DefaultValued
from ..core.types import GatherMaskParams
from .utils import OverloadDispatcher, set_common_docstring
def check_type_gather_mask(dst: LocalTensor, src0: LocalTensor, src1_pattern: Union[LocalTensor, int]) -> None:
"""
Check data type constraints for GatherMask operation.
According to GatherMask specification:
- T (dst and src0 data types):
Atlas inference AI Core: half/uint16_t/int16_t/float/uint32_t/int32_t
Atlas A2 training/Atlas 800I A2 inference/A200I A2 Box: half/bfloat16_t/uint16_t/int16_t/float/uint32_t/int32_t
Atlas 200I/500 A2 inference: half/uint16_t/int16_t/float/uint32_t/int32_t
Atlas A3 training/A3 inference: half/bfloat16_t/uint16_t/int16_t/float/uint32_t/int32_t
- U (src1_pattern data type when LocalTensor):
uint16_t/uint32_t
When dst is half/uint16_t/int16_t: src1_pattern should be uint16_t
When dst is float/uint32_t/int32_t: src1_pattern should be uint32_t
"""
valid_dst_src0_types = [
KT.half,
KT.uint16,
KT.int16,
KT.float_,
KT.uint32,
KT.int32,
]
if dst.dtype not in valid_dst_src0_types:
raise TypeError(f"Invalid dst data type for GatherMask: {dst.dtype}. "
f"Supported types: half, uint16, int16, float, uint32, int32")
if src0.dtype not in valid_dst_src0_types:
raise TypeError(f"Invalid src0 data type for GatherMask: {src0.dtype}. "
f"Supported types: half, uint16, int16, float, uint32, int32")
if dst.dtype != src0.dtype:
raise TypeError(f"dst and src0 must have same data type. Got dst={dst.dtype}, src0={src0.dtype}")
if isinstance(src1_pattern, LocalTensor):
if dst.dtype in [KT.half, KT.uint16, KT.int16]:
if src1_pattern.dtype != KT.uint16:
raise TypeError(f"For dst data type {dst.dtype}, src1_pattern must be uint16. Got {src1_pattern.dtype}")
elif dst.dtype in [KT.float_, KT.uint32, KT.int32]:
if src1_pattern.dtype != KT.uint32:
raise TypeError(f"For dst data type {dst.dtype}, src1_pattern must be uint32. Got {src1_pattern.dtype}")
else:
raise TypeError(f"Unsupported dst data type for src1_pattern validation: {dst.dtype}")
elif isinstance(src1_pattern, int):
if not (1 <= src1_pattern <= 7):
raise ValueError(f"Built-in src1_pattern must be between 1 and 7. Got {src1_pattern}")
else:
raise TypeError(f"src1_pattern must be either LocalTensor or int. Got {type(src1_pattern)}")
@overload
def gather_mask(dst: LocalTensor, src0: LocalTensor, src1_pattern: LocalTensor,
reduce_mode: bool, mask: int, params: GatherMaskParams,
gather_mask_mode=GatherMaskMode.DEFAULT) -> int:
...
@overload
def gather_mask(dst: LocalTensor, src0: LocalTensor, src1_pattern: int,
reduce_mode: bool, mask: int, params: GatherMaskParams,
gather_mask_mode=GatherMaskMode.DEFAULT) -> int:
...
@require_jit
@set_common_docstring("gather_mask")
def gather_mask(dst: LocalTensor, src0: LocalTensor, *args, **kwargs) -> RuntimeInt:
builder = global_builder.get_ir_builder()
dispatcher = OverloadDispatcher("gather_mask")
@dispatcher.register(src1_pattern=LocalTensor, reduce_mode=RuntimeBool, mask=RuntimeInt, params=GatherMaskParams,
gather_mask_mode=DefaultValued(GatherMaskMode, GatherMaskMode.DEFAULT))
def _(src1_pattern: LocalTensor, reduce_mode: RuntimeBool, mask: RuntimeInt,
params: GatherMaskParams, gather_mask_mode: GatherMaskMode):
check_type_gather_mask(dst, src0, src1_pattern)
rsvd_cnt = builder.create_asc_GatherMaskAndResult(KT.uint64.to_ir(),
dst.to_ir(), src0.to_ir(), src1_pattern.to_ir(),
_mat(reduce_mode, KT.bool_).to_ir(), _mat(mask, KT.uint32).to_ir(),
params.to_ir(), gather_mask_mode
)
return PlainValue(rsvd_cnt)
@dispatcher.register(src1_pattern=RuntimeInt, reduce_mode=RuntimeBool, mask=RuntimeInt, params=GatherMaskParams,
gather_mask_mode=DefaultValued(GatherMaskMode, GatherMaskMode.DEFAULT))
def _(src1_pattern: RuntimeInt, reduce_mode: RuntimeBool, mask: RuntimeInt,
params: GatherMaskParams, gather_mask_mode: GatherMaskMode):
check_type_gather_mask(dst, src0, src1_pattern)
rsvd_cnt = builder.create_asc_GatherMaskAndResult(KT.uint64.to_ir(),
dst.to_ir(), src0.to_ir(), _mat(src1_pattern, KT.uint8).to_ir(),
_mat(reduce_mode, KT.bool_).to_ir(), _mat(mask, KT.uint32).to_ir(),
params.to_ir(), gather_mask_mode
)
return PlainValue(rsvd_cnt)
return dispatcher(*args, **kwargs)
@require_jit
def get_gather_mask_remain_count() -> int:
builder = global_builder.get_ir_builder()
result = builder.create_asc_GetGatherMaskRemainCountOp(builder.get_ui64_type())
return result