from typing import Any, Callable, Dict, List, Tuple, overload
from ..._C import ir
from ..core.dtype import KnownTypes as KT
from ..core.ir_value import RuntimeInt, materialize_ir_value as _mat
from ..core.tensor import LocalTensor
from ..core.utils import DefaultValued, OverloadDispatcher, require_jit, global_builder
from ..core.types import UnaryRepeatParams
from .utils import set_unary_docstring
def op_impl(callee: str, dst: LocalTensor, src: LocalTensor, args: Tuple[Any], kwargs: Dict[str, Any],
build_l0: Callable, build_l1: Callable, build_l2: Callable) -> None:
builder = build_l0.__self__
if not isinstance(builder, ir.Builder):
raise TypeError("Input builder must be ir.Builder")
dispatcher = OverloadDispatcher(callee)
@dispatcher.register(mask=RuntimeInt, repeat_times=RuntimeInt, repeat_params=UnaryRepeatParams,
is_set_mask=DefaultValued(bool, True))
def _(mask: RuntimeInt, repeat_times: RuntimeInt, repeat_params: UnaryRepeatParams, is_set_mask: bool = True):
build_l0(dst.to_ir(), src.to_ir(),
_mat(mask, KT.uint64).to_ir(),
_mat(repeat_times, KT.int8).to_ir(), repeat_params.to_ir(), is_set_mask)
@dispatcher.register(mask=list, repeat_times=RuntimeInt, repeat_params=UnaryRepeatParams,
is_set_mask=DefaultValued(bool, True))
def _(mask: list, repeat_times: RuntimeInt, repeat_params: UnaryRepeatParams, is_set_mask: bool = True):
mask = [_mat(v, KT.uint64).to_ir() for v in mask]
build_l1(dst.to_ir(), src.to_ir(), mask, _mat(repeat_times, KT.int8).to_ir(),
repeat_params.to_ir(), is_set_mask)
@dispatcher.register_auto
def _(count: RuntimeInt):
build_l2(dst.to_ir(), src.to_ir(), _mat(count, KT.int32).to_ir())
dispatcher(*args, **kwargs)
@overload
def abs(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def abs(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def abs(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Abs", append_text="按元素取绝对值。")
def abs(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("abs", dst, src, args, kwargs, builder.create_asc_AbsL0Op, builder.create_asc_AbsL1Op,
builder.create_asc_AbsL2Op)
@overload
def exp(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def exp(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def exp(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Exp", append_text="按元素取自然指数。")
def exp(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("exp", dst, src, args, kwargs, builder.create_asc_ExpL0Op, builder.create_asc_ExpL1Op,
builder.create_asc_ExpL2Op)
@overload
def ln(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def ln(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def ln(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Ln", append_text="按元素取自然对数。")
def ln(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("ln", dst, src, args, kwargs, builder.create_asc_LnL0Op, builder.create_asc_LnL1Op,
builder.create_asc_LnL2Op)
@overload
def bitwise_not(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def bitwise_not(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def bitwise_not(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Not", append_text="按元素做按位取反。命名为 bitwise_not 避免与Python关键字重名。")
def bitwise_not(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("bitwise_not", dst, src, args, kwargs, builder.create_asc_NotL0Op, builder.create_asc_NotL1Op,
builder.create_asc_NotL2Op)
@overload
def reciprocal(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def reciprocal(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def reciprocal(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Reciprocal", append_text="按元素取倒数。")
def reciprocal(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("reciprocal", dst, src, args, kwargs, builder.create_asc_ReciprocalL0Op, builder.create_asc_ReciprocalL1Op,
builder.create_asc_ReciprocalL2Op)
@overload
def relu(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def relu(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def relu(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Relu", append_text="按元素做线性整流Relu。")
def relu(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("relu", dst, src, args, kwargs, builder.create_asc_ReluL0Op, builder.create_asc_ReluL1Op,
builder.create_asc_ReluL2Op)
@overload
def rsqrt(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def rsqrt(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def rsqrt(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Rsqrt", append_text="按元素进行开方后取倒数的计算。")
def rsqrt(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("rsqrt", dst, src, args, kwargs, builder.create_asc_RsqrtL0Op, builder.create_asc_RsqrtL1Op,
builder.create_asc_RsqrtL2Op)
@overload
def sqrt(dst: LocalTensor, src: LocalTensor, count: int) -> None:
...
@overload
def sqrt(dst: LocalTensor, src: LocalTensor, mask: int, repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def sqrt(dst: LocalTensor, src: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: UnaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_unary_docstring(cpp_name="Sqrt", append_text="按元素做开方。")
def sqrt(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("sqrt", dst, src, args, kwargs, builder.create_asc_SqrtL0Op, builder.create_asc_SqrtL1Op,
builder.create_asc_SqrtL2Op)