README.md

aclblasLtMatmul 接口实现

概述

BLAS Lt 矩阵乘法(aclblasLtMatmul)接口实现与精度测试。

aclblasLtMatmul 实现了通用矩阵乘法运算,对应的数学表达式为:

D = alpha * op(A) * op(B) + beta * C

其中 A、B 为输入矩阵,C 为累加矩阵,D 为输出矩阵,alpha 和 beta 为标量,op(A)/op(B) 支持不转置(N)和转置(T)。当前实现支持 FP32、MXFP8(E4M3FN)、MXFP4(E2M1)三种输入类型组合,输出支持 FP32 和 BF16。

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品

MXFP8/MXFP4 量化路径依赖 CANN asc-devkit >= 9.1(ASC_DEVKIT_MAJOR >= 9 && ASC_DEVKIT_MINOR >= 1)。

目录结构介绍

接口实现位于 blasLt/

blasLt/
├── aclblasLt.cpp                          // aclBLASLt 库入口,含 aclblasLtMatmul 路由
├── matmul_fp32/arch35/
│   ├── matmul_fp32_host.cpp               // FP32 Host 侧 Tiling
│   └── matmul_fp32_kernel.cpp             // FP32 Kernel 侧实现
├── matmul_mxfp8/arch35/
│   ├── matmul_mxfp8_host.cpp              // MXFP8 Host 侧 Tiling
│   └── matmul_mxfp8_kernel.cpp            // MXFP8 Kernel 侧实现
├── matmul_mxfp4/arch35/
│   ├── matmul_mxfp4_host.cpp              // MXFP4 Host 侧 Tiling
│   └── matmul_mxfp4_kernel.cpp            // MXFP4 Kernel 侧实现
└── utils/
    └── kernel_utils.h                       // shared kernel helpers

测试代码位于 test/blasLtMatmul/

test/blasLtMatmul/
├── CMakeLists.txt                         // 编译工程文件
├── blasLtMatmul_param.h                   // 参数结构体(继承 BlasTestParamBase)
├── blasLtMatmul_golden.h                  // CPU golden(封装 aclblasLtMatmul CPU 参考)
└── arch35/
    ├── blasLtMatmul_npu_wrapper.h         // NPU wrapper(封装 aclrtMalloc/H2D/kernel/D2H/free)
    ├── blasLtMatmul_test.cpp              // 精度测试(GTest 入口)
    └── blasLtMatmul_test.csv              // 精度测试用例表

接口描述

  • 接口功能:
    执行矩阵乘法 D = alpha * op(A) * op(B) + beta * C。支持 FP32 全精度路径,以及 MXFP8/MXFP4 量化输入路径(需配合 scale factor)。

  • 对应接口为:

aclblasStatus_t aclblasLtMatmul(
    aclblasLtHandle_t lightHandle,
    aclblasLtMatmulDesc_t computeDesc,
    const void* alpha,
    const void* A,
    aclblasLtMatrixLayout_t Adesc,
    const void* B,
    aclblasLtMatrixLayout_t Bdesc,
    const void* beta,
    const void* C,
    aclblasLtMatrixLayout_t Cdesc,
    void* D,
    aclblasLtMatrixLayout_t Ddesc,
    const aclblasLtMatmulAlgo_t* algo,
    void* workspace,
    size_t workspaceSizeInBytes,
    aclrtStream stream);
参数 aclblasLtMatmul 参数说明
参数列表 Param. Memory in/out 含义
lightHandle in aclBLASLt 库上下文句柄,由 aclblasLtCreate 创建。不可为 NULL,否则返回 ACLBLAS_STATUS_NOT_INITIALIZED。
computeDesc in 矩阵乘法描述符,设置 transA/transB、epilogue、scale 指针等属性。不可为 NULL。
alpha host in 用于乘法的 float 标量。不可为 NULL。
A device in 输入矩阵 A,数据类型由 Adesc 指定。不可为 NULL(m>0 且 n>0 时)。
Adesc in 矩阵 A 的 layout 描述符(rows/cols/ld/order/dtype)。
B device in 输入矩阵 B,数据类型由 Bdesc 指定。不可为 NULL(m>0 且 n>0 时)。
Bdesc in 矩阵 B 的 layout 描述符。
beta host in 用于累加的 float 标量。不可为 NULL。beta=0 时 C 可不参与计算。
C device in 累加矩阵 C。beta=0 时可为 NULL。当前测试覆盖 C=NULL 场景。
Cdesc in 矩阵 C 的 layout 描述符。
D device out 输出矩阵 D,维度 m x n。不可为 NULL(m>0 且 n>0 时)。
Ddesc in 矩阵 D 的 layout 描述符,指定输出数据类型(FP32 或 BF16)。
algo in 算法描述符,可为 NULL(使用默认算法)。
workspace device in 工作空间内存,可为 NULL。非 NULL 时需 16B 对齐。
workspaceSizeInBytes in 工作空间大小(字节)。
stream in AscendCL 执行流。

