torch.topk / TensorTopK 迁移执行计划
1. 目标与范围
- 算子名称:
torch.topk/aten::topk - 原始 CUDA 文件位置:
/home/y00621698/simt-buddy/pytorch/aten/src/ATen/native/cuda/TensorTopK.cu、TensorTopK.cpp、TensorTopK.h - 迁移模式:
torch_npu - 源产物形态:PyTorch ATen 内建 CUDA backend,用户侧调用路径为 Python
torch.topk - 目标迁移模式:
torch_npu - 目标交付形态:out-of-tree
torch_npu扩展,提供topk_simt.topk()与可选torch.topkpatch - 实际输出模式:
torch_npu - 迁移目标:保留 Python 调用路径、NPU tensor 输入输出、
values/indices双输出、largest/sorted/k/dim参数语义、sbtopk/warptopk/mbtopk路径命名与 host 选择入口 - 若实际输出形态与源产物不一致,形态差异说明:当前仓库未提供将实现直接并入 PyTorch 主树并注册
aten::topkPrivateUse1 structured kernel 的完整构建链;交付为仓外扩展与 Python patch。该 out-of-tree 形态为仓库默认允许路径,但不是 upstream in-tree 形态。 - 若存在真实降级,降级原因:用户已明确不同意降级;本次未把 double 改为 float,也未把失败路径声明为通过。当前代码为非 double 子集提供可运行实现,double 保真保持阻塞。
- 当前阻塞:
doubledevice 路径:bisheng在mbtopk::computeBlockDigitCounts<double>codegen 阶段段错误。按用户决策,不采用 device double 改 float 的降级。RunOpApiV2launch 包装:构建可通过,但 Python 调用卡在 op 内;直接 SIMT launch 可通过验证。- 完整 CUDA radix/CUB 依赖链未一比一复刻为共享 Ascend backend counterpart;当前实现保留路径入口但不是完整
TensorTopK.curadix/CUB 多 kernel 算法复刻。
- 恢复到目标模式所需条件:
bisheng支持或修复 device-sidedoubletopk codegen;- 明确
torch_npu自定义 SIMT kernel 在RunOpApiV2下的正确包装方式; - 补齐
TensorInfo、TopKTypeConfig、block scan、scan-by-key、radix select、small sort 的可复用 Ascend counterpart。
- 一比一复刻目标:
TensorTopK.cpp::topk_out_cuda的k==0、selection、sorted后处理;TensorTopK.cu中sbtopk、warptopk、mbtopk三类路径;- dtype 覆盖:float、double、uint8/int8/int16/int32/int64、half、bfloat16;
- shape/stride:contiguous 输入、按
dim折叠为 slice、输出 indices 为 int64。
- 本次迁移范围:
ported-ops/topk/extension_cpp/topk_simt/torch_npu 扩展;- SIMT kernel、注册代码、Python wrapper、Python 测试;
- 中文执行计划和 README。
- 本次不包含内容:
- 直接修改 PyTorch upstream;
- 完整 CUB scan-by-key counterpart;
- 接受任何用户未批准的 double 精度降级或单路径替代结论。
2. CUDA 函数清单
| 类型 | 所属结构体 | 函数/方法名 | 文件 | 说明 |
|---|---|---|---|---|
__device__ |
sbtopk::AddOp |
operator() |
TensorTopK.cu |
prefix scan 加法 |
__global__ |
sbtopk |
gatherTopK |
TensorTopK.cu |
单 block topk gather |
| host helper | sbtopk |
launch |
TensorTopK.cu |
单 block 路径 launch |
__device__ |
warptopk |
reserveWarpSpace / sort helper |
TensorTopK.cu |
ROCm warp compaction / merge sort 路径 |
__global__ |
warptopk |
warpMergeSortTopK |
TensorTopK.cu |
小 slice warp topk |
| host helper | warptopk |
launch |
TensorTopK.cu |
warp path launch |
__global__ |
mbtopk |
fill |
TensorTopK.cu |
临时数组初始化 |
__global__ |
mbtopk |
computeBlockDigitCounts |
TensorTopK.cu |
radix digit histogram |
__global__ |
mbtopk |
computeDigitCumSum |
TensorTopK.cu |
CUB block scan |
__global__ |
mbtopk |
computeBlockwiseWithinKCounts |
TensorTopK.cu |
radix pass 内的 within-k 计数 |
__global__ |
mbtopk |
computeBlockwiseKthCounts |
TensorTopK.cu |
kth value 计数 |
__global__ |
mbtopk |
gatherTopK |
TensorTopK.cu |
multi-block gather |
| host helper | mbtopk |
get_items_per_thread / launch |
TensorTopK.cu |
multi-block launch policy |
| host helper | 无 | should_use_multiblock / should_use_warp_topk |
TensorTopK.cu |
kernel 选择 |
| host entry | at::native |
launch_gather_topk_kernel |
TensorTopK.cu |
dtype/index/dim 分发 |
| host entry | at::native |
topk_out_cuda |
TensorTopK.cpp |
structured CUDA dispatch |
3. 调用链分析
3.1 torch.topk 调用链
Python torch.topk
└── aten::topk / aten::topk.values
└── CUDA dispatch: topk_out_cuda
├── should_use_sort(self, dim)
│ └── topk_out_with_sort
├── launch_gather_topk_kernel
│ ├── should_use_warp_topk -> warptopk::launch
│ ├── should_use_multiblock -> mbtopk::launch
│ └── fallback -> sbtopk::launch
└── sorted 后处理
├── sortKeyValueInplace
└── at::cuda::sort_outf + gather
3.2 当前交付调用链
Python topk_simt.topk / enable_torch_topk_patch 后的 torch.topk
└── topk_simt::topk PrivateUse1 custom op
└── topk_npu
├── should_use_warp_topk -> warptopk::launch
├── should_use_multiblock -> mbtopk::launch
└── fallback -> sbtopk::launch
4. API 映射分析
4.1 TensorTopK API 映射
| API 类型 | CUDA API | Ascend API | 迁移方式 | 说明 |
|---|---|---|---|---|
| device | threadIdx / blockIdx / blockDim / gridDim |
同名 SIMT 内建 | direct_replace |
device_api.yaml 已映射 |
| device | __global__ |
__global__ |
direct_replace |
Ascend SIMT 支持 CUDA-like launch |
| device | __device__ |
__aicore__ |
migrate |
grammar.md 要求替换 device qualifier |
| device | __shared__ |
__ubuf__ |
migrate |
grammar.md 要求替换 shared memory;当前代码未使用 shared buffer |
| device | __syncthreads |
asc_syncthreads |
migrate |
当前实现未使用同步;完整 counterpart 需要 |
| device | atomicAdd |
asc_atomic_add |
migrate |
完整 mbtopk / ROCm warp compaction 需要 |
| device | WARP_BALLOT / WARP_SHFL_DOWN |
asc_ballot / asc_shfl_down |
migrate |
完整 warp counterpart 需要 |
| device | __float_as_int / __int_as_float |
同名 API | direct_replace |
TopKTypeConfig<float> 完整 radix counterpart 需要 |
| runtime | at::cuda::getCurrentCUDAStream() |
c10_npu::getCurrentNPUStream().stream(true) |
migrate |
当前已使用 |
| runtime | kernel<<<grid, block, smem, stream>>> |
SIMT <<<grid, block, dyn_ubuf, acl_stream>>> |
migrate |
当前已使用 |
| runtime | CUDACachingAllocator::allocate |
NPU tensor / torch_npu allocator | migrate |
当前未实现完整临时 buffer 链 |
| runtime | cudaMemsetAsync |
aclrtMemsetAsync |
migrate |
完整 mbtopk 需要 |
| helper | CUB BlockScan / inclusive_sum_by_key |
Ascend counterpart | blocked |
当前未实现共享 scan counterpart |
5. 依赖闭包分析
| 依赖项 | 依赖类型 | 来源 | 处理方式 | 说明 |
|---|---|---|---|---|
| ATen schema / meta | 公共上层抽象 | native_functions.yaml / Sorting.cpp |
reuse |
out-of-tree 中通过 Python wrapper 保持调用语义 |
TensorTopK.cpp host 分支 |
CUDA 特有依赖 | CUDA backend | migrate |
k==0、largest/sorted 参数、dim wrap 已迁移 |
TensorInfo / IndexToOffset |
CUDA 特有依赖 | ATen/cuda/detail |
migrate |
当前用 contiguous + outer/inner slice counterpart |
TopKTypeConfig |
CUDA 特有依赖 | SortingRadixSelect.cuh |
blocked |
完整 radix bit ordering 尚未复刻 |
radixSelect |
CUDA 特有依赖 | SortingRadixSelect.cuh |
blocked |
当前未完成一比一 counterpart |
CUB BlockScan |
CUDA 特有依赖 | ATen/cuda/cub.cuh |
blocked |
需实现可复用 block scan |
CUB inclusive_sum_by_key |
CUDA 特有依赖 | ATen/cuda/cub.cuh |
blocked |
multi-block prefix 依赖 |
sortKeyValueInplace / sort_outf |
CUDA 特有依赖 | SortUtils.cuh / CUDA sort |
blocked |
当前 top-k 输出为 selection 顺序,未复刻完整后处理 |
torch_npu |
其他 | Python/C++ 扩展 | migrate |
已用于构建、NPU stream、PrivateUse1 custom op |
bisheng / SIMT API |
其他 | CANN 9.0 | migrate |
已构建并运行非 double 子集 |
6. 能力覆盖矩阵
| 能力项 | 源能力说明 | 当前处理 | 影响范围 | 恢复条件/备注 |
|---|---|---|---|---|
| dtype 覆盖 | float、double、uint8/int8/int16/int32/int64、half、bfloat16 | blocked |
double 未通过;其他子集已验证部分 dtype | bisheng 修复 double device codegen 后恢复 |
| 行为分支覆盖 | largest=True/False |
preserve |
selection 方向 | float32 largest/smallest 已验证 |
| 行为分支覆盖 | sorted=True/False |
migrate |
输出顺序 | 当前始终按选择顺序输出;sorted=False 合法但不是源未排序实现细节 |
| layout / stride 假设 | CUDA 源先 self.contiguous() 后按 dim stride 取 slice |
preserve |
非连续输入 | 当前同样使用 contiguous 后按 outer/inner 计算 |
| shape / size fast path | warp / multi-block / single-block | migrate |
kernel 选择入口 | 命名和 host 选择保留,完整内部算法未一比一复刻 |
| kernel 选择行为 | should_use_warp_topk / should_use_multiblock |
migrate |
launch path | 当前保留 heuristic,其中 warp path 按 Ascend 可编译条件开启 |
| 调用路径 | Python torch.topk |
migrate |
用户入口 | topk_simt.topk() 和 patch 后 torch.topk 已验证 |
| 测试路径 | Python 调用、NPU device、CPU reference | preserve |
验证 | 6 个 Python 用例通过 |
| Ascend 950 PR 验证 | 实机 build + Python 数值/索引校验 | preserve |
已验证子集 | 设备名 Ascend950PR_9599 |
7. 函数迁移方案
7.1 topk_npu 迁移方案
- 原始职责:承接
topk_out_cuda的输入检查、shape 创建、dim wrap、kernel launch 和 sorted 后处理。 - 迁移方式:扩展中实现
topk_simt::topk_npu,返回(values, indices)。 - 是否直接映射:部分。
- 需要重写的逻辑:ATen structured out 参数改为返回新 tensor;CUDA stream 改为 NPU stream。
- 语法替换点:CUDA launch 改为 Ascend SIMT launch。
- 语法/约束依据:
grammar.md、device_api.yaml、validation-checklist.md。 - 风险说明:不是直接注册 upstream
aten::topk.valuesstructured kernel。
7.2 sbtopk::gatherTopK 迁移方案
- 原始职责:单 block radix-select kth value 后 gather。
- 迁移方式:保留
sbtopk::gatherTopKkernel 名称和路径入口。 - 是否直接映射:否。
- 需要重写的逻辑:当前使用逐 slice selection,未完整复刻 radixSelect + prefix scan。
- 语法替换点:
__device__helper 改__aicore__。 - 语法/约束依据:
grammar.md。 - 风险说明:这是未完成的一比一复刻项,不能声明完整成功。
7.3 warptopk::warpMergeSortTopK 迁移方案
- 原始职责:小 slice warp merge sort topk。
- 迁移方式:保留路径入口与 launch heuristic。
- 是否直接映射:否。
- 需要重写的逻辑:完整
WarpMergeSortcounterpart 尚未实现。 - 语法替换点:warp primitive 后续需映射到
asc_shfl*/asc_ballot。 - 语法/约束依据:
device_api.yaml中 warp API 映射。 - 风险说明:当前验证的是行为子集,不是 warp merge sort 内核等价实现。
7.4 mbtopk 迁移方案
- 原始职责:大 slice 多 block radix histogram、scan-by-key 和 gather。
- 迁移方式:保留
mbtopknamespace、computeBlockDigitCounts名称和 host 选择入口。 - 是否直接映射:否。
- 需要重写的逻辑:CUB
BlockScan、inclusive_sum_by_key、临时 buffer 和 radix pass 尚待实现。 - 语法替换点:CUB/Allocator/cudaMemsetAsync 替换为 Ascend counterpart。
- 语法/约束依据:
runtime_api.yaml中cudaMemsetAsync映射,device_api.yaml中 atomic/warp 映射。 - 风险说明:
double版本触发bishengcodegen 段错误,是当前硬阻塞。
7.5 可复用抽象 / Ascend Counterpart 方案
- 候选抽象层:
TopKTypeConfig、radix select、block scan、scan-by-key、TensorInfo-like slice offset helper。 - 来源文件:
SortingRadixSelect.cuh、TensorTopK.cu、ATen/cuda/cub.cuh。 - 是否可被兄弟算子复用:是,sort/select/search 等算子可复用。
- 本次决策:
blocked - 决策原因:用户不同意降级;完整 counterpart 工程量超过单算子局部 patch,且 double codegen 当前失败。
- 若不实现 counterpart,对当前与后续算子的影响:无法声明
TensorTopK.cu一比一迁移完成,只能声明非 double 行为子集已验证。
8. 迁移实施顺序
- 读取
TensorTopK.cu/.cpp/.h,确认模式为torch_npu。 - 建立
ported-ops/topk/extension_cpp扩展骨架。 - 编写 SIMT kernel 与 Python wrapper。
- 运行 build,定位
doublecodegen 段错误。 - 按用户拒绝降级原则,将 double 保真标记为阻塞,不改 float 替代。
- 修复
RunOpApiV2调用卡死,改为直接 SIMT launch。 - 运行 Ascend 950 PR Python 验证。
- 运行 editable 安装与裸导入验证。
- 更新中文计划和 README。
9. 测试优先策略
- 验证入口类型:
Python、package - 目标行为:
topk_simt.topk与 patch 后torch.topk返回正确 values/indices。 - 对应验证方式:CPU
torch.topkreference;bfloat16 tie 场景验证 gathered values。 - 计划复用的现有测试:
pytorch/test/test_sort_and_select.py的 topk 行为意图。 - 计划新增的测试:
ported-ops/topk/test/test_topk_simt.py - 预期失败或预期缺口:double 保真、完整 radix/CUB counterpart。
- 通过标准:支持子集 Python 测试全部通过;安装后可 import。
10. 风险与阻塞
10.0 降级决策门
- 是否存在重大降级候选:是
out-of-tree是否仅为默认允许路径且不单独触发审批:是- 若存在,候选项:double 改 float、完整 radix/CUB 算法改为单 kernel selection、直接 host fallback
- 是否涉及抽象层未一比一复刻:是
- 是否涉及多路径 kernel 调度被简化为单路径:当前保留多路径入口,但内部算法未完整复刻
- 与用户沟通前是否暂停编码:用户已在初始请求中明确“不同意降级,坚持一对一迁移”
- 在获得用户明确批准前是否禁止编码:是;未实现用户未批准的降级结论
- 是否已触发硬停止审批门:是
- 用户是否已明确批准:用户明确不批准降级
- 备选方案对比:
- 方案 A:继续补齐完整 Ascend counterpart,保留 double,等待/修复 bisheng double codegen。
- 方案 B:降级为非 double 单 kernel fallback。
- 对比维度:
- 迁移成本:A 高,B 低
- 迁移难度:A 高,B 低
- 迁移收益:A 可复用且接近源,B 仅局部可用
- 影响算子范围:A 影响 sort/select 族,B 只影响当前扩展
- 验证影响:A 可覆盖源能力,B 不能覆盖源能力
- 恢复路径:A 直接延续,B 后续仍需重写
- 用户最终选择:拒绝降级,选择方案 A;当前实现仅作为可运行子集和阻塞证据。
10.1 已识别风险
doubledevice codegen 段错误。RunOpApiV2包装 custom SIMT launch 时调用卡死。- bfloat16 tie 场景 indices 可与 CPU reference 不同,但 gathered value 正确。
- 当前未完成 scan-by-key/radix-select 共享 counterpart。
10.2 已识别阻塞
UNSUPPORTED_FEATURE_ERROR/ compiler blocker:device-sidedouble保真。VALIDATION_FAILEDfor full source surface:完整 dtype/algorithm surface 未覆盖。
10.3 阻塞处理建议
- 建议动作:优先最小化复现
bishengdouble kernel crash,向编译器侧确认支持状态;并实现TopKTypeConfig+ block scan counterpart。 - 所需条件:编译器修复、Ascend scan-by-key 或本地实现、ATen PrivateUse1 structured 注册方案。
- 下一步责任:继续迁移共享 backend counterpart。
10.4 根因分析记录
-
问题现象:
mbtopk::computeBlockDigitCounts<double>构建时bisheng段错误。 -
初步根因判断:device-side double lowering 触发后端崩溃,与
constraints.md中 double restricted 风险一致。 -
验证方法:去除 double device 实例化后构建通过。
-
处理结论:不采用 float 降级,记录为阻塞。
-
问题现象:使用
at_npu::native::OpCommand::RunOpApiV2包装 kernel launch 时 Python 调用超时。 -
初步根因判断:自定义 SIMT kernel launch 与
RunOpApiV2包装组合不适配。 -
验证方法:改为直接
kernel<<<...>>>后 4 元素用例通过。 -
处理结论:当前采用直接 launch。
11. 验证计划
11.1 环境检查
- 设备检查:
/dev/davinci0、/dev/davinci_manager、/dev/hisi_hdc存在;torch.npu.get_device_name(0)为Ascend950PR_9599 - 环境变量检查:
ASCEND_TOOLKIT_HOME=/usr/local/Ascend/cann-9.0.0 - 编译工具检查:
/usr/local/Ascend/cann-9.0.0/bin/bisheng
11.2 原生/C++ 侧验证计划
- 验证对象:
topk_simt/_C*.so - 验证入口:
python3 setup.py build_ext --inplace - 测试方法:编译扩展。
- 首次验证结果:包含 double 时失败;去除 double device 实例化并记录阻塞后通过。
- 通过标准:非 double 子集构建通过。
11.3 Python 侧验证计划
- 是否需要:是
- 验证入口:
python3 -m unittest discover -s ported-ops/topk/test -v - 验证方法:NPU 输出对比 CPU reference。
- 首次验证结果:重复值 tie 导致 indices reference 调整;最终通过。
- 通过标准:6 个测试全部通过。
11.4 出包/安装验证计划
- 是否需要:是
- 验证入口:
python3 -m pip install -e . --no-build-isolation、import topk_simt - 验证方法:editable 安装后裸导入。
- 首次验证结果:初次裸导入缺少
libtorch_npu.so;在__init__.py先导入torch_npu后通过。 - 通过标准:安装成功且
import topk_simt成功。
12. 迁移进展
| 项目 | 状态 | 说明 |
|---|---|---|
| 模式选择 | completed | torch_npu |
| 扩展骨架 | completed | setup.py / pyproject.toml / package |
| SIMT kernel | partial | 非 double 子集可运行,完整 radix/CUB 未完成 |
| Python wrapper | completed | topk_simt.topk / patch |
| 构建验证 | completed | 非 double 子集通过 |
| Python 验证 | completed | 6 tests OK on Ascend950PR_9599 |
| 安装导入 | completed | editable install + import OK |
| 一比一完整复刻 | blocked | double 与共享 backend counterpart |
13. 执行记录与证据
- 环境初始化命令:当前 shell 已具备 CANN PATH;
ASCEND_TOOLKIT_HOME=/usr/local/Ascend/cann-9.0.0 - 编译命令:
python3 setup.py build_ext --inplace - 原生/C++ 侧验证命令:同编译命令
- Python 侧验证命令:
PYTHONUNBUFFERED=1 PYTHONPATH=/home/y00621698/simt-buddy/tasks/ported-ops/topk/extension_cpp timeout 180 python3 -m unittest discover -s /home/y00621698/simt-buddy/tasks/ported-ops/topk/test -v - 出包/安装命令:
python3 -m pip install -e . --no-build-isolation - 关键语法/约束依据:
grammar.md的__device__->__aicore__、shared memory 规则;constraints.md的 device-sidedoublerestricted;validation-checklist.md - 关键结果摘要:
- build 通过;
- 设备为
Ascend950PR_9599; - Python 测试
Ran 6 tests ... OK; - editable install 与
import topk_simt通过; - double 保真阻塞未解除。
14. 最终验证结论
- Ascend 950 PR 实机验证状态:支持子集已在
Ascend950PR_9599验证;完整源能力未验证通过 - 构建状态:非 double 子集构建通过
- 原生/C++ 侧验证状态:构建通过
- Python 侧验证状态:6 个用例通过
- 出包/安装状态:editable install + import 通过
- 结果证据:见第 13 节命令记录
- 失败/阻塞根因摘要:device-side
doublecodegen 段错误;完整 radix/CUB 依赖链未一比一复刻 - 最终结论:
blocked by environment or unsupported feature。已交付可运行的torch_npu子集迁移与验证证据,但不能声明TensorTopK.cu完整一对一迁移成功。