# -*- coding: utf-8 -*-
"""一键执行最终训练与评估(简洁命令入口)。"""

from __future__ import annotations

import argparse
import os
import subprocess
import sys
from typing import List


PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))


def _build_env() -> dict[str, str]:
    env = os.environ.copy()
    # 当前环境默认走 CPU,避免 CUDA 版本不匹配带来的中断。
    env.setdefault("CUDA_VISIBLE_DEVICES", "")
    return env


def _run(stage: str, cmd: List[str], env: dict[str, str]) -> None:
    print(f"\n[{stage}] 开始...", flush=True)
    ret = subprocess.run(cmd, cwd=PROJECT_DIR, env=env, check=False)
    if ret.returncode != 0:
        raise SystemExit(f"{stage}失败(退出码 {ret.returncode})。")
    print(f"[{stage}] 完成", flush=True)


def main() -> None:
    parser = argparse.ArgumentParser(description="一键训练并评估最终模型。")
    parser.add_argument("--dataset-dir", type=str, required=True, help="数据集目录(例如 BPT-V/dataset)。")
    parser.add_argument("--epochs", type=int, default=20, help="训练轮数。")
    parser.add_argument("--batch-size", type=int, default=8, help="批大小。")
    parser.add_argument("--lr", type=float, default=5e-4, help="学习率。")
    parser.add_argument("--seed", type=int, default=42, help="随机种子。")
    parser.add_argument("--run-name", type=str, default="final_optimized", help="本次运行名称。")
    args = parser.parse_args()

    dataset_dir = os.path.abspath(args.dataset_dir)
    if not os.path.exists(os.path.join(dataset_dir, "video_001", "data.json")):
        raise FileNotFoundError(f"数据集目录无效:{dataset_dir}")

    train_cmd = [
        sys.executable,
        os.path.join(PROJECT_DIR, "train.py"),
        "--dataset-dir",
        dataset_dir,
        "--epochs",
        str(args.epochs),
        "--batch-size",
        str(args.batch_size),
        "--lr",
        str(args.lr),
        "--seed",
        str(args.seed),
        "--run-name",
        args.run_name,
        "--loss-mode",
        "mse",
        "--center-loss-weight",
        "1.0",
        "--augment-prob",
        "0.2",
        "--augment-occlusion-ratio",
        "0.35",
        "--augment-noise-ratio",
        "0.01",
        "--event-loss-weight",
        "1.0",
        "--event-pos-weight",
        "6.0",
        "--event-focal-gamma",
        "2.0",
        "--uncertainty-loss-weight",
        "0.3",
        "--uncertainty-residual-reg-weight",
        "0.05",
        "--event-loss-start-weight",
        "0.0",
        "--uncertainty-loss-start-weight",
        "0.0",
        "--aux-warmup-epochs",
        "6",
        "--aux-stage2-epoch",
        "20",
        "--event-stage2-scale",
        "0.8",
        "--uncertainty-stage2-scale",
        "0.8",
        "--select-metric",
        "hybrid",
        "--hybrid-event-weight",
        "0.3",
        "--use-hard-mining",
    ]

    eval_cmd = [
        sys.executable,
        os.path.join(PROJECT_DIR, "evaluate_physics.py"),
        "--dataset-dir",
        dataset_dir,
        "--ckpt",
        os.path.join("checkpoints", args.run_name, "best_model.pt"),
        "--tag",
        args.run_name,
    ]

    env = _build_env()

    print("=" * 60, flush=True)
    print("开始执行:最终训练 + 评估", flush=True)
    print(f"数据集目录:{dataset_dir}", flush=True)
    print(f"运行名称:{args.run_name}", flush=True)
    print("=" * 60, flush=True)

    _run("训练阶段", train_cmd, env)
    _run("评估阶段", eval_cmd, env)

    print("\n" + "=" * 60, flush=True)
    print("流程完成", flush=True)
    print(f"模型目录:checkpoints\\{args.run_name}", flush=True)
    print(f"评估结果:outputs\\metrics\\model_eval_{args.run_name}.json", flush=True)
    print("=" * 60, flush=True)


if __name__ == "__main__":
    main()