当前支持的入参/出参范围

参数项 支持范围 说明
dtypeA / dtypeB FP32;MXFP8_E4M3FN;MXFP4_E2M1 A/B 须为同类型组合:FP32×FP32、MXFP8×MXFP8、MXFP4×MXFP4。其他组合返回 ACLBLAS_STATUS_NOT_SUPPORTED。
dtypeC FP32 累加矩阵类型,当前固定为 FP32。
dtypeD FP32;BF16 MXFP8/MXFP4 路径支持 FP32 或 BF16 输出;FP32 路径输出 FP32。
computeType ACLBLAS_COMPUTE_32F 所有已支持路径均使用 32F 计算精度。
transA / transB N、T 对应 ACLBLAS_OP_N(不转置)、ACLBLAS_OP_T(转置)。
M / N / K M,N,K ≥ 0 M=0 或 N=0 时为空操作,直接返回 SUCCESS。MXFP8/MXFP4 路径要求 K 为 32 的整数倍,否则返回 ACLBLAS_STATUS_INVALID_VALUE。
lda / ldb / ldc / ldd ld ≥ 物理列数 行主序(ACLBLASLT_ORDER_ROW)存储,ld 为 leading dimension,须 ≥ 矩阵物理列数。MXFP4 的 ld 为逻辑元素 leading dim(2 个 FP4 元素打包为 1 字节)。
alpha / beta float 当前测试覆盖 alpha=1.0、beta=0.0。beta=0 时 C 可为 NULL。
epilogue ACLBLASLT_EPILOGUE_DEFAULT 当前仅支持默认 epilogue。
scaleA / scaleB MXFP8/MXFP4 必填 通过 computeDesc 的 ACLBLASLT_MATMUL_DESC_A/B_SCALE_POINTER 设置,E8M0 格式,按 K 方向每 32 元素一组。不可为 NULL。
algo default / NULL 可为 NULL,使用默认算法。
order ACLBLASLT_ORDER_ROW 当前实现使用行主序。
  • 算子规格:

    算子类型(OpType)aclblasLtMatmul
    算子输入nameshapedata typeformat
    AM×K(或转置后 K×M)FP32 / MXFP8 / MXFP4ND
    BK×N(或转置后 N×K)FP32 / MXFP8 / MXFP4ND
    CM×NFP32ND
    scaleA按 K 分组E8M0 (uint8)ND
    scaleB按 K 分组E8M0 (uint8)ND
    alpha/beta1floatND
    算子输出DM×NFP32 / BF16ND
    核函数名MatmulFp32Kernel / matmul_mxfp8_kernel_do / ltmatmul_mxfp4_kernel_do
  • 算子实现:
    Host 侧根据 A/B 数据类型路由至 FP32、MXFP8 或 MXFP4 对应的 Tiling 与 Kernel 实现。MXFP 路径在 Kernel 内完成量化矩阵乘加,输出经 epilogue 处理写入 D。

  • 调用实现:
    通过 aclBLASLt 标准 API 调用,内部使用内核调用符 <<<>>> 启动 NPU 核函数。

测试用例覆盖

分组 用例数 覆盖场景
L0 FP32 基础 7 小/中/大规模 NN、TN/NT/TT 转置、algo=nullptr
L0 MXFP8 6 K=32 基础/大规模、四种转置、C=null
L0 MXFP4 7 K=32 基础/大规模、四种转置、C=null、FP32 输出
L0 异常入参 4 handle/desc/alpha/A 为 NULL
L0 边界 6 M=0/N=0 空操作、K 非 32 倍数非法、algo=nullptr
L1 FP32 扩展 11 矩形矩阵、非方阵转置、瘦矩阵 M128×N32×K128
L1 MXFP8 扩展 22 多种 K 规模、矩形/奇数维度、scale 全零、K 非法值
L1 MXFP4 扩展 17 多种 K 规模、矩形/转置、scale 全零、K 非法值
TEST_F 固定用例 4 NullHandle、NullComputeDesc、NullAlpha、NullA

编译运行

在本样例根目录下执行如下步骤,编译并执行测试。

  • 配置环境变量
    请根据当前环境上 CANN 开发套件包的安装方式,选择对应配置环境变量的命令。

    • 默认路径,root 用户安装 CANN 软件包

      source /usr/local/Ascend/cann/set_env.sh
      
    • 默认路径,非 root 用户安装 CANN 软件包

      source $HOME/Ascend/cann/set_env.sh
      
    • 指定路径 install_path,安装 CANN 软件包

      source ${install_path}/cann/set_env.sh
      
  • 样例执行

    bash build.sh --ops=blasLtMatmul --soc=ascend950 --run
    

    其中 --soc可选参数,用于指定目标硬件平台(与上文「产品支持情况」对应)。按实际硬件选用:

    产品 --soc 取值
    Ascend 950PR / Ascend 950DT ascend950

    执行结果如下,说明精度对比成功。

    [PASS] blasLtMatmul_test