from typing import List, Union, overload
from ..core.dtype import KnownTypes as KT
from ..core.ir_value import RuntimeInt, RuntimeNumeric, materialize_ir_value as _mat
from ..core.tensor import LocalTensor
from ..core.utils import DefaultValued, OverloadDispatcher, require_jit, global_builder
from .utils import set_common_docstring
@overload
def duplicate(dst: LocalTensor, scalar: Union[int, float], count: int) -> None:
...
@overload
def duplicate(dst: LocalTensor, scalar: Union[int, float], mask: int, repeat_times: int,
dst_block_stride: int, dst_repeat_stride: int, is_set_mask: bool = True) -> None:
...
@overload
def duplicate(dst: LocalTensor, scalar: Union[int, float], mask: List[int], repeat_times: int,
dst_block_stride: int, dst_repeat_stride: int, is_set_mask: bool = True) -> None:
...
@require_jit
@set_common_docstring(api_name="duplicate")
def duplicate(dst: LocalTensor, scalar: RuntimeNumeric, *args, **kwargs) -> None:
dispatcher = OverloadDispatcher(__name__)
builder = global_builder.get_ir_builder()
@dispatcher.register(mask=RuntimeInt, repeat_times=RuntimeInt, dst_block_stride=RuntimeInt,
dst_repeat_stride=RuntimeInt, is_set_mask=DefaultValued(bool, True))
def _(mask: RuntimeInt, repeat_times: RuntimeInt, dst_block_stride: RuntimeInt,
dst_repeat_stride: RuntimeInt, is_set_mask: bool = True):
builder.create_asc_DuplicateL0Op(dst.to_ir(),
_mat(scalar, dst.dtype).to_ir(),
_mat(mask, KT.uint64).to_ir(),
_mat(repeat_times, KT.int8).to_ir(),
_mat(dst_block_stride, KT.int8).to_ir(),
_mat(dst_repeat_stride, KT.int8).to_ir(),
is_set_mask)
@dispatcher.register(mask=list, repeat_times=RuntimeInt, dst_block_stride=RuntimeInt,
dst_repeat_stride=RuntimeInt, is_set_mask=DefaultValued(bool, True))
def _(mask: list, repeat_times: RuntimeInt, dst_block_stride: RuntimeInt,
dst_repeat_stride: RuntimeInt, is_set_mask: bool = True):
mask = [_mat(v, KT.uint64).to_ir() for v in mask]
builder.create_asc_DuplicateL1Op(dst.to_ir(),
_mat(scalar, dst.dtype).to_ir(),
mask,
_mat(repeat_times, KT.int8).to_ir(),
_mat(dst_block_stride, KT.int8).to_ir(),
_mat(dst_repeat_stride, KT.int8).to_ir(),
is_set_mask)
@dispatcher.register(count=RuntimeInt, is_set_mask=DefaultValued(bool, True))
def _(count: RuntimeInt, is_set_mask: bool = True):
builder.create_asc_DuplicateL2Op(dst.to_ir(), _mat(scalar, dst.dtype).to_ir(), _mat(count, KT.int32).to_ir())
dispatcher(*args, **kwargs)