from typing import overload
from ..core.dtype import KnownTypes as KT
from ..core.ir_value import materialize_ir_value as _mat, RuntimeInt
from ..core.tensor import LocalTensor, GlobalTensor
from ..core.types import FixpipeConfig, FixpipeParamsV220
from ..core.utils import require_jit, global_builder, OverloadDispatcher
from .utils import set_common_docstring
@overload
def fixpipe(dst: GlobalTensor, src: LocalTensor, params: FixpipeParamsV220,
config: FixpipeConfig = FixpipeConfig.cfg_row_major) -> None:
...
@overload
def fixpipe(dst: GlobalTensor, src: LocalTensor, cbuf_workspace: LocalTensor,
params: FixpipeParamsV220, config: FixpipeConfig = FixpipeConfig.cfg_row_major) -> None:
...
@require_jit
@set_common_docstring("fixpipe")
def fixpipe(dst: GlobalTensor, src: LocalTensor, *args, **kwargs) -> None:
builder = global_builder.get_ir_builder()
dispatcher = OverloadDispatcher("fixpipe")
@dispatcher.register_auto
def _(params: FixpipeParamsV220, config: FixpipeConfig = FixpipeConfig.cfg_row_major):
if callable(config):
config = config()
builder.create_asc_FixpipeOp(
dst.to_ir(), src.to_ir(), params.to_ir(), config.to_ir())
@dispatcher.register_auto
def _(cbuf_workspace: LocalTensor, params: FixpipeParamsV220, config: FixpipeConfig = FixpipeConfig.cfg_row_major):
if callable(config):
config = config()
builder.create_asc_FixpipeWithWorkspaceOp(
dst.to_ir(), src.to_ir(), cbuf_workspace.to_ir(), params.to_ir(), config.to_ir())
dispatcher(*args, **kwargs)
@overload
def set_fix_pipe_pre_quant_flag(config: int) -> None:
...
@require_jit
@set_common_docstring("set_fix_pipe_pre_quant_flag")
def set_fix_pipe_pre_quant_flag(config: RuntimeInt) -> None:
builder = global_builder.get_ir_builder()
builder.create_asc_SetFixpipePreQuantFlagOp(_mat(config, KT.uint64).to_ir())