GroupedMatmulAdd

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:实现分组矩阵乘计算,每组矩阵乘的维度大小可以不同。基本功能为矩阵乘,如yRefi[mi,ni]=xi[mi,ki]×weighti[ki,ni]+yi[mi,ni],i=1...gyRef_i[m_i,n_i]=x_i[m_i,k_i] \times weight_i[k_i,n_i]+y_i[m_i,n_i], i=1...g,其中g为分组个数,mi/ki/nim_i/k_i/n_i为对应shape。当前仅支持K轴分组。

    • k轴分组:kik_i各不相同,但mi/nim_i/n_i每组相同。
  • 计算公式:

    yRefi=xi×weighti+yiyRef_i=x_i\times weight_i + y_i

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 公式中的输入x。 FLOAT16、BFLOAT16 ND
weight 输入 公式中的weight。 FLOAT16、BFLOAT16 ND
groupList 输入 表示输入K轴方向的matmul大小分布的cumsum结果(累积和)。 INT64 ND
y 输入 表示原地累加的输出矩阵。 FLOAT32 ND
transposeX 属性 表示x矩阵是否转置。 BOOL -
transposeWeight 属性 表示weight矩阵是否转置。 BOOL -
groupType 属性 表示分组类型。 INT64 -
groupListType 属性 表示分组groupList格式。 INT64 -
yRef 输出 表示原地累加的输出矩阵。 FLOAT32 ND

约束说明

  • x和weight中每一组tensor的每一维大小在32字节对齐后都应小于int32的最大值2147483647。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_grouped_matmul_add 通过接口方式调用GroupedMatmulAdd算子。