import os
import sys
import argparse
import asyncio
import logging
from pathlib import Path
project_root = str(Path(__file__).parent.parent.absolute())
if project_root not in sys.path:
sys.path.insert(0, project_root)
from akg_agents.op.verifier.kernel_verifier import KernelVerifier
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
async def run_batch(args):
base_dir = Path(__file__).parent.parent.parent / "thirdparty" / "sol-execbench" / "data" / "benchmark"
level_dir = base_dir / args.level
if not level_dir.exists():
logger.error(f"SOL Bench data directory not found: {level_dir}")
logger.error("Please run 'bash download.sh --with_sol_execbench' first.")
return
code_dir = Path(args.code_dir)
if not code_dir.exists():
logger.error(f"Code directory not found: {code_dir}")
return
cases = sorted([d for d in level_dir.iterdir() if d.is_dir()])
logger.info(f"Found {len(cases)} cases in {args.level}")
passed_count = 0
failed_count = 0
missing_count = 0
for case_dir in cases:
op_name = case_dir.name
code_file = code_dir / f"{op_name}_{args.dsl}.py"
if not code_file.exists():
code_file = code_dir / f"{op_name}.py"
if not code_file.exists():
logger.warning(f"[{op_name}] Missing code file, skipping.")
missing_count += 1
continue
with open(code_file, "r", encoding="utf-8") as f:
impl_code = f.read()
config = {
"log_dir": args.log_dir,
"sol_problem_dir": str(case_dir),
"verify_timeout": args.timeout
}
verifier = KernelVerifier(
op_name=op_name,
framework_code="",
framework="torch",
dsl=args.dsl,
backend=args.backend,
arch=args.arch,
config=config,
bench_type="sol"
)
task_info = {"coder_code": impl_code}
logger.info(f"[{op_name}] Starting verification...")
try:
passed, log = await verifier.run(task_info, device_id=args.device_id)
if passed:
logger.info(f"[{op_name}] ✅ PASSED")
passed_count += 1
else:
logger.error(f"[{op_name}] ❌ FAILED\n{log[:500]}...")
failed_count += 1
except Exception as e:
logger.error(f"[{op_name}] ❌ ERROR: {e}")
failed_count += 1
logger.info("="*50)
logger.info(f"Batch Verification Summary for {args.level}")
logger.info(f"Total: {len(cases)} | Passed: {passed_count} | Failed: {failed_count} | Missing Code: {missing_count}")
logger.info("="*50)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Batch run SOL-ExecBench cases")
parser.add_argument("--level", type=str, default="L1", choices=["L1", "L2", "Quant", "FlashInfer-Bench"], help="SOL Bench level")
parser.add_argument("--code-dir", type=str, required=True, help="Directory containing the generated kernel codes")
parser.add_argument("--dsl", type=str, default="triton_cuda", help="DSL type (e.g., triton_cuda, cpp, ascendc)")
parser.add_argument("--backend", type=str, default="cuda", choices=["cuda", "ascend", "cpu"], help="Backend type")
parser.add_argument("--arch", type=str, default="a100", help="Architecture type")
parser.add_argument("--device-id", type=int, default=0, help="Device ID to use")
parser.add_argument("--timeout", type=int, default=300, help="Timeout per case in seconds")
parser.add_argument("--log-dir", type=str, default=os.environ.get("AKG_AGENTS_LOG_DIR", "~/akg_agents_logs/sol_batch").strip() or "~/akg_agents_logs/sol_batch", help="Directory to save logs")
args = parser.parse_args()
args.log_dir = os.path.expanduser(args.log_dir)
os.makedirs(args.log_dir, exist_ok=True)
asyncio.run(run_batch(args))