__input__ = {
"kernel": {
"all_gather_matmul": "all_gather_matmul_inputs"
}
}
from typing import List
import torch
def all_gather_matmul_inputs(x1, x2, bias=None, is_trans_a=False, is_trans_b=False,
gather_index: int = 0, comm_turn: int = 0, rank_size: int = 1,
is_gather_out: bool = True, **kwargs):
"""
参数接收、校验和调整
参数名称必须与 all_gather_matmul_def.cpp 完全一致
Args:
x1: 输入张量 1 (REQUIRED) - 来自 Input("x1")
x2: 输入张量 2 (REQUIRED) - 来自 Input("x2")
bias: 偏置 (OPTIONAL) - 来自 Input("bias")
is_trans_a: 是否转置 x1 - 来自 Attr("is_trans_a").Bool(false)
is_trans_b: 是否转置 x2 - 来自 Attr("is_trans_b").Bool(false)
gather_index: Gather 索引 - 来自 Attr("gather_index").Int(0)
comm_turn: 通信轮次 - 来自 Attr("comm_turn").Int(0)
rank_size: Rank 数量 - 来自 Attr("rank_size").Int(0)
is_gather_out: 是否输出 gather 结果 - 来自 Attr("is_gather_out").Bool(true)
**kwargs: 扩展参数
"""
if is_trans_b:
expected_k = x2.shape[1] if len(x2.shape) > 1 else x2.shape[0]
actual_k = x1.shape[1] if len(x1.shape) > 1 else x1.shape[0]
if expected_k != actual_k:
raise ValueError(f"x1 and x2 shape mismatch after transpose: "
f"x1 K={actual_k}, x2 transposed K={expected_k}")
else:
expected_k = x2.shape[0]
actual_k = x1.shape[1] if len(x1.shape) > 1 else x1.shape[0]
if expected_k != actual_k:
raise ValueError(f"x1 and x2 shape mismatch: "
f"x1 K={actual_k}, x2 K={expected_k}")
if bias is not None:
output_dim = x2.shape[1] if not is_trans_b else x2.shape[0]
if bias.shape[-1] != output_dim:
raise ValueError(f"bias shape mismatch: expected {output_dim}, got {bias.shape[-1]}")
if rank_size <= 0:
raise ValueError(f"rank_size must be positive, got {rank_size}")
return x1, x2, bias, is_trans_a, is_trans_b, gather_index, comm_turn, rank_size, is_gather_out