A2Fp8E4M3Matmul Example Readme
代码组织
├── 29_a2_fp8_e4m3_matmul
│ ├── CMakeLists.txt # CMake编译文件
│ ├── README.md
│ ├── gen_data.py # 数据生成脚本
│ └── fp8_matmul.cpp # 主文件
功能介绍
该算子支持输入A矩阵和B矩阵的数据类型为FP8 E4M3格式(软实现),然后进行矩阵乘输出C矩阵(FP16)。
实现细节
1、输入处理:接收两个FP8 E4M3格式的输入矩阵A和B
2、伪量化:将FP8数据伪量化成FP16格式(per-tensor量化模式)
3、矩阵乘:使用FP16数据进行矩阵乘,中间结果使用FP32精度进行累加
4、输出转换:将最终结果转换成FP16格式输出
使用示例
example使用
- 第一步,编译
- 获取代码之后编译相应的算子可执行文件,可参考quickstart
# 编译指定用例
bash scripts/build.sh 29_a2_fp8_e4m3_matmul
- 第二步,执行
gen_data.py生成测试数据,测试用例规格从命令行输入
cd examples/29_a2_fp8_e4m3_matmul && python gen_data.py 256 512 1024 0 0 && cd ../..
# 输入参数分别对应 m, n, k, trans_a, trans_b
# trans_a表示A矩阵是否转置,0是不转置,1是转置
# trans_b表示B矩阵是否转置,0是不转置,1是转置
执行该命令后会在当前路径下生成input和output目录,包含算子的输入数据和用于精度验证的golden数据
├── input
│ ├── a_8.bin
│ ├── b_8.bin
└── output
└── expected_data.bin
- 第三步,执行算子,注意提供给算子的输入shape和上面第二步生成数据的shape需一致
# 可执行文件名 |矩阵m轴|n轴|k轴|Device ID
# Device ID可选,默认为0
./output/bin/29_a2_fp8_e4m3_matmul 256 512 1024 0
执行结果如下,说明精度比对成功。
Compare success.
说明
1、 gen_data.py的输入支持trans_a和trans_b,但29_a2_fp8_e4m3_matmul可执行文件不支持,仅仅是trans_a和trans_b均为0的example示例。
若要对应转置情况请修改example示例中的layout,因为layout隐式表征转置状态,即layout::RowMajor表示不转置,layout::ColumnMajor表示转置。
其对应关系如下表:
| trans_a | trans_b | LayoutA | LayoutB |
|---|---|---|---|
| 0 | 0 | layout::RowMajor | layout::RowMajor |
| 0 | 1 | layout::RowMajor | layout::ColumnMajor |
| 1 | 0 | layout::ColumnMajor | layout::RowMajor |
| 1 | 1 | layout::ColumnMajor | layout::ColumnMajor |
2、对比FP16 Matmul,该样例针对大shape的case有较为明显的显存收益
3、针对小shape场景,可以参考catlass_optimize_guidance对样例进行tiling调优