"""ToolExecutor 单元测试
重点覆盖 required 参数校验功能(_validate_required_params / _get_parameters_schema),
防止 LLM 遗漏必需参数时静默传递空值导致下游不可预期的错误。
"""
import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from akg_agents.core_v2.tools.tool_executor import ToolExecutor
WORKFLOW_SCHEMA = {
"type": "object",
"properties": {
"op_name": {"type": "string", "description": "算子名称"},
"task_desc": {"type": "string", "description": "任务描述(框架代码)"},
"dsl": {"type": "string", "description": "目标 DSL"},
"framework": {"type": "string", "description": "框架"},
"backend": {"type": "string", "description": "后端"},
"arch": {"type": "string", "description": "架构"},
"task_id": {"type": "string", "description": "任务 ID(可选)", "default": ""},
"user_requirements": {
"type": "string",
"description": "用户额外需求(可选)",
"default": "",
},
},
"required": ["op_name", "task_desc", "dsl", "framework", "backend", "arch"],
}
AGENT_SCHEMA = {
"type": "object",
"properties": {
"user_input": {"type": "string", "description": "用户需求描述"},
"framework": {
"type": "string",
"description": "目标框架",
"default": "torch",
},
},
"required": ["user_input"],
}
NO_REQUIRED_SCHEMA = {
"type": "object",
"properties": {
"mode": {"type": "string", "description": "模式"},
},
"required": [],
}
def _make_registry_entry(schema):
"""构造 agent_registry / workflow_registry 中的 config 结构"""
return {
"config": {
"function": {
"name": "test_tool",
"description": "test",
"parameters": schema,
}
}
}
@pytest.fixture
def executor_with_workflow():
"""创建包含 workflow 注册的 ToolExecutor"""
return ToolExecutor(
workflow_registry={
"call_kernelgen_workflow": {
**_make_registry_entry(WORKFLOW_SCHEMA),
"workflow_class": MagicMock(),
"workflow_name": "KernelGenOnlyWorkflow",
}
},
)
@pytest.fixture
def executor_with_agent():
"""创建包含 agent 注册的 ToolExecutor"""
return ToolExecutor(
agent_registry={
"call_op_task_builder": {
**_make_registry_entry(AGENT_SCHEMA),
"agent_class": MagicMock(),
}
},
)
@pytest.fixture
def executor_with_no_required():
"""创建 required 为空的 ToolExecutor"""
return ToolExecutor(
agent_registry={
"call_skill_evolution": {
**_make_registry_entry(NO_REQUIRED_SCHEMA),
"agent_class": MagicMock(),
}
},
)
@pytest.fixture
def executor_empty():
"""创建空注册表的 ToolExecutor"""
return ToolExecutor()
class TestGetParametersSchema:
"""测试 _get_parameters_schema 从三种注册表中正确获取 schema"""
def test_from_workflow_registry(self, executor_with_workflow):
schema = executor_with_workflow._get_parameters_schema("call_kernelgen_workflow")
assert schema is not None
assert "required" in schema
assert "task_desc" in schema["required"]
def test_from_agent_registry(self, executor_with_agent):
schema = executor_with_agent._get_parameters_schema("call_op_task_builder")
assert schema is not None
assert "required" in schema
assert "user_input" in schema["required"]
@patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry")
def test_from_tool_registry(self, mock_registry, executor_empty):
mock_tool = MagicMock()
mock_tool.parameters = {
"type": "object",
"properties": {"file_path": {"type": "string"}},
"required": ["file_path"],
}
mock_registry.get_tool.return_value = mock_tool
schema = executor_empty._get_parameters_schema("read_file")
assert schema is not None
assert "file_path" in schema["required"]
@patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry")
def test_unknown_tool_returns_none(self, mock_registry, executor_empty):
mock_registry.get_tool.return_value = None
schema = executor_empty._get_parameters_schema("nonexistent_tool")
assert schema is None
def test_workflow_takes_priority_over_tool_registry(self, executor_with_workflow):
"""workflow_registry 优先于 ToolRegistry"""
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_registry:
mock_registry.get_tool.return_value = MagicMock(
parameters={"required": ["something_else"]}
)
schema = executor_with_workflow._get_parameters_schema(
"call_kernelgen_workflow"
)
assert "task_desc" in schema["required"]
mock_registry.get_tool.assert_not_called()
class TestValidateRequiredParams:
"""测试 _validate_required_params 校验逻辑"""
def test_all_required_present_passes(self, executor_with_workflow):
"""所有必需参数都提供时校验通过"""
args = {
"op_name": "relu",
"task_desc": "import torch\nclass Model...",
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is None
def test_missing_param_returns_error(self, executor_with_workflow):
"""缺少必需参数时返回错误"""
args = {"op_name": "relu"}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is not None
assert result["status"] == "error"
assert "task_desc" in result["error_information"]
assert "dsl" in result["error_information"]
def test_empty_string_param_returns_error(self, executor_with_workflow):
"""必需参数为空字符串时返回错误"""
args = {
"op_name": "relu",
"task_desc": "",
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is not None
assert result["status"] == "error"
assert "task_desc" in result["error_information"]
assert "为空字符串" in result["error_information"]
def test_whitespace_only_param_returns_error(self, executor_with_workflow):
"""必需参数只有空格时也视为空"""
args = {
"op_name": "relu",
"task_desc": " \n ",
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is not None
assert result["status"] == "error"
assert "task_desc" in result["error_information"]
def test_none_param_returns_error(self, executor_with_workflow):
"""必需参数为 None 时返回错误"""
args = {
"op_name": "relu",
"task_desc": None,
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is not None
assert result["status"] == "error"
assert "task_desc" in result["error_information"]
assert "未提供" in result["error_information"]
def test_optional_param_with_default_allows_empty(self, executor_with_workflow):
"""有 default 值的可选参数允许为空(不在 required 中的参数不校验)"""
args = {
"op_name": "relu",
"task_desc": "import torch\nclass Model...",
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
"task_id": "",
"user_requirements": "",
}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is None
def test_no_required_field_passes(self, executor_with_no_required):
"""required 列表为空时直接通过"""
result = executor_with_no_required._validate_required_params(
"call_skill_evolution", {}
)
assert result is None
def test_unknown_tool_passes(self, executor_empty):
"""未注册的工具直接通过(无 schema 可校验)"""
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_reg:
mock_reg.get_tool.return_value = None
result = executor_empty._validate_required_params("unknown_tool", {})
assert result is None
def test_multiple_missing_params_all_reported(self, executor_with_workflow):
"""多个必需参数缺失时全部报告"""
args = {}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is not None
error_info = result["error_information"]
for param in ["op_name", "task_desc", "dsl", "framework", "backend", "arch"]:
assert param in error_info
def test_error_result_format(self, executor_with_workflow):
"""错误返回格式符合标准"""
args = {"op_name": "relu"}
result = executor_with_workflow._validate_required_params(
"call_kernelgen_workflow", args
)
assert result is not None
assert "status" in result
assert "output" in result
assert "error_information" in result
assert result["status"] == "error"
assert result["output"] == ""
def test_agent_required_validation(self, executor_with_agent):
"""Agent 工具的 required 校验同样生效"""
result = executor_with_agent._validate_required_params(
"call_op_task_builder", {}
)
assert result is not None
assert "user_input" in result["error_information"]
def test_agent_with_valid_params_passes(self, executor_with_agent):
"""Agent 参数完整时校验通过"""
result = executor_with_agent._validate_required_params(
"call_op_task_builder", {"user_input": "生成 relu 算子"}
)
assert result is None
class TestExecuteWithValidation:
"""测试 execute() 中的 required 校验集成行为"""
@pytest.mark.asyncio
async def test_execute_blocks_missing_required_workflow(self, executor_with_workflow):
"""execute 在 workflow 缺少 required 参数时直接返回错误,不执行 workflow"""
result = await executor_with_workflow.execute(
"call_kernelgen_workflow", {"op_name": "relu"}
)
assert result["status"] == "error"
assert "task_desc" in result["error_information"]
@pytest.mark.asyncio
async def test_execute_blocks_empty_required_workflow(self, executor_with_workflow):
"""execute 在 workflow 的 required 参数为空字符串时阻止执行"""
result = await executor_with_workflow.execute(
"call_kernelgen_workflow",
{
"op_name": "relu",
"task_desc": "",
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
},
)
assert result["status"] == "error"
assert "task_desc" in result["error_information"]
@pytest.mark.asyncio
async def test_execute_blocks_missing_required_agent(self, executor_with_agent):
"""execute 在 agent 缺少 required 参数时直接返回错误"""
result = await executor_with_agent.execute("call_op_task_builder", {})
assert result["status"] == "error"
assert "user_input" in result["error_information"]
@pytest.mark.asyncio
async def test_execute_proceeds_when_all_required_present(
self, executor_with_workflow
):
"""所有 required 参数存在时正常进入 workflow 执行流程"""
mock_workflow_class = executor_with_workflow.workflow_registry[
"call_kernelgen_workflow"
]["workflow_class"]
mock_app = AsyncMock()
mock_app.ainvoke = AsyncMock(
return_value={"verifier_result": True, "coder_code": "code"}
)
mock_workflow_instance = MagicMock()
mock_workflow_instance.compile.return_value = mock_app
mock_workflow_instance.format_result.return_value = {
"status": "success",
"code": "code",
}
mock_workflow_class.return_value = mock_workflow_instance
mock_workflow_class.build_initial_state = MagicMock(
return_value={"task_desc": "code", "op_name": "relu"}
)
mock_workflow_class.ensure_resources = AsyncMock()
mock_workflow_class.prepare_config = MagicMock()
executor_with_workflow.agent_context = {
"get_workflow_resources": lambda: {
"agents": {"kernel_gen": MagicMock(), "verifier": MagicMock()},
"config": {},
"trace": MagicMock(),
"device_pool": None,
"private_worker": None,
"worker_manager": MagicMock(),
"backend": "cpu",
"arch": "x86_64",
}
}
result = await executor_with_workflow.execute(
"call_kernelgen_workflow",
{
"op_name": "relu",
"task_desc": "import torch\nclass Model(nn.Module):\n pass",
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
},
)
assert result["status"] == "success"
@pytest.mark.asyncio
async def test_execute_resolves_expression_then_validates(
self, executor_with_workflow
):
"""验证 resolve_arguments 在校验之前执行:表达式解析失败时保留原值再校验"""
args = {
"op_name": "relu",
"task_desc": "read_json_file('/nonexistent/path.json')['key']",
"dsl": "cpp",
"framework": "torch",
"backend": "cpu",
"arch": "x86_64",
}
result = await executor_with_workflow.execute(
"call_kernelgen_workflow", args
)
assert result is not None
class TestRegressionEmptyTaskDesc:
"""回归测试:复现 LLM 遗漏 task_desc 导致 relu_torch.py 为空的 bug"""
@pytest.mark.asyncio
async def test_only_op_name_provided_is_rejected(self, executor_with_workflow):
"""
复现场景:LLM 调用 call_kernelgen_workflow 时只传了 {"op_name": "relu"},
遗漏了 task_desc 等所有其他必需参数。
期望:校验立即拦截并返回清晰错误,而不是让空 task_desc 流入 verifier。
"""
result = await executor_with_workflow.execute(
"call_kernelgen_workflow", {"op_name": "relu"}
)
assert result["status"] == "error"
assert "task_desc" in result["error_information"]
assert "dsl" in result["error_information"]
assert "framework" in result["error_information"]
assert "backend" in result["error_information"]
assert "arch" in result["error_information"]
class TestErrorMessageReadability:
"""测试错误信息的可读性,防止超长 tool_name 污染报错"""
@pytest.mark.asyncio
async def test_unknown_tool_error_truncates_long_name(self, executor_empty):
"""超长工具名在报错信息中被截断"""
long_name = "ask_user` 而不是 `ask_user`)\n\n3. 当前状态" + "x" * 500
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_reg:
mock_reg.get_tool.return_value = None
mock_reg.list_names.return_value = ["ask_user", "read_file"]
result = await executor_empty.execute(long_name, {})
assert result["status"] == "error"
assert len(result["error_information"]) < 500
assert "..." in result["error_information"]
@pytest.mark.asyncio
async def test_unknown_tool_error_message_is_actionable(self, executor_empty):
"""未知工具的报错信息应该包含可操作指引"""
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_reg:
mock_reg.get_tool.return_value = None
mock_reg.list_names.return_value = ["ask_user", "read_file"]
result = await executor_empty.execute("nonexistent", {})
assert "请检查工具名拼写" in result["error_information"]
assert "可用工具" in result["error_information"]
@pytest.mark.asyncio
async def test_short_tool_name_not_truncated(self, executor_empty):
"""短工具名不被截断"""
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_reg:
mock_reg.get_tool.return_value = None
mock_reg.list_names.return_value = ["ask_user"]
result = await executor_empty.execute("bad_name", {})
assert "'bad_name'" in result["error_information"]
class TestInvalidToolNameDetection:
"""测试非法 function.name 到达 ToolExecutor 后的清晰报错
核心理念:不做偷偷修正,而是清晰报错让 LLM 自行修正。
"""
@pytest.mark.asyncio
async def test_polluted_name_rejected_by_tool_executor(self, executor_empty):
"""被 LLM 思考链污染的 function.name 到达 ToolExecutor 后应返回可读错误"""
polluted = (
'ask_user` 而不是 `ask_user`)\n\n3. **当前状态**:\n'
' - task_desc 已经生成成功(在 node_002)\n'
' - 但是上一步的 ask_user 失败了\n\n'
'</think>我看到上一步出现了错误。<tool_call>ask_user'
)
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_reg:
mock_reg.get_tool.return_value = None
mock_reg.list_names.return_value = ["ask_user", "call_kernelgen_workflow"]
result = await executor_empty.execute(polluted, {})
assert result["status"] == "error"
assert "..." in result["error_information"]
assert "请检查工具名拼写" in result["error_information"]
@pytest.mark.asyncio
async def test_short_polluted_name_shows_full_name(self, executor_empty):
"""短的非法 tool_name(<60字符)应完整展示"""
polluted = '<tool_call>ask_user'
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_reg:
mock_reg.get_tool.return_value = None
mock_reg.list_names.return_value = ["ask_user"]
result = await executor_empty.execute(polluted, {})
assert result["status"] == "error"
assert "请检查工具名拼写" in result["error_information"]
assert "..." not in result["error_information"]
@pytest.mark.asyncio
async def test_normal_tool_name_passes_through(self, executor_empty):
"""正常工具名应正常传递,不触发报错"""
with patch("akg_agents.core_v2.tools.tool_executor.ToolRegistry") as mock_reg:
mock_tool = MagicMock()
mock_tool.parameters = None
mock_reg.get_tool.return_value = mock_tool
mock_reg.aexecute = AsyncMock(return_value={"status": "success", "output": "ok"})
result = await executor_empty.execute("ask_user", {"message": "hello"})
assert result["status"] == "success"