BroadcastMatmulPerblockQuant Example Readme

代码组织

├── 62_ascend950_broadcast_matmul_perblock_quant
│   ├── CMakeLists.txt              # CMake编译文件
│   ├── README.md
│   ├── gen_data_compare.py         # 数据生成+精度比对脚本
│   └── broadcast_matmul_perblock_quant_tla.cpp # 算子调用示例

功能说明

该算子实现了张量A (shape [B,M,K])和矩阵B(shape [K,N])的广播矩阵乘法,并对计算结果进行perblock量化(block大小为[M,K])。 算子典型应用场景为Q,K,V与旋转矩阵进行矩阵乘法以平滑数据分布,然后进行MXFP8(E4M3)量化。

参数说明

参数名 输入/输出 描述 数据类型 数据格式
a 输入 张量a bfloat16 ND
b 输入 矩阵b bfloat16 ND
out 输出 a与b的广播矩阵乘法的量化结果 float8_e4m3fn ND
scale 输出 a与b的广播矩阵乘法的量化缩放系数 float32 ND
  • 输入a的shape为[B,M,K]
  • 输入b的shape为[K,N]
  • 输出out的shape为[B,M,N]
  • 输出scale的shape为[B]

约束说明

  • B的取值范围为[1,65536]; 对应Q,K,V按照MXFP8量化分块后block的数量。
  • M的取值范围为{128,256}; 对应MXFP8量化的block大小。
  • N和K的取值范围为{128}; 对应旋转矩阵的大小

使用示例

数据生成与精度比对

# 编译指定用例
bash scripts/build.sh 62_ascend950_broadcast_matmul_perblock_quant -DCATLASS_ARCH=3510

# 在示例目录下运行数据生成和比对脚本
cd examples/62_ascend950_broadcast_matmul_perblock_quant

# python3 gen_data_compare.py <batch_count> <m> <n> <k> <device_id>
python3 gen_data_compare.py 1024 128 128 128 0

执行结果如下,说明dst和scale精度比对成功。

------ 生成测试数据 ------
batch_count=1024, m=128, n=128, k=128
------ 运行NPU算子 ------
npu op run log =
------ 比对结果 ------
------ 计算相对误差 -----
------ 综合精度指标 ------
dst: npu mare=0.1250, golden mare=0.125000
dst: npu mere=0.0018, golden mere=0.001839
dst: npu rmse=1.3765, golden rmse=1.377344
scale: npu mare=0.0030, golden mare=0.003028
scale: npu mere=0.0013, golden mere=0.001291
scale: npu rmse=0.0006, golden rmse=0.000632
------ 开始比较 ------
精度指标比较结果:Compare success