AllGatherMatmul
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
-
算子功能:完成AllGather通信与MatMul计算融合。
-
计算公式:
y=AllGather(x1)@x2+biasy=AllGather(x1)@x2+bias
gatherOut=AllGather(x1)gatherOut=AllGather(x1)
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | 公式中的输入x1。 | FLOAT16、BFLOAT16 | ND |
| x2 | 输入 | 公式中的输入x2。 | FLOAT16、BFLOAT16 | ND |
| bias | 可选输入 | 公式中的输入bias。 | FLOAT16、BFLOAT16 | ND |
| y | 输出 | 公式中的输出y。 | FLOAT16、BFLOAT16 | ND |
| gather_out | 输出 | 公式中的输出gatherOut。 | FLOAT16、BFLOAT16 | ND |
| group | 属性 | CHAR*、STRING | - | |
| is_trans_a | 可选属性 | BOOL | - | |
| is_trans_b | 可选属性 | BOOL | - | |
| gather_index | 可选属性 | INT64 | - | |
| commTurn | 可选属性 | INT64 | - | |
| rank_size | 可选属性 | INT64 | - | |
| is_gather_out | 可选属性 | BOOL | - |
约束说明
- 当前版本中,输入x1为2维,其shape为(m, k)。x2必须是2维,其shape为(k, n),轴满足MM算子入参要求,k轴相等,且k轴取值范围为[256, 65535)。
- x1/x2支持的空tensor场景,m和n可以为空,k不可为空,且需要满足以下条件:
- m为空,k不为空,n不为空;
- m不为空,k不为空,n为空;
- m为空,k不为空,n为空。
- x1计算输入、x2计算输入、output计算输出、gather_out计算输出的数据类型均需保持一致。
- x2矩阵支持转置/不转置场景,x2矩阵支持通过转置构造的非连续的Tensor,x1矩阵只支持不转置场景。
- bias可选可为空,非空时,当前版本仅支持一维输入,且暂不支持bias输入为非0的场景。
- 输出为2维,其shape为(m*rank_size, n), rank_size为卡数。
- gather_index当前版本仅支持输入0。
- commTurn当前版本仅支持输入0。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- 支持2、4、8卡,并且仅支持HCCS链路all mesh组网。
- 一个模型中的通算融合MC2算子,仅支持相同通信域。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:
- 支持2、4、8、16、32卡,并且仅支持HCCS链路double ring组网。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_all_gather_matmul | 通过aclnnAllGatherMatmul接口方式调用AllGatherMatmul算子。 |