__input__ = {
"kernel": {
"all_gather_matmul_v2": "all_gather_matmul_v2_inputs"
}
}
import numpy as np
def all_gather_matmul_v2_inputs(
x1,
x2,
bias=None,
x1_scale=None,
x2_scale=None,
quant_scale=None,
group="",
is_trans_a=False,
is_trans_b=False,
gather_index: int = 0,
comm_turn: int = 0,
rank_size: int = 0,
block_size: int = 0,
group_size: int = 0,
is_gather_out: bool = True,
is_amax_out: bool = False,
y_dtype: int = -1,
comm_mode: str = "aicpu",
**kwargs
):
"""
AllGatherMatmulV2 inputs validation and adjustment
Args:
x1: Input tensor 1
x2: Input tensor 2
bias: Bias tensor (optional)
x1_scale: Scale for x1 (optional)
x2_scale: Scale for x2 (optional)
quant_scale: Quantization scale (optional)
group: Communication group name
is_trans_a: Whether to transpose x1
is_trans_b: Whether to transpose x2
gather_index: Gather index
comm_turn: Communication turn
rank_size: Number of ranks
block_size: Block size for quantization
group_size: Group size for quantization
is_gather_out: Whether to output gather result
is_amax_out: Whether to output amax
y_dtype: Output data type
comm_mode: Communication mode
**kwargs: Additional arguments
Returns:
Processed inputs
"""
if is_trans_b:
x2 = x2.transpose()
return x1, x2, bias, x1_scale, x2_scale, quant_scale, group, is_trans_a, is_trans_b, \
gather_index, comm_turn, rank_size, block_size, group_size, is_gather_out, is_amax_out, y_dtype, comm_mode