torch.topk 迁移说明
1. 算子说明
- 算子名称:
torch.topk - 迁移模式:
torch_npu - 来源与交付形态摘要:来源为 PyTorch in-tree CUDA backend
TensorTopK.cu;交付为 out-of-tree Ascend C SIMTtorch_npu扩展。 - 功能概述:沿指定维度计算 top-k values 和 int64 indices。
- 输入输出说明:输入为 NPU tensor;输出为
(values, indices),valuesdtype 与输入一致,indices为 int64。 - 一比一复刻结论:未完整完成。非 double 子集已在 Ascend 950 PR 上构建和 Python 验证;由于用户拒绝降级,
double保真和完整 radix/CUB 依赖链仍标记为阻塞。
2. 原始 CUDA 实现摘要
- 原始实现文件:
/home/y00621698/simt-buddy/pytorch/aten/src/ATen/native/cuda/TensorTopK.cu - Kernel 入口:
sbtopk::gatherTopK、warptopk::warpMergeSortTopK、mbtopk::computeBlockDigitCounts/gatherTopK - 用户侧调用路径摘要:
torch.topk->aten::topk/aten::topk.values->topk_out_cuda->launch_gather_topk_kernel
3. Ascend C SIMT 迁移摘要
3.1 迁移策略
- 采用
torch_npu扩展,提供topk_simt.topk()。 - 提供
enable_torch_topk_patch(),可让 NPU tensor 的torch.topk走本扩展。 - 保留
sbtopk、warptopk、mbtopk命名和 host 选择入口。 - 保留 contiguous 后按
dim切 slice 的行为。 - 是否实现可复用 Ascend counterpart:部分。当前未完成
TopKTypeConfig、radix select、block scan、scan-by-key 的共享 counterpart。
3.2 主要改动与结论
- CUDA stream 改为
c10_npu::getCurrentNPUStream().stream(true)。 - CUDA device helper 改为 Ascend SIMT
__aicore__helper。 RunOpApiV2包装会导致 Python 调用超时,当前改为直接 SIMT launch。doubledevice 实例化触发bisheng后端段错误;因用户拒绝降级,未改用 float。- 重大降级决策(如有):用户明确不同意降级;当前不接受 double 精度降级、host fallback 或删除多路径分支作为完成结论。
3.3 API 与语法替换说明
| CUDA 项 | Ascend 项 | 说明 |
|---|---|---|
threadIdx / blockIdx / blockDim / gridDim |
同名 SIMT 内建 | 直接使用 |
__device__ |
__aicore__ |
按 grammar.md 替换 |
__global__ launch |
SIMT <<<grid, block, dyn_ubuf, stream>>> |
直接 launch |
at::cuda::getCurrentCUDAStream() |
c10_npu::getCurrentNPUStream().stream(true) |
torch_npu 当前流 |
CUB BlockScan / inclusive_sum_by_key |
待实现 Ascend counterpart | 当前阻塞 |
device-side double |
无已验证保真路径 | bisheng codegen 段错误 |
3.4 依赖与能力覆盖结论
- 依赖闭包处理:ATen/Python 调用语义通过扩展复用;CUDA 专有 radix/CUB helper 尚未完整迁移。
- 能力覆盖结论:float32、float16、bfloat16、int32 的代表性路径已验证;double 阻塞。
- shape / kernel 选择行为保留情况:host 选择入口保留;内部算法仍需补齐一比一 radix/CUB counterpart。
- 语法规避依据:
grammar.md、constraints.md、device_api.yaml、runtime_api.yaml。 - 显式降级项:无用户接受的降级项;未完成项均记录为 blocked。
4. 目录结构
ported-ops/topk/
├── plan.md
├── README.md
├── extension_cpp/
│ ├── pyproject.toml
│ ├── requirements.txt
│ ├── setup.py
│ └── topk_simt/
│ ├── __init__.py
│ ├── ops.py
│ └── csrc/
│ ├── topk_register.asc
│ └── simt/
│ └── topk.asc
└── test/
└── test_topk_simt.py
5. 构建方式
5.1 环境准备
source /usr/local/Ascend/cann/set_env.sh
5.2 编译步骤
cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 setup.py build_ext --inplace
5.3 安装步骤
cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 -m pip install -e . --no-build-isolation
5.4 出包步骤
cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 -m pip wheel . --no-build-isolation -w dist
6. 验证与使用方式
6.1 原生/C++ 侧验证
cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 setup.py build_ext --inplace
6.2 Python 侧验证
PYTHONUNBUFFERED=1 \
PYTHONPATH=/home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp \
timeout 180 \
python3 -m unittest discover -s /home/y00621698/simt-buddy/tasks/ported-ops/topk/test -v
6.3 出包验证
cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 -m pip install -e . --no-build-isolation
python3 - <<'PY'
import topk_simt
print("import ok", topk_simt.__all__)
PY
6.4 使用示例
import torch
import torch_npu
import topk_simt
x = torch.randn(4, 16, device="npu")
values, indices = topk_simt.topk(x, 3, dim=-1)
topk_simt.enable_torch_topk_patch()
values2, indices2 = torch.topk(x, 3, dim=-1)
7. 最终验证结果摘要
- 验证环境:CANN 9.0.0,
bisheng,torch 2.11.0+cpu,torch_npu 2.11.0.dev20260414 - Ascend 950 PR 实机验证:已确认
torch.npu.get_device_name(0) => Ascend950PR_9599 - 构建结果:通过
- 原生/C++ 侧验证结果:
python3 setup.py build_ext --inplace通过 - Python 侧验证结果:
Ran 6 tests ... OK - 出包/安装结果:editable install 通过,
import topk_simt通过 - 结果校验结论:支持子集 values/indices 校验通过;bfloat16 tie 场景验证 gathered values。
- 当前状态标签:
blocked by environment or unsupported feature
8. 已知限制
- 显式降级项:无用户接受的降级项。
- 已阻塞项:
doubledevice 保真路径,bisheng在 double topk kernel lowering 时段错误。- 完整
TensorTopK.curadix/CUB 多 kernel 算法未一比一复刻。 - 直接注册 upstream
aten::topk.valuesPrivateUse1 structured kernel 未完成。
- 用户已接受的降级项:无。用户明确要求不同意降级。
9. 问题排查记录
-
问题阶段:构建
-
问题现象:实例化
mbtopk::computeBlockDigitCounts<double>时bisheng段错误。 -
根因判断:device-side
doublelowering 风险,与本地constraints.md中 double restricted 项一致。 -
已验证动作:移除 double device 实例化后构建通过;未用 float 替代 double。
-
相关语法/约束依据:
constraints.md的double in kernel code。 -
问题阶段:Python 调用
-
问题现象:
RunOpApiV2包装下 4 元素topk调用超时。 -
根因判断:自定义 SIMT launch 与该包装路径不适配。
-
已验证动作:改为直接 SIMT launch 后 4 元素验证和完整 Python 测试通过。
-
相关语法/约束依据:SIMT
<<<...>>>launch 支持。
10. 后续建议
- 补齐
TopKTypeConfig与 radix-select Ascend counterpart。 - 实现可复用 block scan 与 scan-by-key helper,供 sort/select 族复用。
- 最小化复现并修复/上报
bishengdouble codegen crash。 - 将 custom op 迁移为
aten::topkPrivateUse1 structured 注册,减少 Python patch 依赖。
11. 未完成事项
- 完整 double dtype 覆盖。
- 完整
sbtopk/warptopk/mbtopk算法等价复刻。 - 源
sorted=True后处理的sortKeyValueInplace/sort_outf等价 counterpart。