"""Run one Codex migration command for each operator in the JSON list."""
from __future__ import annotations
import argparse
import json
import re
import subprocess
import sys
from pathlib import Path
from typing import Any
DEFAULT_JSON = "llm_common_torch_ops_20.json"
DEFAULT_LOG_DIR = "codex_migration_logs"
PROMPT_TEMPLATE = (
"帮忙pytorch/aten/src/ATen/native/cuda下{cu_file}中{torch_api}算子迁移到Ascend "
"SIMT平台,迁移过程中遇到用户决策点,一致选不同意降级的方案,坚持一对一迁移。"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Read operator JSON and run codex exec once for each operator."
)
parser.add_argument(
"json_file",
nargs="?",
default=DEFAULT_JSON,
help=f"operator JSON file path, default: {DEFAULT_JSON}",
)
parser.add_argument(
"--codex-bin",
default="codex",
help="codex executable name or path, default: codex",
)
parser.add_argument(
"--log-dir",
default=DEFAULT_LOG_DIR,
help=f"directory for per-operator logs, default: {DEFAULT_LOG_DIR}",
)
parser.add_argument(
"--continue-on-error",
action="store_true",
help="continue running the remaining operators when one command fails",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="print commands without executing them",
)
return parser.parse_args()
def load_ops(json_file: Path) -> list[dict[str, Any]]:
with json_file.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError(f"{json_file} must contain a JSON list")
for index, item in enumerate(data, start=1):
if not isinstance(item, dict):
raise ValueError(f"item {index} must be a JSON object")
for key in ("torch_api", "cu_file"):
if not isinstance(item.get(key), str) or not item[key].strip():
raise ValueError(f"item {index} missing non-empty string field: {key}")
return data
def safe_log_name(index: int, torch_api: str, cu_file: str) -> str:
stem = f"{index:02d}_{torch_api}_{cu_file}"
stem = stem.replace("torch.", "")
stem = re.sub(r"[^A-Za-z0-9_.-]+", "_", stem)
return f"{stem}.log"
def build_prompt(item: dict[str, Any]) -> str:
return PROMPT_TEMPLATE.format(
cu_file=item["cu_file"].strip(),
torch_api=item["torch_api"].strip(),
)
def run_one(
*,
codex_bin: str,
prompt: str,
log_file: Path,
dry_run: bool,
) -> int:
command = [
codex_bin,
"exec",
"--dangerously-bypass-approvals-and-sandbox",
prompt,
]
if dry_run:
print("DRY-RUN:", subprocess.list2cmdline(command))
return 0
log_file.parent.mkdir(parents=True, exist_ok=True)
with log_file.open("w", encoding="utf-8") as f:
f.write("$ " + subprocess.list2cmdline(command) + "\n\n")
f.flush()
completed = subprocess.run(command, stdout=f, stderr=subprocess.STDOUT)
return completed.returncode
def main() -> int:
args = parse_args()
json_file = Path(args.json_file)
log_dir = Path(args.log_dir)
try:
ops = load_ops(json_file)
except (OSError, json.JSONDecodeError, ValueError) as exc:
print(f"failed to load operator JSON: {exc}", file=sys.stderr)
return 2
failed: list[tuple[int, str, int]] = []
total = len(ops)
for index, item in enumerate(ops, start=1):
torch_api = item["torch_api"].strip()
cu_file = item["cu_file"].strip()
prompt = build_prompt(item)
log_file = log_dir / safe_log_name(index, torch_api, cu_file)
print(f"[{index}/{total}] running {torch_api} from {cu_file}")
if not args.dry_run:
print(f" log: {log_file}")
returncode = run_one(
codex_bin=args.codex_bin,
prompt=prompt,
log_file=log_file,
dry_run=args.dry_run,
)
if returncode != 0:
failed.append((index, torch_api, returncode))
print(f" failed with exit code {returncode}", file=sys.stderr)
if not args.continue_on_error:
break
if failed:
print("\nfailed operators:", file=sys.stderr)
for index, torch_api, returncode in failed:
print(f" {index}: {torch_api} exit={returncode}", file=sys.stderr)
return 1
print(f"\ncompleted {total} operator commands")
return 0
if __name__ == "__main__":
raise SystemExit(main())