from typing import overload, List
from ..core.ir_value import RuntimeInt, materialize_ir_value as _mat
from ..core.tensor import LocalTensor
from ..core.utils import require_jit, global_builder, OverloadDispatcher
from ..core.types import TransposeParamsExt, TransDataTo5HDParams, KnownTypes
from .utils import check_type_transpose, check_type_5hd, set_common_docstring
@overload
def transpose(dst: LocalTensor, src: LocalTensor) -> None:
...
@overload
def transpose(dst: LocalTensor, src: LocalTensor, shared_tmp_buffer: LocalTensor,
params: TransposeParamsExt) -> None:
...
@require_jit
@set_common_docstring(api_name="transpose")
def transpose(dst: LocalTensor, src: LocalTensor, *args, **kwargs) -> None:
check_type_transpose("transpose", dst, src, *args)
builder = global_builder.get_ir_builder()
if not args and not kwargs:
builder.create_asc_TransposeOp(dst.to_ir(), src.to_ir())
return
dispatcher = OverloadDispatcher("transpose")
@dispatcher.register_auto
def _(shared_tmp_buffer: LocalTensor, params: TransposeParamsExt):
builder.create_asc_TransposeExtOp(dst.to_ir(), src.to_ir(),
shared_tmp_buffer.to_ir(), params.to_ir())
dispatcher(*args, **kwargs)
TensorList = List[LocalTensor]
AddrList = List[RuntimeInt]
@overload
def trans_data_to_5hd(dst_list: TensorList, src_list: TensorList, params: TransDataTo5HDParams) -> None:
...
@overload
def trans_data_to_5hd(dst_list: AddrList, src_list: AddrList, params: TransDataTo5HDParams) -> None:
...
@overload
def trans_data_to_5hd(dst: LocalTensor, src: LocalTensor, params: TransDataTo5HDParams) -> None:
...
@require_jit
@set_common_docstring(api_name="trans_data_to_5hd")
def trans_data_to_5hd(dst_or_list, src_or_list, params: TransDataTo5HDParams, *args, **kwargs) -> None:
check_type_5hd("trans_data_to_5hd", dst_or_list, src_or_list)
builder = global_builder.get_ir_builder()
if isinstance(dst_or_list, LocalTensor):
builder.create_asc_TransDataTo5HDOp(dst_or_list.to_ir(), src_or_list.to_ir(), params.to_ir())
elif isinstance(dst_or_list, list):
if not dst_or_list:
return
if isinstance(dst_or_list[0], LocalTensor):
dst_ir_list = [t.to_ir() for t in dst_or_list]
src_ir_list = [t.to_ir() for t in src_or_list]
builder.create_asc_TransDataTo5HDTensorListOp(dst_ir_list, src_ir_list, params.to_ir())
else:
dst_ir_list = [_mat(addr, KnownTypes.uint64).to_ir() for addr in dst_or_list]
src_ir_list = [_mat(addr, KnownTypes.uint64).to_ir() for addr in src_or_list]
builder.create_asc_TransDataTo5HDUintListOp(dst_ir_list, src_ir_list, params.to_ir())