文件最后提交记录最后更新时间
1 个月前
21 天前
21 天前
1 个月前
21 天前
1 个月前
29 天前
4 个月前
23 天前
README.md

TransposeQuantBatchMatMul

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:完成张量x1与张量x2量化的矩阵乘计算,支持K-C、MX量化模式。仅支持三维的Tensor传入。Tensor支持转置,转置序列根据传入的序列进行变更。permX1代表张量x1的转置序列,支持[1,0,2],permX2代表张量x2的转置序列,K-C模式支持[0,1,2],mx模式支持[0,1,2]和[0,2,1],permY表示矩阵乘输出矩阵的转置序列,当前仅支持[1,0,2],序列值为0的是batch维度,其余两个维度做矩阵乘法。x1Scale和x2Scale表示输出矩阵的量化系数;bias为预留参数,当前暂不支持,详细约束条件可见约束说明或者aclnnTransposeQuantBatchMatMul调用说明文档。

  • 示例: 假设x1的shape是[M, B, K],x2的shape是[B, K, N],x1Scale和x2Scale不为None,batchSplitFactor等于1时,计算输出out的shape是[M, B, N]。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x1 输入 矩阵乘运算中的左矩阵。 FLOAT8_E5M2, FLOAT8_E4M3FN ND
x2 输入 矩阵乘运算中的右矩阵。 FLOAT8_E5M2, FLOAT8_E4M3FN ND
bias 输入 矩阵乘运算后累加的偏置,预留参数。 FLOAT32, FLOAT16, BF16 ND
x1Scale 输入 量化参数的缩放因子。 FLOAT32、FLOAT8_E8M0 ND
x2Scale 输入 量化参数的缩放因子。 FLOAT32、FLOAT8_E8M0 ND
dtype 输入 用于指定输出矩阵的数据类型。 INT32 -
groupSize 输入 用于指定量化分组大小。 INT64 -
permX1 输入 表示矩阵乘的第一个矩阵的转置序列。 INT64 -
permX2 输入 表示矩阵乘的第二个矩阵的转置序列。 INT64 -
permY 输入 表示矩阵乘输出矩阵的转置序列。 INT64 -
batchSplitFactor 输入 用于指定矩阵乘输出矩阵中N维的切分大小。 INT32 -
y 输出 矩阵乘运算的计算结果。 FLOAT16, BFLOAT16 ND
  • Ascend 950PR/Ascend 950DT:只有输入x2支持FRACTAL_NZ格式。

约束说明

  • Ascend 950PR/Ascend 950DT:
    • permX1和permY支持[1, 0, 2]
    • K-C量化场景,permX2支持输入[0, 1, 2];MX量化场景,permX2支持输入[0, 1, 2]或[0, 2, 1]。
    • K-C量化场景,K仅支持512,N仅支持128。x1Scale和x2Scale为1维,并且x1Scale为(M,), x2Scale为(N,),group_size仅支持配置为0,其他取值不生效。
    • MX量化场景,K仅支持64的倍数。x1Scale和x2Scale为4维,并且x1Scale为(M, B, K/64, 2), x2Scale为(B, K/64, N, 2)或(B, N, K/64, 2),group_size的groupSizeM和groupSizeN仅支持0或1,groupSizeK仅支持32。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_quant_batch_mat_mul 通过
- aclnnTransposeQuantBatchMatMul
- aclnnTransposeQuantBatchMatMulWeightNz
等方式调用TransposeQuantBatchMatMul算子。