import textwrap
import pytest
from akg_agents.op.verifier.kernel_verifier import KernelVerifier
from akg_agents.utils.common_utils import create_log_dir
from akg_agents.op.config.config_validator import load_config
from akg_agents.core.worker.manager import register_local_worker, get_worker_manager
from ..utils import get_device_id
device_id = get_device_id()
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.parametrize("op_name", ["relu","linear"])
@pytest.mark.asyncio
async def test_kernel_verifier_ascend910b4_torch(op_name):
framework = "torch"
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}_{framework}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_kernel_verifier_a100(op_name):
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.v100
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_kernel_verifier_v100(op_name):
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "v100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = f"{op_name}_{dsl}_{framework}"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.profiling
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_kernel_verifier_profiling_ascend910b4_torch(op_name):
framework = "torch"
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}_{framework}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_profiling_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="profiling_test_001",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
profile_settings = {
"run_times": 50,
"warmup_times": 5
}
result = await verifier.run_profile(
task_info, current_step=0, device_id=device_id, profile_settings=profile_settings)
gen_time = result['gen_time']
base_time = result['base_time']
speedup = result['speedup']
print(f"orig performance is {base_time:.2f} us")
print(f"akg_agents performance is {gen_time:.2f} us")
print(f"speedup is {speedup:.2f}x")
@pytest.mark.level0
@pytest.mark.mindspore
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.profiling
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_kernel_verifier_profiling_ascend910b4_mindspore(op_name):
framework = "mindspore"
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}_{framework}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_profiling_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="profiling_test_001",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
profile_settings = {
"run_times": 50,
"warmup_times": 5
}
result = await verifier.run_profile(
task_info, current_step=0, device_id=device_id, profile_settings=profile_settings)
gen_time = result['gen_time']
base_time = result['base_time']
speedup = result['speedup']
print(f"orig performance is {base_time:.2f} us")
print(f"akg_agents performance is {gen_time:.2f} us")
print(f"speedup is {speedup:.2f}x")
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.profiling
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_kernel_verifier_profiling_a100(op_name):
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_profiling_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="profiling_test_001",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
profile_settings = {
"run_times": 50,
"warmup_times": 5
}
result = await verifier.run_profile(
task_info, current_step=0, device_id=device_id, profile_settings=profile_settings)
gen_time = result['gen_time']
base_time = result['base_time']
speedup = result['speedup']
print(f"orig performance is {base_time:.2f} us")
print(f"akg_agents performance is {gen_time:.2f} us")
print(f"speedup is {speedup:.2f}x")
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.profiling
@pytest.mark.parametrize("op_name", ["linear"])
@pytest.mark.asyncio
async def test_kernel_verifier_profiling_linear_ascend910b4_torch(op_name):
"""Linear profiling test for ascend910b4_torch"""
framework = "torch"
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}_{framework}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="linear_profiling_test_001",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
profile_settings = {
"run_times": 50,
"warmup_times": 5
}
result = await verifier.run_profile(
task_info, current_step=0, device_id=device_id, profile_settings=profile_settings)
gen_time = result['gen_time']
base_time = result['base_time']
speedup = result['speedup']
print(f"Linear Profiling Results:")
print(f"Operation: {op_name}")
print(f"orig performance is {base_time:.2f} us")
print(f"akg_agents performance is {gen_time:.2f} us")
print(f"speedup is {speedup:.2f}x")
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.cpp
@pytest.mark.cpu
@pytest.mark.x86_64
@pytest.mark.profiling
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_kernel_verifier_profiling_cpp(op_name):
framework = "torch"
dsl = "cpp"
backend = "cpu"
arch = "x86_64"
config = load_config(config_path="./python/akg_agents/op/config/cpp_coderonly_config.yaml")
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_profiling_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="profiling_test_001",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.cpp
@pytest.mark.cpu
@pytest.mark.x86_64
@pytest.mark.parametrize("op_name", ["linear"])
@pytest.mark.asyncio
async def test_kernel_verifier_linear_cpp(op_name):
"""测试linear算子,验证weight随机种子对齐"""
framework = "torch"
dsl = "cpp"
backend = "cpu"
arch = "x86_64"
config = load_config(config_path="./python/akg_agents/op/config/cpp_coderonly_config.yaml")
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="linear_test_001",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.parametrize("op_name", ["linear"])
@pytest.mark.asyncio
async def test_kernel_verifier_linear_triton_cuda(op_name):
"""测试linear算子(triton_cuda),验证weight随机种子对齐"""
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}. Please register a worker first.")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="linear_triton_cuda_test_001",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_check_task_desc_static_valid(op_name):
"""测试静态检查:有效的 task_desc 应该通过"""
framework = "torch"
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
valid_task_desc = f.read()
verifier = KernelVerifier(
op_name=op_name,
framework_code=valid_task_desc,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
config=config
)
valid, error = verifier.check_task_desc_static(valid_task_desc)
assert valid, f"静态检查应该通过,但失败了: {error}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.asyncio
async def test_check_task_desc_static_missing_model():
"""测试静态检查:缺少 Model 类应该失败"""
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
config = load_config(dsl, backend=backend)
invalid_task_desc = """
import torch
def get_inputs():
return [torch.randn(16, 16384)]
def get_init_inputs():
return []
"""
verifier = KernelVerifier(
op_name="test",
framework_code="",
framework="torch",
dsl=dsl,
backend=backend,
arch=arch,
config=config
)
valid, error = verifier.check_task_desc_static(invalid_task_desc)
assert not valid, "静态检查应该失败(缺少 Model 类)"
assert "class Model" in error
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.asyncio
async def test_check_task_desc_static_missing_get_inputs():
"""测试静态检查:缺少 get_inputs 函数应该失败"""
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
config = load_config(dsl, backend=backend)
invalid_task_desc = """
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.relu(x)
def get_init_inputs():
return []
"""
verifier = KernelVerifier(
op_name="test",
framework_code="",
framework="torch",
dsl=dsl,
backend=backend,
arch=arch,
config=config
)
valid, error = verifier.check_task_desc_static(invalid_task_desc)
assert not valid, "静态检查应该失败(缺少 get_inputs)"
assert "get_inputs" in error
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_check_task_desc_runtime_valid(op_name):
"""测试运行时检查:有效的 task_desc 应该通过"""
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
valid_task_desc = f.read()
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}")
try:
verifier = KernelVerifier(
op_name=op_name,
framework_code=valid_task_desc,
task_id="runtime_check_test",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
config=config,
worker=worker
)
valid, error = await verifier.check_task_desc_runtime(valid_task_desc, timeout=60)
assert valid, f"运行时检查应该通过,但失败了: {error}"
finally:
await get_worker_manager().release(worker)
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.cpu
@pytest.mark.x86_64
@pytest.mark.profiling
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_profile_baseline_only_cpp(op_name):
"""测试只跑 baseline profile(不跑 generation profile)
场景:验证 skip_base=False 时,baseline profile 正常执行
"""
framework = "torch"
dsl = "cpp"
backend = "cpu"
arch = "x86_64"
config = load_config(config_path="./python/akg_agents/op/config/cpp_coderonly_config.yaml")
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_baseline_only_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="baseline_only_test",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {"coder_code": kernel_code}
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
profile_settings = {
"warmup_times": 2,
"run_times": 5,
}
profile_result = await verifier.run_profile(
task_info,
current_step=1,
device_id=device_id,
profile_settings=profile_settings
)
assert profile_result is not None, "Profile 结果不应为 None"
assert 'base_time' in profile_result, "应该包含 base_time"
assert 'gen_time' in profile_result, "应该包含 gen_time"
base_time = profile_result['base_time']
gen_time = profile_result['gen_time']
assert base_time is not None and base_time > 0 and base_time < float('inf'), \
f"Baseline 时间应该被正常测量,但得到: {base_time}"
assert gen_time is not None and gen_time > 0 and gen_time < float('inf'), \
f"Generation 时间应该被正常测量,但得到: {gen_time}"
print(f"✅ Baseline profile 测试通过: base_time={base_time:.2f}us, gen_time={gen_time:.2f}us")
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.cpu
@pytest.mark.x86_64
@pytest.mark.profiling
@pytest.mark.parametrize("op_name", ["relu"])
@pytest.mark.asyncio
async def test_profile_generation_only_cpp(op_name):
"""测试只跑 generation profile(跳过 baseline profile)
场景:验证优化后的逻辑,使用缓存的 baseline 时间,跳过 baseline profile
"""
framework = "torch"
dsl = "cpp"
backend = "cpu"
arch = "x86_64"
config = load_config(config_path="./python/akg_agents/op/config/cpp_coderonly_config.yaml")
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = create_log_dir(f'{op_name}_{framework}_{backend}_{arch}_{dsl}_generation_only_test')
await register_local_worker([device_id], backend=backend, arch=arch)
worker = await get_worker_manager().select(backend=backend, arch=arch)
if not worker:
raise RuntimeError(f"No available worker for backend={backend}, arch={arch}")
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id="generation_only_test",
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config,
worker=worker
)
task_info = {"coder_code": kernel_code}
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
cached_baseline_time = 20.0
profile_settings = {
"warmup_times": 2,
"run_times": 5,
"override_base_time_us": cached_baseline_time,
"skip_base_profile": True,
}
profile_result = await verifier.run_profile(
task_info,
current_step=1,
device_id=device_id,
profile_settings=profile_settings
)
assert profile_result is not None, "Profile 结果不应为 None"
assert 'base_time' in profile_result, "应该包含 base_time"
assert 'gen_time' in profile_result, "应该包含 gen_time"
base_time = profile_result['base_time']
gen_time = profile_result['gen_time']
assert base_time == cached_baseline_time, \
f"Baseline 时间应该使用缓存值 {cached_baseline_time},但得到: {base_time}"
assert gen_time is not None and gen_time > 0 and gen_time < float('inf'), \
f"Generation 时间应该被正常测量,但得到: {gen_time}"
expected_speedup = cached_baseline_time / gen_time if gen_time > 0 else 0
actual_speedup = profile_result.get('speedup', 0)
assert abs(actual_speedup - expected_speedup) < 0.01, \
f"Speedup 计算错误: expected={expected_speedup:.2f}, actual={actual_speedup:.2f}"
print(f"✅ Generation-only profile 测试通过: base_time={base_time:.2f}us (cached), gen_time={gen_time:.2f}us, speedup={actual_speedup:.2f}x")