Triton-CUDA → Triton-Ascend 跨平台算子生成
将 SGLang / vLLM 的 triton_cuda 算子改写为 triton_ascend,通过预生成参考数据实现跨平台验证。
整体流程
┌─────────────────────────────────┐ scp / rsync ┌──────────────────────────────────┐
│ CUDA 服务器 (A100) │ ──────────────────── │ Ascend 服务器 (910B) │
│ │ .pt 文件 │ │
│ 1. gen_reference_cache.py │ │ 2. run_adaptive_search_with_ │
│ ↓ │ │ cache.py │
│ 遍历 sglang/vllm 算子 │ │ (adaptive_search, 多轮进化 │
│ ↓ │ │ 搜索 + 性能优化) │
│ 对每个算子: │ │ ↓ │
│ - import Model │ │ 验证: .pt inputs → compare outputs│
│ - get_inputs() → forward() │ │ Profiling: NPU 实际跑性能 │
│ - save inputs + outputs │ │ ↓ │
│ ↓ │ │ PASS/FAIL + speedup │
│ ~/.akg/.tmp/reference_data/ │ │ │
│ triton_cuda_cache/ │ │ │
│ ├── manifest.json │ │ │
│ ├── sglang/*.pt │ │ │
│ ├── vllm_triton/*.pt │ │ │
│ └── vllm_torch/*.pt │ │ │
└─────────────────────────────────┘ └──────────────────────────────────┘
Step 1: 在 CUDA 服务器生成参考数据
前置条件
- CUDA GPU 可用
source env.sh- thirdparty 子模块已初始化:
git submodule update --init "akg_agents/thirdparty/*"
运行
# 生成全部(sglang + vllm,约 90 个算子)
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py
# 只生成 sglang
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py --source sglang
# 只生成 vllm triton 算子
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py --source vllm_triton
# 指定算子
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py \
--source sglang --ops triton_tanh merge_state_triton
# 指定输出目录和设备
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py \
--output-dir /data/ref_cache --device 1
产出
~/.akg/.tmp/reference_data/triton_cuda_cache/
├── manifest.json # 包含全部算子的生成结果清单
├── sglang/
│ ├── triton_tanh.pt # 每个 .pt 包含 inputs + outputs + init_inputs
│ ├── merge_state_triton.pt
│ └── ... # ~46 个算子
├── vllm_triton/
│ ├── rms_norm_kernel.pt
│ └── ... # ~28 个算子
└── vllm_torch/
├── silu_and_mul.pt
└── ... # ~17 个算子
每个 .pt 文件内容(torch.load 后的 dict):
{
'op_name': 'triton_tanh',
'seed': 0,
'save_inputs': True,
'inputs': [tensor_cpu, ...], # 原始输入(forward 前 clone,避免 in-place 污染)
'input_shapes': [(16, 16384), ...],
'input_dtypes': ['torch.float32', ...],
'outputs': [tensor_cpu, ...], # 参考输出
'output_shapes': [(16, 16384), ...],
'output_dtypes': ['torch.float32', ...],
'init_inputs': [...], # Model.__init__ 参数
}
拷贝到 Ascend 服务器
scp -r ~/.akg/.tmp/reference_data/triton_cuda_cache \
ascend-server:~/.akg/.tmp/reference_data/triton_cuda_cache
Step 2: 在 Ascend 上自适应搜索生成 triton_ascend 算子
前置条件
- Ascend NPU 可用
source env.sh- API key 已配置
- .pt 缓存已拷贝到本机
单算子模式
# 按 source + op 定位
python reproduce/wip/triton-cuda-to-ascend/run_adaptive_search_with_cache.py \
--source sglang --op triton_tanh
# 多设备 + 更多搜索轮次
python reproduce/wip/triton-cuda-to-ascend/run_adaptive_search_with_cache.py \
--source sglang --op merge_state_triton \
--devices 0 1 2 3 --max-concurrent 4 --max-tasks 40
# 手动指定文件路径
python reproduce/wip/triton-cuda-to-ascend/run_adaptive_search_with_cache.py \
--ref-pt /path/to/merge_state_triton.pt \
--benchmark-file benchmark/akg_kernels_bench/thirdparty/sglang/merge_state_triton.py
批量模式(断点续存)
不指定 --op 即进入批量模式,串行遍历 cache 目录下所有 .pt,逐个跑 adaptive_search。
进度记录在脚本同级 .tmp/batch_progress.json,重跑自动跳过已完成和已失败的 case。
# 批量跑某个 source 下所有 cache
python reproduce/wip/triton-cuda-to-ascend/run_adaptive_search_with_cache.py \
--source sglang
# 批量跑全部 source(sglang + vllm_triton + vllm_torch,约 91 个算子)
python reproduce/wip/triton-cuda-to-ascend/run_adaptive_search_with_cache.py
# 中途 Ctrl+C 或异常退出后,直接重跑同样的命令即可续跑
# 已完成(status=done)的 case 会自动跳过
进度文件格式
reproduce/wip/triton-cuda-to-ascend/.tmp/
└── batch_progress.json
{
"cases": {
"sglang/triton_tanh": {
"status": "done",
"source": "sglang",
"op_name": "triton_tanh",
"finished_at": "2026-04-08T15:30:00",
"elapsed_time": 120.5,
"result": { "total_success": 8, "best_speedup": 1.35, "..." }
},
"sglang/merge_state_triton": {
"status": "error",
"error": "RuntimeError: ...",
"..."
}
},
"summary": { "total": 46, "succeeded": 30, "failed": 2, "skipped": 32 }
}
重跑某个 case
已完成(done)和已失败(error)的 case 都会被跳过。
想重跑某个 case,手动删掉 JSON 中对应的条目再跑即可。
想全部重跑,删掉整个 .tmp/batch_progress.json。
原理说明
为什么需要两步?
原始 triton_cuda 代码(如 import triton、CUDA kernel)在 Ascend 环境上无法 import,
因此无法直接调用 get_inputs() / Model.forward() 来获取参考输入和输出。
通过预生成参考数据 (.pt),将平台依赖解耦:
- CUDA 端: 运行原始代码,保存 inputs + outputs 到 .pt(所有 tensor 已
.cpu()) - Ascend 端: 加载 .pt,跳过原始 framework 代码的 import,直接用保存的 inputs 测试 LLM 生成的 triton_ascend 代码
关键配置项
每个算子运行时的 config 会注入以下参考数据字段:
config["use_reference_data"] = True # 启用参考数据模式
config["use_reference_inputs"] = True # 同时使用参考输入(不调用 get_inputs())
config["reference_data"] = ref_bytes # .pt 文件的 bytes
这会触发验证模板中的条件分支:
- 不 import framework 文件(避免 CUDA 依赖)
- 从 .pt 加载
init_inputs(用于ModelNew(*init_inputs)) - 从 .pt 加载
inputs(用于ModelNew.forward(*inputs)) - 从 .pt 加载
outputs(作为比对基准)
in-place 安全
generate_reference_data(save_inputs=True) 在 model.forward() 之前 clone inputs,
确保即使 forward 有 in-place 操作(如 x.add_()),保存的 inputs 也是原始值。
推荐工作流
# 1. CUDA 上生成参考数据
python gen_reference_cache.py --source sglang
# 2. scp 到 Ascend
scp -r ~/.akg/.tmp/reference_data/triton_cuda_cache ascend:~/.akg/.tmp/reference_data/
# 3. 批量跑自适应搜索(断点续存,可随时中断重跑)
python run_adaptive_search_with_cache.py --source sglang --max-tasks 30
# 4. 对重点算子加大搜索力度
python run_adaptive_search_with_cache.py --source sglang --op triton_tanh --max-tasks 60
手动使用 .pt 缓存(编程接口)
如果不想用批量脚本,可以在自己的代码中直接使用:
import asyncio
from akg_agents.op.langgraph_op.task import LangGraphTask
from akg_agents.op.config.config_validator import load_config
from akg_agents.core.worker.manager import register_local_worker
async def convert_single_op():
await register_local_worker([0], backend="ascend", arch="ascend910b4")
with open("~/.akg/.tmp/reference_data/triton_cuda_cache/sglang/triton_tanh.pt", "rb") as f:
ref_bytes = f.read()
with open("benchmark/akg_kernels_bench/thirdparty/sglang/triton_tanh.py") as f:
framework_code = f.read()
config = load_config(dsl="triton_ascend", backend="ascend")
config["use_reference_data"] = True
config["use_reference_inputs"] = True
config["reference_data"] = ref_bytes
config["max_step"] = 10
task = LangGraphTask(
op_name="triton_tanh",
task_desc=framework_code,
task_id="manual_001",
backend="ascend", arch="ascend910b4",
dsl="triton_ascend",
config=config, framework="torch",
workflow="coder_only_workflow",
)
op_name, success, info = await task.run()
print(f"{'PASS' if success else 'FAIL'}: {op_name}")
if success:
print(info.get("coder_code", "")[:200])
asyncio.run(convert_single_op())
已覆盖算子
| 数据源 | 算子数 | 说明 |
|---|---|---|
sglang |
~46 | SGLang triton kernel(不含 class_method) |
vllm_triton |
~28 | vLLM triton kernel |
vllm_torch |
~17 | vLLM torch op(激活函数、归一化、RoPE 等) |
| 合计 | ~91 |