"""
Skill 选择 ST — 需要 LLM,验证真实选择准确性
覆盖:
- LLM 对不同算子类型选择正确的 guide(看护核心匹配准确性)
- initial 阶段不含 case,debug 阶段含 fundamental
- exclude_skill_names / force_skill_names 运行时生效
- 选择耗时 baseline
"""
import time
import pytest
import logging
logger = logging.getLogger(__name__)
RELU_TASK = """\
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(x)
def get_inputs():
return [torch.randn(16, 16384)]
def get_init_inputs():
return []
"""
MATMUL_TASK = """\
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return torch.matmul(a, b)
def get_inputs():
return [torch.randn(1024, 1024), torch.randn(1024, 1024)]
def get_init_inputs():
return []
"""
MATMUL_LARGE_K_TASK = """\
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return torch.matmul(A, B)
M = 128
N = 128
K = 131072
def get_inputs():
A = torch.randn(M, K)
B = torch.randn(K, N)
return [A, B]
def get_init_inputs():
return []
"""
@pytest.fixture(scope="module")
def kernel_gen():
from akg_agents.op.agents.kernel_gen import KernelGen
return KernelGen()
MUST_MATCH = [
("relu", RELU_TASK, "elementwise"),
("matmul", MATMUL_TASK, "matmul"),
]
@pytest.mark.level2
@pytest.mark.use_model
@pytest.mark.asyncio
@pytest.mark.parametrize("op_name,task_desc,expected_keyword", MUST_MATCH,
ids=[c[0] for c in MUST_MATCH])
async def test_guide_selection_accuracy(kernel_gen, op_name, task_desc, expected_keyword):
"""看护:典型算子必须选中对应 guide"""
t0 = time.time()
skills = await kernel_gen._select_skills_by_stage(
stage="initial", op_name=op_name, task_desc=task_desc,
verifier_error="", dsl="triton_ascend", backend="ascend", framework="torch",
)
elapsed = time.time() - t0
logger.info(f"[{op_name}] {elapsed:.1f}s, {len(skills)} skills")
guide_names = [s.name for s in skills if getattr(s, "category", "") == "guide"]
assert any(expected_keyword in n for n in guide_names), \
f"Expected guide containing '{expected_keyword}', got: {guide_names}"
@pytest.mark.level2
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_initial_stage_no_case(kernel_gen):
"""initial 阶段不应包含 case"""
skills = await kernel_gen._select_skills_by_stage(
stage="initial", op_name="relu", task_desc=RELU_TASK,
verifier_error="", dsl="triton_ascend", backend="ascend", framework="torch",
)
categories = {getattr(s, "category", "") for s in skills}
assert "case" not in categories
@pytest.mark.level2
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_debug_stage_has_fundamentals(kernel_gen):
"""debug 阶段仍含 fundamental"""
skills = await kernel_gen._select_skills_by_stage(
stage="debug", op_name="relu", task_desc=RELU_TASK,
verifier_error="RuntimeError: shape mismatch",
dsl="triton_ascend", backend="ascend", framework="torch",
)
categories = {getattr(s, "category", "") for s in skills}
assert "fundamental" in categories or "reference" in categories
@pytest.mark.level2
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_exclude_skill_names():
"""exclude_skill_names 排除指定 skill"""
from akg_agents.op.agents.kernel_gen import KernelGen
kg = KernelGen()
kg.exclude_skill_names = ["triton-ascend-optimization"]
skills = await kg._select_skills_by_stage(
stage="initial", op_name="relu", task_desc=RELU_TASK,
verifier_error="", dsl="triton_ascend", backend="ascend", framework="torch",
)
assert "triton-ascend-optimization" not in [s.name for s in skills]
@pytest.mark.level2
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_force_skill_names():
"""force_skill_names 强制导入指定 skill"""
from akg_agents.op.agents.kernel_gen import KernelGen
kg = KernelGen()
kg.force_skill_names = ["triton-ascend-case-reduction-amax-medium"]
skills = await kg._select_skills_by_stage(
stage="initial", op_name="relu", task_desc=RELU_TASK,
verifier_error="", dsl="triton_ascend", backend="ascend", framework="torch",
)
assert "triton-ascend-case-reduction-amax-medium" in [s.name for s in skills]
@pytest.mark.level2
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_case_selection_matmul_large_k():
from akg_agents.op.agents.kernel_gen import KernelGen
kg = KernelGen()
await kg._select_skills_by_stage(
stage="initial", op_name="matmul_large_k",
task_desc=MATMUL_LARGE_K_TASK,
verifier_error="", dsl="triton_ascend", backend="ascend", framework="torch",
)
cache = kg._initial_selection_cache
assert cache is not None, "initial stage should populate _initial_selection_cache"
cached_case_names = [s.name for s in cache.get("case", [])]
logger.info(f"[matmul_large_k] LLM selected cases: {cached_case_names}")
assert "triton-ascend-case-matmul-large-k" in cached_case_names, \
f"Expected 'triton-ascend-case-matmul-large-k' in LLM-selected cases, got: {cached_case_names}"
@pytest.mark.level2
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_selection_latency_baseline(kernel_gen):
"""记录选择耗时 baseline"""
t0 = time.time()
await kernel_gen._select_skills_by_stage(
stage="initial", op_name="gelu", task_desc=RELU_TASK.replace("relu", "gelu"),
verifier_error="", dsl="triton_ascend", backend="ascend", framework="torch",
)
elapsed = time.time() - t0
logger.info(f"Skill selection latency: {elapsed:.2f}s")
if elapsed > 30:
logger.warning(f"Selection took {elapsed:.1f}s (>30s)")