"""
DEPRECATED: 此测试对应的 Selector Agent (core/agent/selector.py) 已标记废弃,
后续将迁移到 op/agents,届时此测试也需要更新。
Selector Agent的单元测试 - 精简版
只测试核心流程
Selector Agent 的作用:
1. 接收候选文档列表
2. 使用 LLM 筛选相关文档
3. 返回按相关性排序的文档名列表(第1个最相关)
重要:Selector 返回的排序列表会被 HandwriteSampler 使用,
通过加权采样确保相关性高的文档有更高的被选中概率。
"""
import pytest
from unittest.mock import patch, AsyncMock
from akg_agents.core.agent.selector import Selector
@pytest.fixture
def mock_config():
"""Mock配置"""
return {
'agent_model_config': {
'default': {
'model_name': 'test-model',
'temperature': 0.3,
'max_tokens': 1000
}
}
}
@pytest.fixture
def sample_candidates():
"""示例候选文档"""
return [
{
'name': 'relu_001',
'framework_code': 'def relu_torch(x): return torch.relu(x)',
'impl_code': '@triton.jit\ndef relu_kernel(): pass',
'improvement_doc': '# ReLU优化建议\n使用向量化操作'
},
{
'name': 'gelu_001',
'framework_code': 'def gelu_torch(x): return F.gelu(x)',
'impl_code': '@triton.jit\ndef gelu_kernel(): pass',
'improvement_doc': '# GELU优化建议\n使用近似计算'
},
{
'name': 'matmul_001',
'framework_code': 'def matmul_torch(a, b): return torch.matmul(a, b)',
'impl_code': '@triton.jit\ndef matmul_kernel(): pass',
'improvement_doc': '# MatMul优化建议\n使用tile优化'
}
]
class TestSelectorCore:
"""测试Selector核心功能"""
def test_initialization(self, mock_config):
"""测试1: 初始化和基本属性"""
selector = Selector(
op_name="relu_op",
task_desc="ReLU activation function",
dsl="triton_ascend",
config=mock_config
)
assert selector.op_name == "relu_op"
assert selector.task_desc == "ReLU activation function"
assert selector.dsl == "triton_ascend"
assert selector.llm_step_count == 0
@pytest.mark.asyncio
async def test_run_with_valid_selection(self, mock_config, sample_candidates):
"""测试2: 正常选择流程"""
selector = Selector(
op_name="relu_op",
task_desc="ReLU activation function",
dsl="triton_ascend",
config=mock_config
)
mock_llm_response = '{"selected_names": ["relu_001", "gelu_001"]}'
with patch.object(selector, 'run_llm', new_callable=AsyncMock) as mock_run_llm:
mock_run_llm.return_value = (mock_llm_response, None, None)
selected_names = await selector.run(sample_candidates)
assert selected_names == ["relu_001", "gelu_001"]
assert selector.llm_step_count == 1
mock_run_llm.assert_called_once()
@pytest.mark.asyncio
async def test_filter_invalid_names(self, mock_config, sample_candidates):
"""测试3: 过滤无效文档名"""
selector = Selector(
op_name="relu_op",
task_desc="ReLU activation function",
dsl="triton_ascend",
config=mock_config
)
mock_llm_response = '{"selected_names": ["relu_001", "invalid_001", "gelu_001"]}'
with patch.object(selector, 'run_llm', new_callable=AsyncMock) as mock_run_llm:
mock_run_llm.return_value = (mock_llm_response, None, None)
selected_names = await selector.run(sample_candidates)
assert set(selected_names) == {"relu_001", "gelu_001"}
assert "invalid_001" not in selected_names
@pytest.mark.asyncio
async def test_fallback_on_empty_selection(self, mock_config, sample_candidates):
"""测试4: 空选择时的fallback"""
selector = Selector(
op_name="relu_op",
task_desc="ReLU activation function",
dsl="triton_ascend",
config=mock_config
)
mock_llm_response = '{"selected_names": []}'
with patch.object(selector, 'run_llm', new_callable=AsyncMock) as mock_run_llm:
mock_run_llm.return_value = (mock_llm_response, None, None)
selected_names = await selector.run(sample_candidates)
assert len(selected_names) == 3
assert set(selected_names) == {"relu_001", "gelu_001", "matmul_001"}
@pytest.mark.asyncio
async def test_selection_order_preserved(self, mock_config, sample_candidates):
"""测试6: 验证选择顺序被保留(重要!用于加权采样)"""
selector = Selector(
op_name="relu_op",
task_desc="ReLU activation function",
dsl="triton_ascend",
config=mock_config
)
mock_llm_response = '{"selected_names": ["relu_001", "gelu_001", "matmul_001"]}'
with patch.object(selector, 'run_llm', new_callable=AsyncMock) as mock_run_llm:
mock_run_llm.return_value = (mock_llm_response, None, None)
selected_names = await selector.run(sample_candidates)
assert isinstance(selected_names, list)
assert selected_names == ["relu_001", "gelu_001", "matmul_001"]
print("\n ✓ 顺序保持: LLM排序 → Selector返回 → HandwriteSampler加权采样")