Torch Catlass
tests/optest 是 CATLASS 示例算子接入 PyTorch 的测试框架。框架将 AscendC/CATLASS kernel 通过 C++ extension 注册为 torch.ops.catlass.*,并由 Python 包 torch_catlass 提供调用入口。
详细设计见 docs/design.md。
架构总览
框架按职责分为五层:
Python API (`torch_catlass.ops.*`)
-> Python loader (`torch_catlass/__init__.py`)
-> PyTorch C++ extension (`src/`)
-> Kernel adapter (`src/include/template/`)
-> Kernel 实现 (`kernels/`)
- Python 层负责用户接口与动态库加载。
- C++ extension 层负责
torch.ops.catlass.*注册和 NPU dispatch。 - adapter 层负责 Tensor 参数到 kernel ABI 参数的转换。
- kernel 层负责 prebuilt kernel 与 JIT kernel 的实现和执行。
算子接入清单
已接入
- 00_basic_matmul
- 01_batched_matmul
- 02_grouped_matmul_slice_m
- 03_matmul_add
- 04_padding_matmul
- 05_grouped_matmul_slice_k
- 06_optimized_matmul
- 07_grouped_matmul_slice_m_per_token_dequant_moe
- 08_grouped_matmul
- 09_splitk_matmul
- 10_grouped_matmul_slice_m_per_token_dequant
- 11_grouped_matmul_slice_k_per_token_dequant
- 12_quant_matmul
- 13_basic_matmul_tla
- 14_optimized_matmul_tla
- 19_mla
- 20_matmul_bias
- 21_basic_matmul_preload_zN
- 22_padding_splitk_matmul
- 23_flash_attention_infer
- 40_flash_attention_infer_tla
- 25_matmul_full_loadA
- 26_matmul_relu
- 27_matmul_gelu
- 28_matmul_silu
- 29_a2_fp8_e4m3_matmul
- 30_w8a16_matmul
- 31_small_matmul
- 32_w4a8_matmul
- 34_single_core_splitk_matmul
- 37_streamk_matmul
- 39_big_matmul_tla
- 41_sparse_matmul_tla
- 42_quant_optimized_matmul_tla
- 43_ascend950_basic_matmul
- 44_quant_matmul_full_loadA_tla
- 45_strided_batched_matmul_tla
- 52_quant_multi_core_splitk_matmul_tla
- 53_ascend950_fp8_mx_matmul (Ascend950)
- 54_ascend950_fp4_mx_matmul (Ascend950)
- 58_ascend950_fp8_mx_batch_matmul (Ascend950)
- 63_ascend950_dual_level_quant_mx_batch_matmul (Ascend950)
暂未接入
- 15_gemm
- 16_group_gemm
- 17_gemv_aiv
- 18_gemv_aic
- 24_conv_bias
- 33_basic_conv2d
- 38_w4a4_per_token_per_channel_dequant
- 46_ascend950_matmul_fixpipe_opti
- 47_ascend950_grouped_slice_m_dequant
- 48_ascend950_grouped_slice_m_pt_pc_dequant
- 49_ascend950_flash_attention_infer
- 50_ascend950_basic_matmul_gemv
- 51_ascend950_quant_per_group_per_block
- 55_ascend950_mx_grouped_slice_m
- 56_ascend950_basic_conv2d_tla
- 57_ascend950_matmul_full_dequant
- 59_ascend950_a8w4_mx_matmul
- 60_ascend950_grouped_matmul_slice_m
- 61_ascend950_svd_quant_matmul
- 102_dynamic_optimized_matmul
- 103_dynamic_optimized_quant_matmul_per_token_basic
目录结构
tests/optest/
├── CMakeLists.txt
├── build.sh
├── pyproject.toml
├── docs/
│ └── design.md
├── include/
│ ├── catlass_kernel.h
│ ├── catlass_kernel_jit.h
│ └── catlass_kernel_prebuilt.h
├── src/
│ ├── catlass_torch.cpp
│ └── include/common/
├── kernels/
│ ├── 00_basic_matmul/
│ ├── include/
│ └── jit/
├── torch_catlass/
│ ├── __init__.py
│ └── ops/
├── tests/
│ └── test_00_basic_matmul.py
└── utils/
昇腾与架构识别说明
本项目面向昇腾 NPU 的 CANN/Ascend C 开发栈,相关术语和 SoC 信息以昇腾公开开发文档为准。
- Python loader 侧:通过
torch_npu.npu.get_device_name()映射 arch id,用于加载torch_catlass/lib/<arch>/下的预构建库。 - JIT 编译侧:通过
GetCurrentNPUArch()(AscendC 平台 API)获取当前 SoC 对应架构,用于设置编译参数。 - 当前代码中的映射:
Ascend910B.*/Ascend910_93->2201,Ascend950->3510。
公共 ABI 与接口分层
include/catlass_kernel.h 仅作为聚合头,接口按模块拆分为:
- catlass_kernel_jit.h:matmul/gemm/gemv 等 JIT 接口,参数语义为
TParams + Params。 - catlass_kernel_prebuilt.h:flash-attention/conv/mla 等 prebuilt 接口。
其中:
TParams表示编译期模板参数(dtype/layout/transpose 等)。Params表示运行期参数(shape、buffer address 等)。
构建与测试
环境要求
- Python 3.11+
- PyTorch + torch-npu(与本机 CANN/驱动版本匹配)
- CMake 3.16+
- 可用昇腾 NPU 设备与环境
编译
bash build.sh
测试
pytest tests/ -v
环境变量
外部配置(用户可设置)
| 变量 | 作用 | 可接受值 | 默认值 |
|---|---|---|---|
ASCEND_HOME_PATH |
CANN 安装根目录,查找 compiler(ccec) 和 runtime 库 |
绝对路径 | —(必设) |
CATLASS_JIT_CACHE_DIR |
JIT 编译产物 .so 磁盘缓存根目录,版本号作为二级目录 |
绝对路径 | ~/.cache/catlass/jit_cache |
CATLASS_JIT_LOG_LEVEL |
JIT 编译日志等级 | 0=None, 1=Info, 2=Debug |
0 |
MS_SANITIZE_MEMORY |
启用 Ascend memory sanitizer 调试 | 1 |
— |
CATLASS_JIT_AIC_AS_MIX |
强制 AIC kernel 以 __mix__(1,0) 编译 |
任意非空 | —(默认 __cube__) |
CATLASS_JIT_AIV_AS_MIX |
强制 AIV kernel 以 __mix__(0,1) 编译 |
任意非空 | —(默认 __vector__) |
CATLASS_JIT_MIX_CV_11 |
强制 MIX kernel 以 __mix__(1,1) 编译 |
任意非空 | —(默认 __mix__(1,2)) |
PYTHON |
build.sh 使用的 Python 解释器路径 |
which python3 的输出 |
$(which python3) |
包内注入(import 时自动设置,用户不直接改)
| 变量 | 作用 | 来源 |
|---|---|---|
CATLASS_JIT_VERSION |
注入 JIT 编译的版本标识 -DCATLASS_VERSION_FULL,同时作为缓存目录的二级目录名 |
torch_catlass/__init__.py 从 catlass git 推导 |
CATLASS_JIT_PKG_DIR |
Python 包安装目录,JIT 依据此路径定位 templates 和 include | torch_catlass/__init__.py 通过 _find_pkg_dir() 解析 |
构建选项
| 变量 / 参数 | 作用 | 可接受值 |
|---|---|---|
build.sh --build-type |
CMake 构建类型 | Release(默认), Debug |
build.sh --skip-wheel |
跳过 wheel 打包,使用 editable install | — |
build.sh --clean |
清理 JIT 缓存 + 构建产物 | — |
CATLASS_ARCH_LIST (CMake) |
限制 prebuilt kernel 编译的 NPU 架构 | 分号分隔列表,如 2201;3510 |
prebuilt kernel 默认为每个 arch 编译两份:普通版本和
_ms(sanitizer) 版本,无需额外选项。
使用示例
import torch
import torch_catlass
a = torch.randn(1024, 1024, dtype=torch.float16, device="npu")
b = torch.randn(1024, 1024, dtype=torch.float16, device="npu")
c = torch_catlass.basic_matmul(a, b, outDType="float16")
开发指南
新增算子流程
- 在
include/catlass_kernel_jit.h或include/catlass_kernel_prebuilt.h预留 ABI(带 example 编号说明)。 - 在
kernels/实现 kernel(JIT template 或 prebuilt)。 - 在
src/增加 C++ op 适配和注册。 - 在
torch_catlass/ops/增加 Python wrapper。 - 在
tests/增加 pytest 用例。