#!/usr/bin/env python3
"""Run one Codex migration command for each operator in the JSON list."""

from __future__ import annotations

import argparse
import json
import re
import subprocess
import sys
from pathlib import Path
from typing import Any


DEFAULT_JSON = "llm_common_torch_ops_20.json"
DEFAULT_LOG_DIR = "codex_migration_logs"
PROMPT_TEMPLATE = (
    "帮忙pytorch/aten/src/ATen/native/cuda下{cu_file}中{torch_api}算子迁移到Ascend "
    "SIMT平台,迁移过程中遇到用户决策点,一致选不同意降级的方案,坚持一对一迁移。"
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Read operator JSON and run codex exec once for each operator."
    )
    parser.add_argument(
        "json_file",
        nargs="?",
        default=DEFAULT_JSON,
        help=f"operator JSON file path, default: {DEFAULT_JSON}",
    )
    parser.add_argument(
        "--codex-bin",
        default="codex",
        help="codex executable name or path, default: codex",
    )
    parser.add_argument(
        "--log-dir",
        default=DEFAULT_LOG_DIR,
        help=f"directory for per-operator logs, default: {DEFAULT_LOG_DIR}",
    )
    parser.add_argument(
        "--continue-on-error",
        action="store_true",
        help="continue running the remaining operators when one command fails",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="print commands without executing them",
    )
    return parser.parse_args()


def load_ops(json_file: Path) -> list[dict[str, Any]]:
    with json_file.open("r", encoding="utf-8") as f:
        data = json.load(f)

    if not isinstance(data, list):
        raise ValueError(f"{json_file} must contain a JSON list")

    for index, item in enumerate(data, start=1):
        if not isinstance(item, dict):
            raise ValueError(f"item {index} must be a JSON object")
        for key in ("torch_api", "cu_file"):
            if not isinstance(item.get(key), str) or not item[key].strip():
                raise ValueError(f"item {index} missing non-empty string field: {key}")

    return data


def safe_log_name(index: int, torch_api: str, cu_file: str) -> str:
    stem = f"{index:02d}_{torch_api}_{cu_file}"
    stem = stem.replace("torch.", "")
    stem = re.sub(r"[^A-Za-z0-9_.-]+", "_", stem)
    return f"{stem}.log"


def build_prompt(item: dict[str, Any]) -> str:
    return PROMPT_TEMPLATE.format(
        cu_file=item["cu_file"].strip(),
        torch_api=item["torch_api"].strip(),
    )


def run_one(
    *,
    codex_bin: str,
    prompt: str,
    log_file: Path,
    dry_run: bool,
) -> int:
    command = [
        codex_bin,
        "exec",
        "--dangerously-bypass-approvals-and-sandbox",
        prompt,
    ]

    if dry_run:
        print("DRY-RUN:", subprocess.list2cmdline(command))
        return 0

    log_file.parent.mkdir(parents=True, exist_ok=True)
    with log_file.open("w", encoding="utf-8") as f:
        f.write("$ " + subprocess.list2cmdline(command) + "\n\n")
        f.flush()
        completed = subprocess.run(command, stdout=f, stderr=subprocess.STDOUT)
        return completed.returncode


def main() -> int:
    args = parse_args()
    json_file = Path(args.json_file)
    log_dir = Path(args.log_dir)

    try:
        ops = load_ops(json_file)
    except (OSError, json.JSONDecodeError, ValueError) as exc:
        print(f"failed to load operator JSON: {exc}", file=sys.stderr)
        return 2

    failed: list[tuple[int, str, int]] = []
    total = len(ops)
    for index, item in enumerate(ops, start=1):
        torch_api = item["torch_api"].strip()
        cu_file = item["cu_file"].strip()
        prompt = build_prompt(item)
        log_file = log_dir / safe_log_name(index, torch_api, cu_file)

        print(f"[{index}/{total}] running {torch_api} from {cu_file}")
        if not args.dry_run:
            print(f"  log: {log_file}")

        returncode = run_one(
            codex_bin=args.codex_bin,
            prompt=prompt,
            log_file=log_file,
            dry_run=args.dry_run,
        )
        if returncode != 0:
            failed.append((index, torch_api, returncode))
            print(f"  failed with exit code {returncode}", file=sys.stderr)
            if not args.continue_on_error:
                break

    if failed:
        print("\nfailed operators:", file=sys.stderr)
        for index, torch_api, returncode in failed:
            print(f"  {index}: {torch_api} exit={returncode}", file=sys.stderr)
        return 1

    print(f"\ncompleted {total} operator commands")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())