"""
CoderDatabase 测试用例
测试内容:
1. 数据库创建与 auto_update
2. 算子的增删操作(包含 skip/overwrite 模式)
3. samples 采样功能
4. 异步并发操作
"""
import os
import pytest
import asyncio
import shutil
from pathlib import Path
os.environ['AKG_AGENTS_STREAM_OUTPUT'] = 'on'
from akg_agents import get_project_root
from akg_agents.op.database.coder_database import CoderDatabase, DEFAULT_BENCHMARK_PATH
from akg_agents.op.database.coder_vector_store import CoderVectorStore
from akg_agents.op.config.config_validator import load_config
from akg_agents.op.config.config_validator import load_config
TEST_DSL = "triton_ascend"
TEST_FRAMEWORK = "torch"
TEST_BACKEND = "ascend"
TEST_ARCH = "ascend910b4"
MAX_TEST_FILES = 3
TEST_DB_DIR = Path(get_project_root()).parent.parent / "test_temp_db"
def clear_all_instances():
"""清除所有单例实例缓存"""
CoderDatabase._instances.clear()
CoderVectorStore._instances.clear()
def get_test_op_files(dsl: str = TEST_DSL, max_files: int = MAX_TEST_FILES):
"""获取测试用的算子文件信息"""
benchmark_path = Path(DEFAULT_BENCHMARK_PATH)
impl_dir = benchmark_path / dsl / "impl" / "static_shape" / "elemwise"
if not impl_dir.exists():
pytest.skip(f"Benchmark directory not found: {impl_dir}")
impl_files = list(impl_dir.glob("*.py"))[:max_files]
test_ops = []
for impl_file in impl_files:
op_name = impl_file.stem
framework_file = benchmark_path / "static_shape" / "elemwise" / f"{op_name}.py"
if framework_file.exists():
with open(impl_file, 'r', encoding='utf-8') as f:
impl_code = f.read()
with open(framework_file, 'r', encoding='utf-8') as f:
framework_code = f.read()
test_ops.append({
"op_name": op_name,
"impl_code": impl_code,
"framework_code": framework_code
})
return test_ops
@pytest.fixture(scope="module")
def temp_database_path():
"""创建临时数据库目录(module级别共享)"""
TEST_DB_DIR.mkdir(parents=True, exist_ok=True)
yield str(TEST_DB_DIR)
if TEST_DB_DIR.exists():
shutil.rmtree(TEST_DB_DIR)
@pytest.fixture(scope="module")
def config():
"""加载测试配置"""
return load_config(TEST_DSL)
@pytest.fixture(autouse=True)
def setup_test():
"""每个测试前清理单例缓存"""
clear_all_instances()
yield
@pytest.mark.level0
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_auto_update(temp_database_path, config):
"""测试:auto_update 功能(新建与 skip 模式)"""
db = CoderDatabase(
database_path=temp_database_path,
config=config
)
await db.auto_update(
dsl=TEST_DSL,
framework=TEST_FRAMEWORK,
backend=TEST_BACKEND,
arch=TEST_ARCH,
ref_type="impl",
max_files=MAX_TEST_FILES,
update_mode="skip"
)
db_dir = Path(temp_database_path) / TEST_ARCH / TEST_DSL
assert db_dir.exists(), f"Database directory not created: {db_dir}"
cases = list(db_dir.iterdir())
initial_count = len(cases)
assert initial_count > 0, "No cases were inserted"
assert initial_count <= MAX_TEST_FILES, f"More files than expected: {initial_count}"
print(f"✓ auto_update 成功创建 {initial_count} 个算子记录")
db._auto_update_completed.clear()
await db.auto_update(
dsl=TEST_DSL,
framework=TEST_FRAMEWORK,
backend=TEST_BACKEND,
arch=TEST_ARCH,
ref_type="impl",
max_files=MAX_TEST_FILES,
update_mode="skip"
)
final_count = len(list(db_dir.iterdir()))
assert final_count == initial_count, "Skip mode should not duplicate entries"
print(f"✓ auto_update skip 模式正常工作,记录数: {final_count}")
@pytest.mark.level0
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_insert_delete_and_modes(temp_database_path, config):
"""测试:算子的插入、删除以及 skip/overwrite 模式
逻辑:
- test_auto_update 已经创建了数据库并插入了 test_op
- 直接用 skip 模式插入同一个算子,验证 skip 功能
- 用 overwrite 模式插入,验证 overwrite 功能
- 最后删除,验证 delete 功能
"""
db = CoderDatabase(
database_path=temp_database_path,
config=config
)
test_ops = get_test_op_files(max_files=1)
if not test_ops:
pytest.skip("No test operators available")
test_op = test_ops[0]
db_dir = Path(temp_database_path) / TEST_ARCH / TEST_DSL
from akg_agents.utils.common_utils import get_md5_hash
md5_hash = get_md5_hash(impl_code=test_op["impl_code"], backend=TEST_BACKEND, arch=TEST_ARCH, dsl=TEST_DSL)
case_dir = db_dir / md5_hash
assert case_dir.exists(), f"Case directory should exist from test_auto_update: {case_dir}"
initial_mtime = (case_dir / "metadata.json").stat().st_mtime
print(f"✓ 验证数据已存在: {test_op['op_name']}")
await asyncio.sleep(0.1)
await db.insert(
impl_code=test_op["impl_code"],
framework_code=test_op["framework_code"],
backend=TEST_BACKEND,
arch=TEST_ARCH,
dsl=TEST_DSL,
framework=TEST_FRAMEWORK,
mode="skip"
)
skip_mtime = (case_dir / "metadata.json").stat().st_mtime
assert skip_mtime == initial_mtime, "Skip mode should not modify existing files"
print("✓ skip 模式正常工作")
await asyncio.sleep(0.1)
await db.insert(
impl_code=test_op["impl_code"],
framework_code=test_op["framework_code"],
backend=TEST_BACKEND,
arch=TEST_ARCH,
dsl=TEST_DSL,
framework=TEST_FRAMEWORK,
mode="overwrite"
)
overwrite_mtime = (case_dir / "metadata.json").stat().st_mtime
assert overwrite_mtime > initial_mtime, "Overwrite mode should update files"
print("✓ overwrite 模式正常工作")
db.delete(
impl_code=test_op["impl_code"],
backend=TEST_BACKEND,
arch=TEST_ARCH,
dsl=TEST_DSL
)
assert not case_dir.exists(), f"Case directory should be deleted: {case_dir}"
print(f"✓ 成功删除算子: {test_op['op_name']}")
@pytest.mark.level0
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_concurrent_operations(temp_database_path, config):
"""测试:并发操作(多个插入任务同时执行)"""
test_path = str(Path(temp_database_path) / "concurrent_test")
Path(test_path).mkdir(parents=True, exist_ok=True)
db = CoderDatabase(
database_path=test_path,
config=config
)
test_ops = get_test_op_files(max_files=3)
if len(test_ops) < 2:
pytest.skip("Need at least 2 test operators")
async def insert_task(op_info):
await db.insert(
impl_code=op_info["impl_code"],
framework_code=op_info["framework_code"],
backend=TEST_BACKEND,
arch=TEST_ARCH,
dsl=TEST_DSL,
framework=TEST_FRAMEWORK,
mode="skip"
)
return op_info["op_name"]
tasks = [insert_task(op) for op in test_ops]
results = await asyncio.gather(*tasks, return_exceptions=True)
successful = [r for r in results if not isinstance(r, Exception)]
failed = [r for r in results if isinstance(r, Exception)]
db_dir = Path(test_path) / TEST_ARCH / TEST_DSL
case_count = len(list(db_dir.iterdir())) if db_dir.exists() else 0
print(f"✓ 并发操作完成: 成功 {len(successful)}, 失败 {len(failed)}, 数据库记录 {case_count}")
assert len(successful) > 0, "At least one operation should succeed"
assert case_count == len(successful), f"Database count mismatch: {case_count} vs {len(successful)}"
@pytest.mark.level0
@pytest.mark.use_model
@pytest.mark.use_vector_store
@pytest.mark.asyncio
async def test_samples(temp_database_path, config):
"""测试:samples 采样功能"""
test_path = str(Path(temp_database_path) / "samples_test")
Path(test_path).mkdir(parents=True, exist_ok=True)
db = CoderDatabase(
database_path=test_path,
config=config
)
await db.auto_update(
dsl=TEST_DSL,
framework=TEST_FRAMEWORK,
backend=TEST_BACKEND,
arch=TEST_ARCH,
ref_type="impl",
max_files=MAX_TEST_FILES,
update_mode="skip"
)
test_ops = get_test_op_files(max_files=1)
if not test_ops:
pytest.skip("No test operators available")
test_op = test_ops[0]
try:
results = await db.samples(
output_content=["impl_code", "op_name"],
sample_num=2,
impl_code=test_op["impl_code"],
framework_code=test_op["framework_code"],
backend=TEST_BACKEND,
arch=TEST_ARCH,
dsl=TEST_DSL,
framework=TEST_FRAMEWORK
)
assert isinstance(results, list), "samples should return a list"
print(f"✓ samples 查询成功,返回 {len(results)} 个结果")
for i, result in enumerate(results):
op_name = result.get("op_name", "unknown")
print(f" - 结果 {i+1}: {op_name}")
except ValueError as e:
print(f"✓ samples 查询完成(无匹配结果): {e}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])