from typing import List, Tuple, overload
from ..core.ir_value import PlainValue, RuntimeBool, RuntimeInt, materialize_ir_value as _mat
from ..core.tensor import LocalTensor, MrgSortSrcList
from ..core.types import KnownTypes as KT, MrgSort4Info
from ..core.utils import DefaultValued, require_jit, global_builder, OverloadDispatcher
from .utils import set_common_docstring
@overload
def get_sort_len(elem_count: int) -> RuntimeInt:
...
@overload
def get_sort_len(elem_count: RuntimeInt) -> RuntimeInt:
...
@require_jit
@set_common_docstring(api_name="get_sort_len")
def get_sort_len(elem_count):
builder = global_builder.get_ir_builder()
return builder.create_asc_GetSortLenOp(
builder.get_ui32_type(),
_mat(elem_count, KT.uint32).to_ir()
)
@overload
def get_sort_offset(elem_offset: int) -> RuntimeInt:
...
@overload
def get_sort_offset(elem_offset: RuntimeInt) -> RuntimeInt:
...
@require_jit
@set_common_docstring(api_name="get_sort_offset")
def get_sort_offset(elem_offset):
builder = global_builder.get_ir_builder()
return builder.create_asc_GetSortOffsetOp(
builder.get_ui32_type(),
_mat(elem_offset, KT.uint32).to_ir()
)
@overload
def get_mrg_sort_result() -> tuple[int, int, int, int]:
...
@require_jit
@set_common_docstring("get_mrg_sort_result")
def get_mrg_sort_result() -> Tuple[RuntimeInt, RuntimeInt, RuntimeInt, RuntimeInt]:
builder = global_builder.get_ir_builder()
arg1, arg2, arg3, arg4 = builder.create_asc_GetMrgSortResults(KT.uint16.to_ir(),
KT.uint16.to_ir(), KT.uint16.to_ir(), KT.uint16.to_ir())
return PlainValue(arg1), PlainValue(arg2), PlainValue(arg3), PlainValue(arg4)
@overload
def mrg_sort(dst: LocalTensor, sort_list: MrgSortSrcList, element_count_list: List[int],
sorted_num: List[int], valid_bit: int, repeat_time: int,
is_exhausted_suspension: bool = False) -> None:
...
@overload
def mrg_sort(dst: LocalTensor, sort_list: MrgSortSrcList, params: MrgSort4Info) -> None:
...
@require_jit
@set_common_docstring(api_name="mrg_sort")
def mrg_sort(dst: LocalTensor, sort_list: MrgSortSrcList, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
dispatcher = OverloadDispatcher("mrg_sort")
@dispatcher.register(element_count_list=List, sorted_num=List, valid_bit=RuntimeInt,
repeat_time=RuntimeInt, is_exhausted_suspension=DefaultValued(RuntimeBool, False))
def _(element_count_list: List, sorted_num: List, valid_bit: RuntimeInt,
repeat_time: RuntimeInt, is_exhausted_suspension: bool = False):
if is_exhausted_suspension not in (True, False):
raise TypeError(
f"The 'is_exhausted_suspension' argument must be a boolean literal (True or False), "
f"but got {is_exhausted_suspension} of type {type(is_exhausted_suspension).__name__}. "
f"This parameter must be a compile-time constant."
)
element_count_list_ir = [_mat(count, KT.uint16).to_ir() for count in element_count_list]
sorted_num_ir = [_mat(num, KT.uint32).to_ir() for num in sorted_num]
builder.create_asc_MrgSortOp(
dst.to_ir(), sort_list.to_ir(),
element_count_list_ir,
sorted_num_ir,
_mat(valid_bit, KT.uint16).to_ir(),
_mat(repeat_time, KT.uint16).to_ir(),
is_exhausted_suspension
)
@dispatcher.register(params=MrgSort4Info)
def _(params: MrgSort4Info):
builder.create_asc_MrgSortWithInfoOp(dst.to_ir(), sort_list.to_ir(), params.to_ir())
dispatcher(*args, **kwargs)
@require_jit
@set_common_docstring(api_name="mrg_sort4")
def mrg_sort4(dst: LocalTensor, src: MrgSortSrcList, params: MrgSort4Info) -> None:
builder = global_builder.get_ir_builder()
builder.create_asc_MrgSort4Op(dst.to_ir(), src.to_ir(), params.to_ir())
@overload
def proposal_concat(dst: LocalTensor, src: LocalTensor, repeat_time: int, mode_number: int) -> None:
...
@require_jit
@set_common_docstring(api_name="proposal_concat")
def proposal_concat(dst: LocalTensor, src: LocalTensor, repeat_time: RuntimeInt, mode_number: RuntimeInt) -> None:
global_builder.get_ir_builder().create_asc_ProposalConcatOp(dst.to_ir(), src.to_ir(),
_mat(repeat_time).to_ir(), _mat(mode_number).to_ir())
@overload
def proposal_extract(dst: LocalTensor, src: LocalTensor, repeat_time: int, mode_number: int) -> None:
...
@require_jit
@set_common_docstring(api_name="proposal_extract")
def proposal_extract(dst: LocalTensor, src: LocalTensor, repeat_time: RuntimeInt, mode_number: RuntimeInt) -> None:
global_builder.get_ir_builder().create_asc_ProposalExtractOp(dst.to_ir(), src.to_ir(),
_mat(repeat_time).to_ir(), _mat(mode_number).to_ir())
@overload
def rp_sort16(dst: LocalTensor, src: LocalTensor, repeat_time: int) -> None:
...
@require_jit
@set_common_docstring(api_name="rp_sort16")
def rp_sort16(dst: LocalTensor, src: LocalTensor, repeat_time: RuntimeInt) -> None:
builder = global_builder.get_ir_builder()
builder.create_asc_RpSort16Op(
dst.to_ir(),
src.to_ir(),
_mat(repeat_time, KT.int32).to_ir()
)
@overload
def sort(dst: LocalTensor, concat: LocalTensor, index: LocalTensor, tmp: LocalTensor, repeat_time: int) -> None:
...
@require_jit
@set_common_docstring(api_name="sort")
def sort(dst: LocalTensor, concat: LocalTensor, index: LocalTensor, tmp: LocalTensor, repeat_time: RuntimeInt,
is_full_sort: bool = False) -> None:
if is_full_sort not in (True, False):
raise TypeError(
f"The 'is_full_sort' argument must be a boolean literal (True or False), "
f"but got {is_full_sort} of type {type(is_full_sort).__name__}. "
f"This parameter must be a compile-time constant."
)
builder = global_builder.get_ir_builder()
builder.create_asc_SortOp(
dst.to_ir(),
concat.to_ir(),
index.to_ir(),
tmp.to_ir(),
_mat(repeat_time, KT.int32).to_ir(),
is_full_sort
)
@overload
def sort32(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, repeat_time: int) -> None:
...
@require_jit
@set_common_docstring(api_name="sort32")
def sort32(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, repeat_time: RuntimeInt) -> None:
builder = global_builder.get_ir_builder()
builder.create_asc_Sort32Op(
dst.to_ir(),
src0.to_ir(),
src1.to_ir(),
_mat(repeat_time, KT.int32).to_ir()
)