"""
Triton-CUDA 参考数据批量生成脚本
在 CUDA 环境中运行 SGLang / vLLM 的 triton_cuda 算子,批量生成 .pt 参考数据缓存。
生成的 .pt 文件包含 inputs + outputs + init_inputs,可 scp 到 Ascend 环境后
直接用于 triton_ascend 算子生成和验证(无需 CUDA 运行时)。
前置条件:
- CUDA GPU 可用(torch.cuda.is_available())
- source env.sh
- KernelBench 子模块已初始化(部分 benchmark 文件在 thirdparty/ 下)
运行方式:
# 生成全部(sglang + vllm triton_ops + vllm torch_ops)
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py
# 只生成 sglang
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py --source sglang
# 只生成 vllm triton_ops
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py --source vllm_triton
# 只生成 vllm torch_ops
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py --source vllm_torch
# 指定算子
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py --source sglang --ops triton_tanh merge_state_triton
# 指定输出目录和设备
python reproduce/wip/triton-cuda-to-ascend/gen_reference_cache.py --output-dir ./my_cache --device 1
产出:
<output_dir>/
├── manifest.json # 汇总清单
├── sglang/
│ ├── triton_tanh.pt
│ ├── merge_state_triton.pt
│ └── ...
├── vllm_triton/
│ ├── rms_norm_kernel.pt
│ └── ...
└── vllm_torch/
├── silu_and_mul.pt
└── ...
"""
import argparse
import asyncio
import json
import logging
import os
import time
from pathlib import Path
logger = logging.getLogger("gen_reference_cache")
PROJECT_ROOT = Path(__file__).resolve().parents[3]
BENCHMARK_BASE = PROJECT_ROOT / "benchmark" / "akg_kernels_bench" / "thirdparty"
DEFAULT_OUTPUT_DIR = os.path.expanduser("~/.akg/.tmp/reference_data/triton_cuda_cache")
SOURCES = {
"sglang": {
"path": BENCHMARK_BASE / "sglang",
"dsl": "triton_cuda",
"exclude_dirs": ["class_method"],
},
"vllm_triton": {
"path": BENCHMARK_BASE / "vllm" / "triton_ops",
"dsl": "triton_cuda",
"exclude_dirs": [],
},
"vllm_torch": {
"path": BENCHMARK_BASE / "vllm" / "torch_ops",
"dsl": "triton_cuda",
"exclude_dirs": [],
},
}
def discover_ops(source_name: str, ops_filter: list = None) -> list:
"""发现指定 source 下的所有算子文件,返回 [(op_name, file_path), ...]"""
source_cfg = SOURCES[source_name]
base_path = source_cfg["path"]
exclude_dirs = source_cfg["exclude_dirs"]
if not base_path.exists():
logger.warning(f"路径不存在: {base_path}")
return []
results = []
for f in sorted(base_path.iterdir()):
if f.is_dir() and f.name in exclude_dirs:
continue
if f.is_dir():
continue
if not f.suffix == ".py" or f.name.startswith("__"):
continue
op_name = f.stem
if ops_filter and op_name not in ops_filter:
continue
results.append((op_name, str(f)))
return results
async def gen_reference_for_op(
op_name: str,
op_file: str,
source_name: str,
output_dir: str,
device_id: int,
timeout: int,
) -> dict:
"""为单个算子生成参考数据"""
from akg_agents.op.verifier.kernel_verifier import KernelVerifier
from akg_agents.op.config.config_validator import load_config
from akg_agents.core.worker.manager import get_worker_manager
source_cfg = SOURCES[source_name]
dsl = source_cfg["dsl"]
framework = "torch"
backend = "cuda"
arch = "a100"
result = {
"op_name": op_name,
"source": source_name,
"success": False,
"pt_path": None,
"error": None,
"elapsed_s": 0,
}
t0 = time.time()
try:
with open(op_file, "r", encoding="utf-8") as f:
framework_code = f.read()
config = load_config(dsl, backend=backend)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
result["error"] = f"No worker available for {backend}/{arch}"
return result
verifier = KernelVerifier(
op_name=op_name,
framework_code=framework_code,
task_id=f"ref_cache_{source_name}_{op_name}",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
config=config,
worker=worker,
)
success, log, ref_bytes = await verifier.generate_reference_data(
framework_code, save_inputs=True, timeout=timeout
)
if not success:
result["error"] = log[:500]
return result
if len(ref_bytes) == 0:
result["error"] = "Empty reference data"
return result
sub_dir = os.path.join(output_dir, source_name)
os.makedirs(sub_dir, exist_ok=True)
pt_path = os.path.join(sub_dir, f"{op_name}.pt")
with open(pt_path, "wb") as f:
f.write(ref_bytes)
result["success"] = True
result["pt_path"] = pt_path
result["size_bytes"] = len(ref_bytes)
except Exception as e:
result["error"] = str(e)[:500]
finally:
result["elapsed_s"] = round(time.time() - t0, 1)
return result
def parse_args():
parser = argparse.ArgumentParser(
description="批量生成 SGLang/vLLM triton_cuda 参考数据缓存",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--source", nargs="+", default=list(SOURCES.keys()),
choices=list(SOURCES.keys()),
help="指定数据源(默认全部)",
)
parser.add_argument(
"--ops", nargs="+", default=None,
help="只生成指定算子(默认全部)",
)
parser.add_argument(
"--output-dir", default=DEFAULT_OUTPUT_DIR,
help=f"输出目录(默认 {DEFAULT_OUTPUT_DIR})",
)
parser.add_argument(
"--device", type=int, default=int(os.getenv("DEVICE_ID", "0")),
help="CUDA 设备 ID(默认 $DEVICE_ID 或 0)",
)
parser.add_argument(
"--timeout", type=int, default=120,
help="单个算子超时时间/秒(默认 120)",
)
parser.add_argument(
"--concurrency", type=int, default=1,
help="并行度(默认 1,参考数据生成建议串行以避免 GPU OOM)",
)
return parser.parse_args()
async def main():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
args = parse_args()
output_dir = os.path.expanduser(args.output_dir)
os.makedirs(output_dir, exist_ok=True)
from akg_agents.core.worker.manager import register_local_worker
await register_local_worker([args.device], backend="cuda", arch="a100")
all_ops = []
for source_name in args.source:
ops = discover_ops(source_name, args.ops)
logger.info(f"[{source_name}] 发现 {len(ops)} 个算子")
for op_name, op_file in ops:
all_ops.append((source_name, op_name, op_file))
if not all_ops:
logger.error("未发现任何算子,请检查 --source 和 --ops 参数")
return
print(f"\n{'='*70}")
print(" Triton-CUDA 参考数据批量生成")
print(f"{'='*70}")
print(f" 算子总数: {len(all_ops)}")
print(f" 数据源: {', '.join(args.source)}")
print(f" 输出目录: {output_dir}")
print(f" CUDA 设备: {args.device}")
print(f" 超时/算子: {args.timeout}s")
print(f"{'='*70}\n")
results = []
success_count = 0
fail_count = 0
for i, (source_name, op_name, op_file) in enumerate(all_ops):
print(f"[{i+1}/{len(all_ops)}] {source_name}/{op_name} ... ", end="", flush=True)
r = await gen_reference_for_op(
op_name, op_file, source_name, output_dir, args.device, args.timeout
)
results.append(r)
if r["success"]:
success_count += 1
size_kb = r.get("size_bytes", 0) / 1024
print(f"OK ({r['elapsed_s']}s, {size_kb:.1f}KB)")
else:
fail_count += 1
print(f"FAIL ({r['elapsed_s']}s) - {r['error'][:80]}")
manifest = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"output_dir": output_dir,
"device": args.device,
"total": len(results),
"success": success_count,
"failed": fail_count,
"ops": results,
}
manifest_path = os.path.join(output_dir, "manifest.json")
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
print(f"\n{'='*70}")
print(" 完成")
print(f"{'='*70}")
print(f" 成功: {success_count}/{len(results)}")
print(f" 失败: {fail_count}/{len(results)}")
print(f" 清单: {manifest_path}")
print(f" 输出: {output_dir}")
print(f"{'='*70}")
if fail_count > 0:
print("\n 失败算子:")
for r in results:
if not r["success"]:
print(f" - {r['source']}/{r['op_name']}: {r['error'][:100]}")
print(f"\n 下一步:将 {output_dir} 拷贝到 Ascend 环境后运行:")
print(" python reproduce/wip/triton-cuda-to-ascend/run_adaptive_search_with_cache.py \\")
print(" --source sglang")
if __name__ == "__main__":
asyncio.run(main())