"""一键执行最终训练与评估(简洁命令入口)。"""
from __future__ import annotations
import argparse
import os
import subprocess
import sys
from typing import List
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
def _build_env() -> dict[str, str]:
env = os.environ.copy()
env.setdefault("CUDA_VISIBLE_DEVICES", "")
return env
def _run(stage: str, cmd: List[str], env: dict[str, str]) -> None:
print(f"\n[{stage}] 开始...", flush=True)
ret = subprocess.run(cmd, cwd=PROJECT_DIR, env=env, check=False)
if ret.returncode != 0:
raise SystemExit(f"{stage}失败(退出码 {ret.returncode})。")
print(f"[{stage}] 完成", flush=True)
def main() -> None:
parser = argparse.ArgumentParser(description="一键训练并评估最终模型。")
parser.add_argument("--dataset-dir", type=str, required=True, help="数据集目录(例如 BPT-V/dataset)。")
parser.add_argument("--epochs", type=int, default=20, help="训练轮数。")
parser.add_argument("--batch-size", type=int, default=8, help="批大小。")
parser.add_argument("--lr", type=float, default=5e-4, help="学习率。")
parser.add_argument("--seed", type=int, default=42, help="随机种子。")
parser.add_argument("--run-name", type=str, default="final_optimized", help="本次运行名称。")
args = parser.parse_args()
dataset_dir = os.path.abspath(args.dataset_dir)
if not os.path.exists(os.path.join(dataset_dir, "video_001", "data.json")):
raise FileNotFoundError(f"数据集目录无效:{dataset_dir}")
train_cmd = [
sys.executable,
os.path.join(PROJECT_DIR, "train.py"),
"--dataset-dir",
dataset_dir,
"--epochs",
str(args.epochs),
"--batch-size",
str(args.batch_size),
"--lr",
str(args.lr),
"--seed",
str(args.seed),
"--run-name",
args.run_name,
"--loss-mode",
"mse",
"--center-loss-weight",
"1.0",
"--augment-prob",
"0.2",
"--augment-occlusion-ratio",
"0.35",
"--augment-noise-ratio",
"0.01",
"--event-loss-weight",
"1.0",
"--event-pos-weight",
"6.0",
"--event-focal-gamma",
"2.0",
"--uncertainty-loss-weight",
"0.3",
"--uncertainty-residual-reg-weight",
"0.05",
"--event-loss-start-weight",
"0.0",
"--uncertainty-loss-start-weight",
"0.0",
"--aux-warmup-epochs",
"6",
"--aux-stage2-epoch",
"20",
"--event-stage2-scale",
"0.8",
"--uncertainty-stage2-scale",
"0.8",
"--select-metric",
"hybrid",
"--hybrid-event-weight",
"0.3",
"--use-hard-mining",
]
eval_cmd = [
sys.executable,
os.path.join(PROJECT_DIR, "evaluate_physics.py"),
"--dataset-dir",
dataset_dir,
"--ckpt",
os.path.join("checkpoints", args.run_name, "best_model.pt"),
"--tag",
args.run_name,
]
env = _build_env()
print("=" * 60, flush=True)
print("开始执行:最终训练 + 评估", flush=True)
print(f"数据集目录:{dataset_dir}", flush=True)
print(f"运行名称:{args.run_name}", flush=True)
print("=" * 60, flush=True)
_run("训练阶段", train_cmd, env)
_run("评估阶段", eval_cmd, env)
print("\n" + "=" * 60, flush=True)
print("流程完成", flush=True)
print(f"模型目录:checkpoints\\{args.run_name}", flush=True)
print(f"评估结果:outputs\\metrics\\model_eval_{args.run_name}.json", flush=True)
print("=" * 60, flush=True)
if __name__ == "__main__":
main()