"""
AgentBase 工具配置加载功能 - 综合测试
"""
import sys
import os
import yaml
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'python'))
from akg_agents.core_v2.agents.base import AgentBase
def test_1_basic_agent():
"""测试 1: 基本 Agent"""
print("\n" + "=" * 80)
print("测试 1: 基本 Agent - OpTaskBuilder")
print("=" * 80)
class OpTaskBuilderAgent(AgentBase):
TOOL_NAME = "call_op_task_builder"
DESCRIPTION = """将用户需求转换为 KernelBench 格式的 task 代码,并验证其可运行性。
这是算子生成的必要前置步骤,必须先确保 task 代码正确,才能调用后续的代码生成工具。"""
PARAMETERS_SCHEMA = {
"type": "object",
"properties": {
"op_name": {
"type": "string",
"description": "算子名称"
},
"user_request": {
"type": "string",
"description": "用户的自然语言需求描述"
},
"user_feedback": {
"type": "string",
"description": "用户对之前生成的 task_desc 的反馈(可选)",
"default": ""
}
},
"required": ["op_name", "user_request"]
}
agent = OpTaskBuilderAgent()
config = agent.load_tool_config()
print("生成的配置字典:")
print(yaml.dump(config, allow_unicode=True, default_flow_style=False, sort_keys=False))
print("✅ 测试通过\n")
return True
def test_2_missing_metadata():
"""测试 2: 缺少元数据(预期失败)"""
print("=" * 80)
print("测试 2: 缺少元数据 - 应该抛出 ValueError")
print("=" * 80)
class IncompleteAgent(AgentBase):
pass
agent = IncompleteAgent()
try:
config = agent.load_tool_config()
print("❌ 测试失败:应该抛出 ValueError")
return False
except ValueError as e:
print(f"✅ 成功捕获 ValueError:\n{e}\n")
return True
def test_3_no_parameters():
"""测试 3: 无参数的 Agent"""
print("=" * 80)
print("测试 3: 无参数的 Agent")
print("=" * 80)
class NoParamAgent(AgentBase):
TOOL_NAME = "call_no_param_agent"
DESCRIPTION = "不需要参数的 Agent"
PARAMETERS_SCHEMA = {
"type": "object",
"properties": {},
"required": []
}
agent = NoParamAgent()
config = agent.load_tool_config()
print("生成的配置字典:")
print(yaml.dump(config, allow_unicode=True, default_flow_style=False, sort_keys=False))
print("✅ 测试通过\n")
return True
def test_4_complex_parameters():
"""测试 4: 复杂参数类型"""
print("=" * 80)
print("测试 4: 复杂参数类型(包含数组、对象等)")
print("=" * 80)
class ComplexAgent(AgentBase):
TOOL_NAME = "call_complex_agent"
DESCRIPTION = "支持复杂参数类型的 Agent"
PARAMETERS_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "名称"
},
"count": {
"type": "integer",
"description": "数量",
"default": 10
},
"ratio": {
"type": "number",
"description": "比例",
"default": 0.5
},
"enabled": {
"type": "boolean",
"description": "是否启用",
"default": True
},
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "标签列表"
},
"config": {
"type": "object",
"properties": {
"timeout": {"type": "integer"},
"retry": {"type": "boolean"}
},
"description": "配置对象"
}
},
"required": ["name"]
}
agent = ComplexAgent()
config = agent.load_tool_config()
print("生成的配置字典:")
print(yaml.dump(config, allow_unicode=True, default_flow_style=False, sort_keys=False))
print("✅ 测试通过\n")
return True
def test_5_custom_indent():
"""测试 5: 字典格式输出"""
print("=" * 80)
print("测试 5: 字典格式输出")
print("=" * 80)
class CustomIndentAgent(AgentBase):
TOOL_NAME = "call_custom_indent"
DESCRIPTION = "测试字典格式"
PARAMETERS_SCHEMA = {
"type": "object",
"properties": {
"param": {"type": "string", "description": "参数"}
},
"required": ["param"]
}
agent = CustomIndentAgent()
config = agent.load_tool_config()
print("生成的配置字典:")
print(config)
print("\n转换为 YAML (4 空格缩进):")
print(yaml.dump(config, allow_unicode=True, default_flow_style=False, sort_keys=False, indent=4))
print("✅ 测试通过\n")
return True
def test_6_multiline_description():
"""测试 6: 多行描述"""
print("=" * 80)
print("测试 6: 多行描述格式化")
print("=" * 80)
class MultilineAgent(AgentBase):
TOOL_NAME = "call_multiline"
DESCRIPTION = """这是第一行描述
这是第二行描述
这是第三行描述
这是空行后的描述"""
PARAMETERS_SCHEMA = {
"type": "object",
"properties": {
"input1": {"type": "string"},
"input2": {"type": "integer"}
},
"required": ["input1", "input2"]
}
agent = MultilineAgent()
config = agent.load_tool_config()
print("生成的配置字典:")
print(yaml.dump(config, allow_unicode=True, default_flow_style=False, sort_keys=False))
print("✅ 测试通过\n")
return True
def main():
"""运行所有测试"""
print("\n" + "=" * 80)
print("AgentBase.load_tool_config() - 综合功能测试")
print("=" * 80)
tests = [
test_1_basic_agent,
test_2_missing_metadata,
test_3_no_parameters,
test_4_complex_parameters,
test_5_custom_indent,
test_6_multiline_description,
]
results = []
for test_func in tests:
try:
result = test_func()
results.append(result)
except Exception as e:
print(f"❌ 测试异常: {e}")
import traceback
traceback.print_exc()
results.append(False)
print("\n" + "=" * 80)
print("测试汇总")
print("=" * 80)
passed = sum(results)
total = len(results)
print(f"通过: {passed}/{total}")
if passed == total:
print("🎉 所有测试通过!")
print("=" * 80)
return 0
else:
print("❌ 部分测试失败")
print("=" * 80)
return 1
if __name__ == "__main__":
sys.exit(main())