"""
Skill 选择与 Evolution 逻辑 UT — 不依赖 LLM,纯逻辑验证
覆盖:
1. dsl_to_dir_key 转换
2. KernelGen 实例属性、PARAMETERS_SCHEMA、run() 签名
3. _infer_case_type metadata 推断(兼容旧 case category)
4. _parse_unified_selection JSON 解析
5. _assemble_skill_contents 排序与组装(含 fix/improvement category)
6. Stage → category 注入逻辑(含 fix/improvement 识别,源码检查)
7. nodes.py / evolution_processors 接口清理验证
8. evolved_skill_loader 已删除
9. AB test build_evolve_config A/B 模式
"""
import inspect
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
import pytest
@pytest.fixture(scope="module")
def kg():
from akg_agents.op.agents.kernel_gen import KernelGen
return KernelGen()
@pytest.fixture(scope="module")
def project_root():
return Path(__file__).resolve().parents[3]
class TestDslToDirKey:
def test_underscore_to_hyphen(self):
from akg_agents.core_v2.skill.metadata import dsl_to_dir_key
assert dsl_to_dir_key("triton_ascend") == "triton-ascend"
def test_upper_case(self):
from akg_agents.core_v2.skill.metadata import dsl_to_dir_key
assert dsl_to_dir_key("Triton_CUDA") == "triton-cuda"
def test_package_import(self):
from akg_agents.core_v2.skill import dsl_to_dir_key
assert dsl_to_dir_key("triton_ascend") == "triton-ascend"
class TestKernelGenInterface:
def test_default_attrs(self, kg):
assert kg.exclude_skill_names == []
assert kg.force_skill_names == []
assert kg.extra_skills == []
def test_attrs_settable(self):
from akg_agents.op.agents.kernel_gen import KernelGen
k = KernelGen()
k.exclude_skill_names = ["a"]
k.force_skill_names = ["b"]
assert k.exclude_skill_names == ["a"]
assert k.force_skill_names == ["b"]
def test_schema_has_exclude_and_force(self, kg):
props = kg.PARAMETERS_SCHEMA.get("properties", {})
assert "exclude_skill_names" in props
assert "force_skill_names" in props
def test_run_signature(self, kg):
params = list(inspect.signature(kg.run).parameters.keys())
assert "exclude_skill_names" in params
assert "force_skill_names" in params
assert "handwrite_suggestions" in params
class TestInferCaseType:
@dataclass
class FakeSkill:
metadata: dict = field(default_factory=dict)
def test_metadata_case_type_fix(self, kg):
assert kg._infer_case_type(self.FakeSkill(metadata={"case_type": "fix"})) == "fix"
def test_metadata_source_error_fix(self, kg):
assert kg._infer_case_type(self.FakeSkill(metadata={"source": "error_fix"})) == "fix"
def test_metadata_case_type_improvement(self, kg):
assert kg._infer_case_type(self.FakeSkill(metadata={"case_type": "improvement"})) == "improvement"
def test_default_improvement(self, kg):
assert kg._infer_case_type(self.FakeSkill()) == "improvement"
class TestParseUnifiedSelection:
def test_plain_json(self, kg):
r = kg._parse_unified_selection('{"guides": ["a"], "cases": ["b"]}')
assert r["guides"] == ["a"]
def test_json_in_code_fence(self, kg):
r = kg._parse_unified_selection('```json\n{"guides": ["x"]}\n```')
assert r["guides"] == ["x"]
def test_empty_and_malformed(self, kg):
assert kg._parse_unified_selection("") == {}
assert kg._parse_unified_selection("{bad}") == {}
class TestAssembleSkillContents:
@dataclass
class FakeSkill:
name: str = ""
category: str = ""
content: str = "content"
def test_empty_list(self, kg):
assert kg._assemble_skill_contents([]) == ""
def test_section_order(self, kg):
skills = [
self.FakeSkill(name="c", category="case", content="c"),
self.FakeSkill(name="f", category="fundamental", content="f"),
self.FakeSkill(name="g", category="guide", content="g"),
self.FakeSkill(name="e", category="example", content="e"),
]
result = kg._assemble_skill_contents(skills)
assert result.find("基础知识与规范") < result.find("算子优化指南") < \
result.find("代码示例参考") < result.find("优化/修复案例")
def test_fix_and_improvement_in_case_section(self, kg):
skills = [
self.FakeSkill(name="fx", category="fix", content="fix-content"),
self.FakeSkill(name="imp", category="improvement", content="imp-content"),
self.FakeSkill(name="f", category="fundamental", content="f"),
]
result = kg._assemble_skill_contents(skills)
assert "优化/修复案例" in result
assert "fix-content" in result
assert "imp-content" in result
class TestStageCategories:
"""验证 _select_skills_by_stage 中各 stage 的 category 注入逻辑。
实际逻辑内嵌在方法体中(非类属性),通过源码检查确认:
- initial: extras = [](不注入 fix/improvement)
- debug: extras = case_fix(fix category 全部注入)
- optimize: extras = _sample_cases(...)(improvement 参与采样)
- 分类时识别 fix / improvement / case 三种 category
"""
@pytest.fixture(autouse=True)
def _load_source(self):
from akg_agents.op.agents.kernel_gen import KernelGen
self.source = inspect.getsource(KernelGen._select_skills_by_stage)
def test_initial_no_case(self):
assert 'extras = []\n' in self.source
assert '"none (initial)"' in self.source
assert 'always_skills' in self.source
def test_debug_and_optimize_have_case(self):
assert 'extras = case_fix' in self.source or "extras = [s for s in case_fix" in self.source
assert '_sample_cases' in self.source
def test_recognizes_fix_and_improvement_categories(self):
assert 'cat == "fix"' in self.source
assert 'cat == "improvement"' in self.source
class TestInterfaceCleanup:
def test_kernel_gen_node_no_handwrite(self):
from akg_agents.op.langgraph_op.nodes import NodeFactory
source = inspect.getsource(NodeFactory.create_kernel_gen_node)
assert "handwrite_suggestions" not in source
def test_evolution_processors_no_old_refs(self):
from akg_agents.op.utils.evolve.evolution_processors import InitializationProcessor
source = inspect.getsource(InitializationProcessor.initialize)
assert "evolved_suggestions" not in source
assert "evolved_skill_loader" not in source
class TestEvolvedLoaderDeleted:
def test_file_not_exists(self, project_root):
path = (project_root / "python" / "akg_agents" / "op"
/ "utils" / "evolved_skill_loader.py")
assert not path.exists()
class TestBuildEvolveConfig:
@pytest.fixture(autouse=True)
def _setup_path(self, project_root):
import sys
p = str(project_root / "examples" / "kernel_related" / "skill_evolution")
if p not in sys.path:
sys.path.insert(0, p)
def test_a_mode_has_exclude(self, project_root):
import yaml
from ab_test_utils import build_evolve_config
config_dir = project_root / "python" / "akg_agents" / "op" / "config"
config_files = list(config_dir.glob("*.yaml"))
if not config_files:
pytest.skip("no config yaml found")
with tempfile.TemporaryDirectory() as run_dir:
path = build_evolve_config(
group=1, ab_mode="A", run_dir=run_dir, device=0,
evolved_skill_dir="", base_config_path=str(config_files[0]),
max_rounds=1, project_root=project_root,
)
with open(Path(path).parent / "agent_config.yaml") as f:
cfg = yaml.safe_load(f) or {}
assert "exclude_skill_names" in cfg
def test_b_mode_has_force(self, project_root):
import yaml
from ab_test_utils import build_evolve_config
config_dir = project_root / "python" / "akg_agents" / "op" / "config"
config_files = list(config_dir.glob("*.yaml"))
if not config_files:
pytest.skip("no config yaml found")
with tempfile.TemporaryDirectory() as run_dir:
path = build_evolve_config(
group=1, ab_mode="B", run_dir=run_dir, device=0,
evolved_skill_dir="", base_config_path=str(config_files[0]),
max_rounds=1, project_root=project_root,
)
with open(Path(path).parent / "agent_config.yaml") as f:
cfg = yaml.safe_load(f) or {}
assert "force_skill_names" in cfg