torch.topk 迁移说明

1. 算子说明

  • 算子名称:torch.topk
  • 迁移模式:torch_npu
  • 来源与交付形态摘要:来源为 PyTorch in-tree CUDA backend TensorTopK.cu;交付为 out-of-tree Ascend C SIMT torch_npu 扩展。
  • 功能概述:沿指定维度计算 top-k values 和 int64 indices。
  • 输入输出说明:输入为 NPU tensor;输出为 (values, indices)values dtype 与输入一致,indices 为 int64。
  • 一比一复刻结论:未完整完成。非 double 子集已在 Ascend 950 PR 上构建和 Python 验证;由于用户拒绝降级,double 保真和完整 radix/CUB 依赖链仍标记为阻塞。

2. 原始 CUDA 实现摘要

  • 原始实现文件:/home/y00621698/simt-buddy/pytorch/aten/src/ATen/native/cuda/TensorTopK.cu
  • Kernel 入口:sbtopk::gatherTopKwarptopk::warpMergeSortTopKmbtopk::computeBlockDigitCounts / gatherTopK
  • 用户侧调用路径摘要:torch.topk -> aten::topk / aten::topk.values -> topk_out_cuda -> launch_gather_topk_kernel

3. Ascend C SIMT 迁移摘要

3.1 迁移策略

  • 采用 torch_npu 扩展,提供 topk_simt.topk()
  • 提供 enable_torch_topk_patch(),可让 NPU tensor 的 torch.topk 走本扩展。
  • 保留 sbtopkwarptopkmbtopk 命名和 host 选择入口。
  • 保留 contiguous 后按 dim 切 slice 的行为。
  • 是否实现可复用 Ascend counterpart:部分。当前未完成 TopKTypeConfig、radix select、block scan、scan-by-key 的共享 counterpart。

3.2 主要改动与结论

  • CUDA stream 改为 c10_npu::getCurrentNPUStream().stream(true)
  • CUDA device helper 改为 Ascend SIMT __aicore__ helper。
  • RunOpApiV2 包装会导致 Python 调用超时,当前改为直接 SIMT launch。
  • double device 实例化触发 bisheng 后端段错误;因用户拒绝降级,未改用 float。
  • 重大降级决策(如有):用户明确不同意降级;当前不接受 double 精度降级、host fallback 或删除多路径分支作为完成结论。

3.3 API 与语法替换说明

CUDA 项 Ascend 项 说明
threadIdx / blockIdx / blockDim / gridDim 同名 SIMT 内建 直接使用
__device__ __aicore__ grammar.md 替换
__global__ launch SIMT <<<grid, block, dyn_ubuf, stream>>> 直接 launch
at::cuda::getCurrentCUDAStream() c10_npu::getCurrentNPUStream().stream(true) torch_npu 当前流
CUB BlockScan / inclusive_sum_by_key 待实现 Ascend counterpart 当前阻塞
device-side double 无已验证保真路径 bisheng codegen 段错误

3.4 依赖与能力覆盖结论

  • 依赖闭包处理:ATen/Python 调用语义通过扩展复用;CUDA 专有 radix/CUB helper 尚未完整迁移。
  • 能力覆盖结论:float32、float16、bfloat16、int32 的代表性路径已验证;double 阻塞。
  • shape / kernel 选择行为保留情况:host 选择入口保留;内部算法仍需补齐一比一 radix/CUB counterpart。
  • 语法规避依据:grammar.mdconstraints.mddevice_api.yamlruntime_api.yaml
  • 显式降级项:无用户接受的降级项;未完成项均记录为 blocked。

4. 目录结构

ported-ops/topk/
├── plan.md
├── README.md
├── extension_cpp/
│   ├── pyproject.toml
│   ├── requirements.txt
│   ├── setup.py
│   └── topk_simt/
│       ├── __init__.py
│       ├── ops.py
│       └── csrc/
│           ├── topk_register.asc
│           └── simt/
│               └── topk.asc
└── test/
    └── test_topk_simt.py

5. 构建方式

5.1 环境准备

source /usr/local/Ascend/cann/set_env.sh

5.2 编译步骤

cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 setup.py build_ext --inplace

5.3 安装步骤

cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 -m pip install -e . --no-build-isolation

5.4 出包步骤

cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 -m pip wheel . --no-build-isolation -w dist

6. 验证与使用方式

6.1 原生/C++ 侧验证

cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 setup.py build_ext --inplace

6.2 Python 侧验证

PYTHONUNBUFFERED=1 \
PYTHONPATH=/home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp \
timeout 180 \
python3 -m unittest discover -s /home/y00621698/simt-buddy/tasks/ported-ops/topk/test -v

6.3 出包验证

cd /home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp
python3 -m pip install -e . --no-build-isolation
python3 - <<'PY'
import topk_simt
print("import ok", topk_simt.__all__)
PY

6.4 使用示例

import torch
import torch_npu
import topk_simt

x = torch.randn(4, 16, device="npu")
values, indices = topk_simt.topk(x, 3, dim=-1)

topk_simt.enable_torch_topk_patch()
values2, indices2 = torch.topk(x, 3, dim=-1)

7. 最终验证结果摘要

  • 验证环境:CANN 9.0.0,bishengtorch 2.11.0+cputorch_npu 2.11.0.dev20260414
  • Ascend 950 PR 实机验证:已确认 torch.npu.get_device_name(0) => Ascend950PR_9599
  • 构建结果:通过
  • 原生/C++ 侧验证结果:python3 setup.py build_ext --inplace 通过
  • Python 侧验证结果:Ran 6 tests ... OK
  • 出包/安装结果:editable install 通过,import topk_simt 通过
  • 结果校验结论:支持子集 values/indices 校验通过;bfloat16 tie 场景验证 gathered values。
  • 当前状态标签:blocked by environment or unsupported feature

8. 已知限制

  • 显式降级项:无用户接受的降级项。
  • 已阻塞项:
    • double device 保真路径,bisheng 在 double topk kernel lowering 时段错误。
    • 完整 TensorTopK.cu radix/CUB 多 kernel 算法未一比一复刻。
    • 直接注册 upstream aten::topk.values PrivateUse1 structured kernel 未完成。
  • 用户已接受的降级项:无。用户明确要求不同意降级。

9. 问题排查记录

  • 问题阶段:构建

  • 问题现象:实例化 mbtopk::computeBlockDigitCounts<double>bisheng 段错误。

  • 根因判断:device-side double lowering 风险,与本地 constraints.md 中 double restricted 项一致。

  • 已验证动作:移除 double device 实例化后构建通过;未用 float 替代 double。

  • 相关语法/约束依据:constraints.mddouble in kernel code

  • 问题阶段:Python 调用

  • 问题现象:RunOpApiV2 包装下 4 元素 topk 调用超时。

  • 根因判断:自定义 SIMT launch 与该包装路径不适配。

  • 已验证动作:改为直接 SIMT launch 后 4 元素验证和完整 Python 测试通过。

  • 相关语法/约束依据:SIMT <<<...>>> launch 支持。

10. 后续建议

  • 补齐 TopKTypeConfig 与 radix-select Ascend counterpart。
  • 实现可复用 block scan 与 scan-by-key helper,供 sort/select 族复用。
  • 最小化复现并修复/上报 bisheng double codegen crash。
  • 将 custom op 迁移为 aten::topk PrivateUse1 structured 注册,减少 Python patch 依赖。

11. 未完成事项

  • 完整 double dtype 覆盖。
  • 完整 sbtopk / warptopk / mbtopk 算法等价复刻。
  • sorted=True 后处理的 sortKeyValueInplace / sort_outf 等价 counterpart。