"""
测试 KernelGen Agent 的基本功能
演示如何使用新的 KernelGen agent 生成内核代码
"""
import asyncio
import logging
import json
from pathlib import Path
from akg_agents.op.agents import KernelGen
from akg_agents.core_v2.filesystem import ActionRecord
from akg_agents.utils.common_utils import ParserFactory
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(funcName)s() - %(message)s'
)
logger = logging.getLogger(__name__)
def print_code(raw_output: str, keys: list = None):
"""打印代码结果(参考 designer 测试格式)"""
if keys is None:
keys = ["code"]
try:
extracted_json = ParserFactory._extract_json_comprehensive(raw_output)
if extracted_json:
parsed = json.loads(extracted_json)
if isinstance(parsed, dict):
parsed = {key: parsed.get(key, "N/A") for key in keys if key in parsed}
except (json.JSONDecodeError, Exception):
print(f"{'=' * 50}")
print(f"📋 [RAW OUTPUT]")
print(raw_output[:1000] + "..." if len(raw_output) > 1000 else raw_output)
print("=" * 50)
return
print(f"{'=' * 50}")
for key, value in parsed.items():
print(f"📋 [{key.upper()}]")
print(value)
print("-" * 50)
print("=" * 50)
async def test_kernel_gen_basic():
"""测试基本的 kernel 生成功能"""
try:
agent = KernelGen()
logger.info("✓ KernelGen agent created successfully")
history_compress = [
ActionRecord(
action_id="summary",
tool_name="history_summary",
arguments={},
result={"summary": "用户请求实现一个向量加法算子,目标是 Triton Ascend"}
),
ActionRecord(
action_id="act_001",
tool_name="op_task_build",
arguments={"user_input": "向量加法"},
result={"task_spec": "..."}
),
]
logger.info("Running KernelGen agent...")
generated_code, formatted_prompt, reasoning = await agent.run(
op_name="vector_add",
task_desc="""
实现一个简单的向量加法内核:
- 输入:两个大小为 N 的一维张量 A 和 B
- 输出:张量 C = A + B
- 要求:
* 处理任意大小
* 使用高效的内存访问模式
* 包含边界检查
""",
dsl="triton_ascend",
framework="torch",
backend="ascend",
arch="ascend910b4",
task_id="test_vector_add_001",
history_compress=history_compress
)
logger.info("✓ Code generation completed")
print_code(generated_code, ["code"])
return True
except Exception as e:
logger.error(f"✗ Test failed: {e}", exc_info=True)
return False
async def test_kernel_gen_with_error():
"""测试带错误反馈的迭代生成"""
try:
agent = KernelGen()
logger.info("✓ KernelGen agent created for softmax")
history_compress = [
ActionRecord(
action_id="summary1",
tool_name="history_summary",
arguments={},
result={"summary": "用户请求实现 softmax 算子,第一次生成的代码有编译错误"}
),
ActionRecord(
action_id="act_001",
tool_name="kernel_gen",
arguments={"task_desc": "softmax implementation"},
result={"code": "code_v1"}
),
ActionRecord(
action_id="act_002",
tool_name="verifier",
arguments={"code": "code_v1"},
result={
"passed": "False",
"error": "Error: Compilation failed\n Line 42: undefined variable 'max_val'\n Hint: You need to compute max before exp"
}
),
]
logger.info("Running KernelGen agent with error feedback...")
generated_code, _, _ = await agent.run(
op_name="softmax",
task_desc="""
实现一个 softmax 内核:
- 输入:二维张量 (batch_size, seq_len)
- 输出:在最后一个维度上应用 softmax
- 使用数值稳定的实现
""",
dsl="triton_ascend",
framework="torch",
backend="ascend",
arch="ascend910b4",
user_requirements="修复之前的编译错误:需要先计算 max_val",
task_id="test_softmax_001",
history_compress=history_compress
)
logger.info("✓ Code generation with error feedback completed")
print_code(generated_code, ["code"])
return True
except Exception as e:
logger.error(f"✗ Test failed: {e}", exc_info=True)
return False
async def main():
"""运行所有测试"""
logger.info("="*60)
logger.info("Testing KernelGen Agent")
logger.info("="*60)
tests = [
("Basic generation", test_kernel_gen_basic),
("Generation with error feedback", test_kernel_gen_with_error)
]
results = []
for test_name, test_func in tests:
logger.info(f"{'='*60}")
logger.info(f"Test: {test_name}")
logger.info(f"{'='*60}")
try:
result = await test_func()
results.append((test_name, result))
except Exception as e:
logger.error(f"Test '{test_name}' crashed: {e}")
results.append((test_name, False))
logger.info(f"{'='*60}")
logger.info("Test Summary")
logger.info(f"{'='*60}")
for test_name, result in results:
status = "✓ PASSED" if result else "✗ FAILED"
logger.info(f"{status}: {test_name}")
passed = sum(1 for _, r in results if r)
total = len(results)
logger.info(f"\nTotal: {passed}/{total} tests passed")
if __name__ == "__main__":
asyncio.run(main())