from typing import Callable, List, TypeVar, Union, overload
from ..core.dtype import KnownTypes as KT
from ..core.ir_value import RuntimeBool, RuntimeInt, materialize_ir_value as _mat
from ..core.tensor import LocalTensor
from ..core.utils import OverloadDispatcher, require_jit, global_builder
from ..core.types import BinaryRepeatParams
from .utils import check_type, op_impl, set_binary_docstring
T = TypeVar("T", bound=Callable)
@overload
def add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="Add", append_text="按元素求和。")
def add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("add", dst, src0, src1, args, kwargs, builder.create_asc_AddL0Op, builder.create_asc_AddL1Op,
builder.create_asc_AddL2Op)
@overload
def add_deq_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def add_deq_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def add_deq_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="AddDeqRelu", append_text="依次计算按元素求和、结果进行deq量化后再进行relu计算(结果和0对比取较大值)。")
def add_deq_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("add_deq_relu", dst, src0, src1, args, kwargs, builder.create_asc_AddDeqReluL0Op,
builder.create_asc_AddDeqReluL1Op, builder.create_asc_AddDeqReluL2Op)
@overload
def add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="AddRelu", append_text="按元素求和,再进行Relu计算(结果和0对比取较大值)。")
def add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("add_relu", dst, src0, src1, args, kwargs, builder.create_asc_AddReluL0Op, builder.create_asc_AddReluL1Op,
builder.create_asc_AddReluL2Op)
@overload
def bitwise_and(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def bitwise_and(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def bitwise_and(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="And", append_text="每对elements按位与运算。命名为 bitwise_and 避免与Python关键字重名。")
def bitwise_and(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
"""
Computes the element-wise and, corresponding to AscendC::And.
Use bitwise_and to avoid conflict with python keywords.
"""
builder = global_builder.get_ir_builder()
op_impl("bitwise_and", dst, src0, src1, args, kwargs, builder.create_asc_AndL0Op, builder.create_asc_AndL1Op,
builder.create_asc_AndL2Op)
@overload
def bitwise_or(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def bitwise_or(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def bitwise_or(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="Or", append_text="每对elements按位或运算。命名为 bitwise_or 避免与Python关键字重名。")
def bitwise_or(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
"""
Computes the element-wise or, corresponding to AscendC::Or. Use bitwise_or to avoid conflict with python keywords.
"""
builder = global_builder.get_ir_builder()
op_impl("bitwise_or", dst, src0, src1, args, kwargs, builder.create_asc_OrL0Op, builder.create_asc_OrL1Op,
builder.create_asc_OrL2Op)
@overload
def div(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def div(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def div(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="Div", append_text="按元素求商。")
def div(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("div", dst, src0, src1, args, kwargs, builder.create_asc_DivL0Op, builder.create_asc_DivL1Op,
builder.create_asc_DivL2Op)
@overload
def fused_mul_add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor,
count: int, is_set_mask: bool = True) -> None:
...
@overload
def fused_mul_add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def fused_mul_add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="FusedMulAdd", append_text="按元素将src0和dst相乘并加上src1,最终结果存放入dst。")
def fused_mul_add(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("fused_mul_add", dst, src0, src1, args, kwargs, builder.create_asc_FusedMulAddL0Op,
builder.create_asc_FusedMulAddL1Op, builder.create_asc_FusedMulAddL2Op)
@overload
def fused_mul_add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor,
count: int, is_set_mask: bool = True) -> None:
...
@overload
def fused_mul_add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def fused_mul_add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="FusedMulAddRelu",
append_text="按元素将src0和dst相乘并加上src1,再进行Relu计算(结果和0对比取较大值),最终结果存放进dst中。")
def fused_mul_add_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("fused_mul_add_relu", dst, src0, src1, args, kwargs, builder.create_asc_FusedMulAddReluL0Op,
builder.create_asc_FusedMulAddReluL1Op, builder.create_asc_FusedMulAddReluL2Op)
@overload
def max(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def max(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def max(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="Max", append_text="按元素求最大值。")
def max(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("max", dst, src0, src1, args, kwargs, builder.create_asc_MaxL0Op, builder.create_asc_MaxL1Op,
builder.create_asc_MaxL2Op)
@overload
def min(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def min(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def min(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="Min", append_text="按元素求最小值。")
def min(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("min", dst, src0, src1, args, kwargs, builder.create_asc_MinL0Op, builder.create_asc_MinL1Op,
builder.create_asc_MinL2Op)
@overload
def mul(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def mul(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def mul(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="Mul", append_text="按元素求积。")
def mul(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("mul", dst, src0, src1, args, kwargs, builder.create_asc_MulL0Op, builder.create_asc_MulL1Op,
builder.create_asc_MulL2Op)
@overload
def mul_add_dst(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def mul_add_dst(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def mul_add_dst(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="MulAddDst", append_text="按元素将src0和src1相乘并和dst相加,将最终结果存放进dst中。")
def mul_add_dst(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("mul_add_dst", dst, src0, src1, args, kwargs, builder.create_asc_MulAddDstL0Op,
builder.create_asc_MulAddDstL1Op, builder.create_asc_MulAddDstL2Op)
@overload
def mul_cast(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int) -> None:
...
@overload
def mul_cast(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams) -> None:
...
@overload
def mul_cast(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="MulCast", append_text="按元素求积,并根据源操作数和目的操作数Tensor的数据类型进行精度转换。")
def mul_cast(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
dispatcher = OverloadDispatcher(__name__)
builder = global_builder.get_ir_builder()
check_type("mul_cast", dst, src0, src1)
@dispatcher.register(mask=RuntimeInt, repeat_times=RuntimeInt, repeat_params=BinaryRepeatParams)
def _(mask: RuntimeInt, repeat_times: RuntimeInt, repeat_params: BinaryRepeatParams):
builder.create_asc_MulCastL0Op(dst.to_ir(), src0.to_ir(), src1.to_ir(),
_mat(mask, KT.uint64).to_ir(), _mat(repeat_times, KT.int8).to_ir(), repeat_params.to_ir())
@dispatcher.register(mask=list, repeat_times=RuntimeInt, repeat_params=BinaryRepeatParams)
def _(mask: list, repeat_times: RuntimeInt, repeat_params: BinaryRepeatParams):
mask = [_mat(v, KT.uint64).to_ir() for v in mask]
builder.create_asc_MulCastL1Op(dst.to_ir(), src0.to_ir(), src1.to_ir(), mask,
_mat(repeat_times, KT.int8).to_ir(), repeat_params.to_ir())
@dispatcher.register_auto
def _(count: RuntimeInt):
builder.create_asc_MulCastL2Op(dst.to_ir(), src0.to_ir(), src1.to_ir(), _mat(count, KT.int32).to_ir())
dispatcher(*args, **kwargs)
@overload
def sub(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def sub(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def sub(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="Sub", append_text="按元素求差。")
def sub(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("sub", dst, src0, src1, args, kwargs, builder.create_asc_SubL0Op, builder.create_asc_SubL1Op,
builder.create_asc_SubL2Op)
@overload
def sub_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, count: int, is_set_mask: bool = True) -> None:
...
@overload
def sub_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: int, repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@overload
def sub_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, mask: List[int], repeat_times: int,
repeat_params: BinaryRepeatParams, is_set_mask: bool = True) -> None:
...
@require_jit
@set_binary_docstring(cpp_name="SubRelu", append_text="按元素求差,再进行Relu计算(结果和0对比取较大值)。")
def sub_relu(dst: LocalTensor, src0: LocalTensor, src1: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
op_impl("sub_relu", dst, src0, src1, args, kwargs, builder.create_asc_SubReluL0Op, builder.create_asc_SubReluL1Op,
builder.create_asc_SubReluL2Op)
@overload
def bilinear_interpolation(dst: LocalTensor, src0: LocalTensor, src0_offset: LocalTensor, src1: LocalTensor, mask: int,
h_repeat: int, repeat_mode: bool, dst_blk_stride: int, v_r_offset: int, v_repeat: int,
shared_tmp_buffer: LocalTensor) -> None:
...
@overload
def bilinear_interpolation(dst: LocalTensor, src0: LocalTensor, src0_offset: LocalTensor, src1: LocalTensor,
mask: List[int], h_repeat: int, repeat_mode: bool, dst_blk_stride: int, v_r_offset: int,
v_repeat: int, shared_tmp_buffer: LocalTensor) -> None:
...
@require_jit
def bilinear_interpolation(dst: LocalTensor, src0: LocalTensor, src0_offset: LocalTensor, src1: LocalTensor,
mask: Union[list, RuntimeInt], h_repeat: RuntimeInt, repeat_mode: RuntimeBool,
dst_blk_stride: RuntimeInt, v_r_offset: RuntimeInt, v_repeat: RuntimeInt,
shared_tmp_buffer: LocalTensor) -> None:
"""
分为水平迭代和垂直迭代。
每个水平迭代顺序地从src0_offset读取8个偏移值,表示src0的偏移,每个偏移值指向src0的一个data_block的起始地址,如果repeat_mode=false,从src1中取一个值,
与src0中8个data_block中每个值进行乘操作;如果repeat_mode=true,从src1中取8个值,按顺序与src0中8个data_block中的值进行乘操作,
最后当前迭代的dst结果与前一个dst结果按data_block进行累加,存入目的地址,在同一个水平迭代内dst地址不变。
然后进行垂直迭代,垂直迭代的dst起始地址为上一轮垂直迭代的dst起始地址加上v_r_offset,本轮垂直迭代占用dst空间为dst起始地址之后的8个data_block,每轮垂直迭代进行h_repeat次水平迭代。
**对应的Ascend C函数原型**
.. code-block:: c++
template <typename T>
__aicore__ inline void BilinearInterpolation(const LocalTensor<T> &dst, const LocalTensor<T> &src0,
const LocalTensor<uint32_t> &src0Offset, const LocalTensor<T> &src1, uint64_t mask[],
uint8_t hRepeat, bool repeatMode, uint16_t dstBlkStride, uint16_t vROffset,
uint8_t vRepeat, const LocalTensor<uint8_t> &sharedTmpBuffer)
.. code-block:: c++
template <typename T>
__aicore__ inline void BilinearInterpolation(const LocalTensor<T> &dst, const LocalTensor<T> &src0,
const LocalTensor<uint32_t> &src0Offset, const LocalTensor<T> &src1, uint64_t mask,
uint8_t hRepeat, bool repeatMode, uint16_t dstBlkStride, uint16_t vROffset,
uint8_t vRepeat, const LocalTensor<uint8_t> &sharedTmpBuffer)
**参数说明**
- dst:目的操作数。类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
- src0, src1:源操作数。类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
- count:参与计算的元素个数。
- mask:用于控制每次迭代内参与计算的元素。
- repeat_times:重复迭代次数。
- params:控制操作数地址步长的参数。
**调用示例**
- 接口样例-mask连续模式
.. code-block:: python
mask = 128; # mask连续模式
h_repeat = 2; # 水平迭代2次
repeat_mode = false; # 迭代模式
dst_blk_stride = 1; # 单次迭代内数据连续写入
v_r_offset = 128; # 相邻迭代间数据连续写入
v_repeat = 2; # 垂直迭代2次
asc.bilinear_interpolation(dst_local, src0_local, src0_offset_local, src1_local, mask, h_repeat, repeat_mode,
dst_blk_stride, v_r_offset, v_repeat, tmp_local)
- 接口样例-mask逐bit模式
.. code-block:: python
mask = [uint64_max, uint64_max]; # mask逐bit模式
h_repeat = 2; # 水平迭代2次
repeat_mode = false; # 迭代模式
dst_blk_stride = 1; # 单次迭代内数据连续写入
v_r_offset = 128; # 相邻迭代间数据连续写入
v_repeat = 2; # 垂直迭代2次
asc.bilinear_interpolation(dst_local, src0_local, src0_offset_local, src1_local, mask, h_repeat, repeat_mode,
dst_blk_stride, v_r_offset, v_repeat, tmp_local)
"""
builder = global_builder.get_ir_builder()
check_type("bilinear_interpolation", dst, src0, src1)
if isinstance(mask, list):
mask = [_mat(v, KT.uint64).to_ir() for v in mask]
builder.create_asc_BilinearInterpolationL1Op(dst.to_ir(), src0.to_ir(), src0_offset.to_ir(), src1.to_ir(), mask,
_mat(h_repeat, KT.uint8).to_ir(),
_mat(repeat_mode, KT.bool_).to_ir(),
_mat(dst_blk_stride, KT.uint16).to_ir(),
_mat(v_r_offset, KT.uint16).to_ir(),
_mat(v_repeat, KT.uint8).to_ir(), shared_tmp_buffer.to_ir())
else:
builder.create_asc_BilinearInterpolationL0Op(dst.to_ir(), src0.to_ir(), src0_offset.to_ir(), src1.to_ir(),
_mat(mask, KT.uint64).to_ir(),
_mat(h_repeat, KT.uint8).to_ir(),
_mat(repeat_mode, KT.bool_).to_ir(),
_mat(dst_blk_stride, KT.uint16).to_ir(),
_mat(v_r_offset, KT.uint16).to_ir(),
_mat(v_repeat, KT.uint8).to_ir(), shared_tmp_buffer.to_ir())