"""Run autoresearch on a custom operator task.
Reference (--desc or --ref) is required. --kernel is optional — KernelGen
generates the initial kernel from the reference if not provided.
python scripts/run_autoresearch.py --desc "fused ReLU + LayerNorm, (32,1024), fp16" --backend cuda
python scripts/run_autoresearch.py --ref reference.py --backend cuda
python scripts/run_autoresearch.py --ref reference.py --kernel kernel.py --backend cuda
"""
import argparse
import asyncio
import os
import re
import logging
from akg_agents.op.langgraph_op.task import LangGraphTask
from akg_agents.op.config.config_validator import load_config
from akg_agents.core.worker.manager import register_local_worker, register_remote_worker
from akg_agents.utils.task_label import resolve_task_label
logger = logging.getLogger(__name__)
_GENERATE_PROMPT = """\
You are a PyTorch expert. Given a natural language description of a tensor operation,
generate a complete Python reference implementation following this EXACT format:
```python
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, ...) -> torch.Tensor:
# implement the operation here
...
def get_inputs():
# create input tensors with the specified shapes, dtypes, and device
return [...]
def get_init_inputs():
return []
```
Rules:
- Model.forward() implements the EXACT operation described, nothing more
- get_inputs() creates random tensors matching the specified shapes, dtypes
- Device must be '{device}'
- Output ONLY the Python code block, no explanation
- Do NOT add any optimization — this is the reference (baseline) implementation
Operator description:
{description}
"""
async def generate_task_desc(
description: str, device: str,
model_level: str = "standard", gen_retries: int = 5,
) -> str:
"""Generate and validate task_desc from natural language.
Retries on structural validation failure — the LLM sometimes omits
required symbols. Each retry gets a fresh LLM call (no error feedback
needed; the prompt is deterministic, variance comes from sampling).
"""
from akg_agents.core_v2.llm.factory import create_llm_client
llm = create_llm_client(model_level=model_level)
prompt = _GENERATE_PROMPT.format(description=description, device=device)
last_error = ""
for attempt in range(1, gen_retries + 1):
print(f"[generate_task_desc] Generating reference "
f"(attempt {attempt}/{gen_retries}) …")
result = await llm.generate(
messages=[{"role": "user", "content": prompt}],
stream=False,
)
text = result.get("content", "")
m = re.search(r"```python\s*\n(.*?)```", text, re.DOTALL)
code = m.group(1).strip() if m else text.strip()
try:
_validate_ref_static(code, "LLM output")
print(f"[generate_task_desc] Reference validated OK")
return code
except ValueError as e:
last_error = str(e)
print(f"[generate_task_desc] Validation failed: {last_error}")
raise ValueError(
f"Failed to generate valid reference after {gen_retries} attempts "
f"({last_error}). Please write a reference.py manually and use "
f"--ref instead of --desc."
)
def derive_op_name(description: str) -> str:
"""Derive a short op_name from description."""
words = re.findall(r"[a-zA-Z]+", description)[:4]
name = "_".join(w.lower() for w in words)
return name or "custom_op"
def _validate_ref_static(code: str, source: str):
"""AST-level validation of reference code structure.
Same checks as KernelVerifier.check_task_desc_static(): parse the AST
and verify class Model, get_inputs, get_init_inputs exist as top-level
definitions. Runs at CLI layer for early failure before workflow starts.
"""
import ast as _ast
try:
tree = _ast.parse(code)
except SyntaxError as e:
raise ValueError(f"Reference from {source} has syntax error: {e}")
names = {
node.name for node in tree.body
if isinstance(node, (_ast.ClassDef, _ast.FunctionDef))
}
required = {"Model": "class Model", "get_inputs": "get_inputs()",
"get_init_inputs": "get_init_inputs()"}
missing = [label for name, label in required.items() if name not in names]
if missing:
raise ValueError(f"Reference from {source} missing: {', '.join(missing)}")
def _read_and_validate_ref(path: str) -> str:
"""Read reference file and validate structure."""
if not os.path.isfile(path):
raise FileNotFoundError(f"Reference file not found: {path}")
with open(path, "r", encoding="utf-8") as f:
code = f.read()
_validate_ref_static(code, path)
return code
async def main():
parser = argparse.ArgumentParser(description="Run autoresearch optimization")
ref_group = parser.add_mutually_exclusive_group(required=True)
ref_group.add_argument("--desc", type=str,
help="Natural language description → LLM generates reference")
ref_group.add_argument("--ref", type=str,
help="Path to reference.py (Model/get_inputs format)")
parser.add_argument("--kernel", type=str, default=None,
help="Path to initial kernel file (skips KernelGen)")
parser.add_argument("--op-name", type=str, default=None,
help="Operator name (auto-derived from --desc if omitted)")
parser.add_argument("--dsl", type=str, default=None,
choices=["triton_ascend", "triton_cuda", "torch",
"cuda_c", "cpp", "ascendc", "tilelang_cuda"])
parser.add_argument("--backend", type=str, default=None,
choices=["ascend", "cuda", "cpu"])
parser.add_argument("--arch", type=str, default=None)
parser.add_argument("--framework", type=str, default="torch")
parser.add_argument("--device-id", type=int, default=None)
parser.add_argument("--worker-url", "--worker_url", dest="worker_url",
type=str, default=None,
help="Remote Worker Service URL(s), comma-separated "
"(e.g. 127.0.0.1:9111). Mutually exclusive with --device-id.")
parser.add_argument("--max-rounds", type=int, default=20)
parser.add_argument("--gen-retries", type=int, default=5,
help="Max retries for code generation (reference and seed)")
args = parser.parse_args()
if args.worker_url and args.device_id is not None:
parser.error("--worker-url and --device-id are mutually exclusive")
_BACKEND_PRESETS = {
"ascend": {"dsl": "triton_ascend", "backend": "ascend", "arch": "ascend910b4"},
"cuda": {"dsl": "triton_cuda", "backend": "cuda", "arch": "a100"},
"cpu": {"dsl": "cpp", "backend": "cpu", "arch": "x86_64"},
}
preset_key = args.backend or (
"cuda" if args.dsl and "cuda" in args.dsl else
"cpu" if args.dsl == "cpp" else
"ascend"
)
preset = _BACKEND_PRESETS.get(preset_key, _BACKEND_PRESETS["ascend"])
if args.dsl is None:
args.dsl = preset["dsl"]
if args.backend is None:
args.backend = preset["backend"]
if args.arch is None:
args.arch = preset["arch"]
device_map = {"ascend": "npu", "cuda": "cuda", "cpu": "cpu"}
device = device_map.get(args.backend, "npu")
if args.desc:
print(f"[run_autoresearch] Generating reference from description...")
print(f" \"{args.desc}\"")
task_desc = await generate_task_desc(args.desc, device=device,
gen_retries=args.gen_retries)
print(f"[run_autoresearch] Reference generated OK")
else:
task_desc = _read_and_validate_ref(args.ref)
print(f"[run_autoresearch] Reference loaded: {args.ref}")
previous_code = ""
if args.kernel:
if not os.path.isfile(args.kernel):
raise FileNotFoundError(f"Kernel file not found: {args.kernel}")
with open(args.kernel, "r", encoding="utf-8") as f:
previous_code = f.read()
print(f"[run_autoresearch] Initial kernel loaded: {args.kernel}")
op_name = args.op_name or (derive_op_name(args.desc) if args.desc else "custom_op")
if args.worker_url:
from akg_agents.cli.service.worker_service import WorkerService
urls = WorkerService.parse_workers(args.worker_url)
if not urls:
raise ValueError(f"No valid worker URLs parsed from: {args.worker_url}")
for url in urls:
await register_remote_worker(
backend=args.backend, arch=args.arch, worker_url=url,
)
print(f"[run_autoresearch] Registered remote worker: {url}")
else:
device_id = args.device_id if args.device_id is not None else 0
await register_local_worker([device_id], backend=args.backend, arch=args.arch)
config = load_config(dsl=args.dsl, backend=args.backend)
config["task_label"] = resolve_task_label(op_name=op_name, parallel_index=1)
config["max_step"] = args.max_rounds
task = LangGraphTask(
op_name=op_name,
task_desc=task_desc,
task_id=f"{op_name}_001",
backend=args.backend,
arch=args.arch,
dsl=args.dsl,
config=config,
framework=args.framework,
workflow="autoresearch",
previous_code=previous_code,
)
result_op_name, success, final_state = await task.run()
print(f"\n{'=' * 60}")
print(f"Op: {result_op_name} | Result: {'SUCCESS' if success else 'FAILED'}")
if success:
profile = final_state.get("profile_res", {})
print(f"gen_time: {profile.get('gen_time', '?')} us")
print(f"base_time: {profile.get('base_time', '?')} us")
print(f"speedup: {profile.get('speedup', '?')}")
print(f"\nFinal kernel:\n{final_state.get('coder_code', '')}")
else:
print(f"Error: {final_state.get('verifier_error', '')}")
if __name__ == "__main__":
asyncio.run(main())