# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch",
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.45.0",
#     "accelerate>=0.24.0",
#     "huggingface_hub>=0.20.0",
#     "trackio",
#     "nvidia-ml-py",
#     "datasets",
#     "bitsandbytes",
#     "pyyaml",
#     "gguf",
# ]
# ///
"""
Unified training script for QMD query expansion models.

Primary pipeline is SFT-only:
  sft  - Supervised fine-tuning on labeled examples

GRPO was moved to `experiments/grpo/` and is not part of the main training
pipeline by default.

Usage:
    uv run train.py sft  --config configs/sft.yaml
"""

import argparse
import os
import subprocess
import sys
import time
from pathlib import Path

import yaml
from transformers import TrainerCallback


def export_gguf(model, tokenizer, output_dir: str, model_name: str):
    """Export model to GGUF at Q4_K_M, Q6_K, Q8_0 quantizations."""
    import shutil
    import tempfile

    output_path = Path(output_dir)
    gguf_dir = output_path / "gguf"
    gguf_dir.mkdir(exist_ok=True)

    # Save merged model to temp dir
    print("Saving merged model for GGUF conversion...")
    with tempfile.TemporaryDirectory() as tmp:
        merged_path = Path(tmp) / "merged"
        model.save_pretrained(merged_path, safe_serialization=True)
        tokenizer.save_pretrained(merged_path)

        # Setup llama.cpp
        llama_cpp = Path("/tmp/llama.cpp")
        if not llama_cpp.exists():
            print("Cloning llama.cpp...")
            subprocess.run(
                [
                    "git",
                    "clone",
                    "--depth",
                    "1",
                    "https://github.com/ggerganov/llama.cpp.git",
                    str(llama_cpp),
                ],
                capture_output=True,
            )
            subprocess.run(
                [
                    sys.executable,
                    "-m",
                    "pip",
                    "install",
                    "-q",
                    "-r",
                    str(llama_cpp / "requirements.txt"),
                ],
                capture_output=True,
            )

        # Build quantize tool if needed
        quantize_bin = llama_cpp / "build" / "bin" / "llama-quantize"
        if not quantize_bin.exists():
            print("Building llama-quantize...")
            build_dir = llama_cpp / "build"
            build_dir.mkdir(exist_ok=True)
            subprocess.run(
                [
                    "cmake",
                    "-B",
                    str(build_dir),
                    "-S",
                    str(llama_cpp),
                    "-DGGML_CUDA=OFF",
                ],
                capture_output=True,
            )
            subprocess.run(
                [
                    "cmake",
                    "--build",
                    str(build_dir),
                    "--target",
                    "llama-quantize",
                    "-j",
                    "4",
                ],
                capture_output=True,
            )

        # Convert to FP16 first
        fp16_file = gguf_dir / f"{model_name}-f16.gguf"
        print(f"Converting to FP16: {fp16_file}")
        log_out = Path("/tmp/qmd-gguf-convert.log")
        log_err = Path("/tmp/qmd-gguf-convert.err")
        with log_out.open("w") as out_f, log_err.open("w") as err_f:
            result = subprocess.run(
                [
                    sys.executable,
                    str(llama_cpp / "convert_hf_to_gguf.py"),
                    str(merged_path),
                    "--outfile",
                    str(fp16_file),
                    "--outtype",
                    "f16",
                ],
                stdout=out_f,
                stderr=err_f,
                text=True,
            )
        if result.returncode != 0:
            print("GGUF conversion failed.")
            print(f"stdout: {log_out}")
            print(f"stderr: {log_err}")
            return

        # Quantize to 4, 6, 8 bit
        for quant_type in ["Q4_K_M", "Q6_K", "Q8_0"]:
            out_file = gguf_dir / f"{model_name}-{quant_type.lower()}.gguf"
            print(f"Quantizing {quant_type}: {out_file}")
            subprocess.run(
                [str(quantize_bin), str(fp16_file), str(out_file), quant_type],
                capture_output=True,
            )
            if out_file.exists():
                size_mb = out_file.stat().st_size / (1024 * 1024)
                print(f"  {quant_type}: {size_mb:.1f} MB")

        # Remove FP16 to save space
        if fp16_file.exists():
            fp16_file.unlink()

    print(f"GGUF files saved to: {gguf_dir}")


class TimedSaveCallback(TrainerCallback):
    """Trigger periodic checkpoint saves based on elapsed wall-clock time."""

    def __init__(self, interval_minutes: float):
        self.interval_seconds = float(interval_minutes) * 60.0
        self.last_save_time = time.time()

    def on_step_end(self, args, state, control, **kwargs):
        if not getattr(state, "is_world_process_zero", False):
            return control

        now = time.time()
        if now - self.last_save_time >= self.interval_seconds:
            control.should_save = True
            self.last_save_time = now
        return control


def run_eval(model_path: str) -> float | None:
    """Run eval.py on the trained model and return average score."""
    print("\n" + "=" * 60)
    print("Running evaluation...")
    print("=" * 60)

    eval_script = Path(__file__).parent / "eval.py"
    result = subprocess.run(
        [sys.executable, str(eval_script), model_path],
        cwd=str(Path(__file__).parent),
        capture_output=True,
        text=True,
    )
    if result.stdout:
        print(result.stdout, end="")
    if result.stderr:
        print(result.stderr, end="")

    avg = None
    for line in (result.stdout or "").splitlines():
        if line.strip().startswith("Average:"):
            try:
                avg = float(line.split("Average:", 1)[1].split("%", 1)[0].strip())
            except ValueError:
                pass
            break
    return avg


def cmd_sft(args):
    """Run supervised fine-tuning."""
    import torch
    from datasets import load_dataset
    import torch.distributed as dist
    from peft import LoraConfig
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from transformers.utils import logging as hf_logging

    hf_logging.set_verbosity_error()
    from trl import SFTTrainer, SFTConfig

    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")

    if args.dry_run:
        print("SFT Training Configuration:")
        print(yaml.dump(cfg, default_flow_style=False))
        return

    dataset_name = cfg["dataset"]["name"]
    print(f"Loading dataset: {dataset_name}...")

    # Support local JSONL files and glob patterns
    if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
        from pathlib import Path
        import glob

        # Handle glob patterns like "data/*.jsonl"
        if "*" in dataset_name:
            jsonl_files = sorted(glob.glob(dataset_name))
            if not jsonl_files:
                raise ValueError(f"No files found matching: {dataset_name}")
            print(
                f"  Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}"
            )
            dataset = load_dataset("json", data_files=jsonl_files, split="train")
        else:
            data_path = Path(dataset_name)
            if data_path.is_dir():
                train_file = data_path / "train.jsonl"
                dataset = load_dataset(
                    "json", data_files=str(train_file), split="train"
                )
            else:
                dataset = load_dataset("json", data_files=dataset_name, split="train")
    else:
        dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
    print(f"Dataset loaded: {len(dataset)} examples")

    dataset = dataset.shuffle(seed=42)
    split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
    train_dataset = split["train"]
    eval_dataset = split["test"]
    print(f"  Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")

    # Check if output looks like a HF Hub path (contains /)
    output_name = cfg["model"]["output"]
    push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
    if "push_to_hub" in cfg["model"]:
        push_to_hub = bool(cfg["model"]["push_to_hub"])
    output_dir = output_name.split("/")[-1] if push_to_hub else output_name

    report_to = "none"
    if os.environ.get("HF_TOKEN"):
        try:
            import trackio  # noqa: F401

            report_to = "trackio"
        except Exception:
            print("Trackio not installed; disabling tracking.")

    tracking = cfg.get("tracking", {})
    if report_to == "trackio":
        project = tracking.get("project")
        if project:
            os.environ.setdefault("TRACKIO_PROJECT", project)

    run_name = tracking.get("run_name")
    if run_name and "{" in run_name:
        from datetime import datetime

        now = datetime.now()
        run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
            "{time}", now.strftime("%H:%M")
        )

    save_interval_minutes = cfg["training"].get("save_interval_minutes")
    save_steps = cfg["training"].get("save_steps", 200)
    save_total_limit = cfg["training"].get("save_total_limit", 2)
    if save_interval_minutes:
        # Prefer wall-clock checkpointing (for long jobs / preemption safety)
        save_steps = max(save_steps, 10_000_000)

    callbacks = []
    if save_interval_minutes:
        try:
            interval_value = float(save_interval_minutes)
        except (TypeError, ValueError):
            interval_value = None
        if interval_value and interval_value > 0:
            callbacks.append(TimedSaveCallback(interval_value))

    config = SFTConfig(
        output_dir=output_dir,
        push_to_hub=push_to_hub,
        hub_model_id=output_name if push_to_hub else None,
        hub_strategy="every_save" if push_to_hub else "end",
        num_train_epochs=cfg["training"]["epochs"],
        per_device_train_batch_size=cfg["training"]["batch_size"],
        gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
        learning_rate=cfg["training"]["learning_rate"],
        max_length=cfg["training"]["max_length"],
        logging_steps=10,
        save_strategy="steps",
        save_steps=save_steps,
        save_total_limit=save_total_limit,
        eval_strategy="steps",
        eval_steps=cfg["training"].get("eval_steps", 200),
        warmup_ratio=cfg["training"]["warmup_ratio"],
        lr_scheduler_type=cfg["training"]["lr_scheduler"],
        ddp_find_unused_parameters=cfg["training"].get(
            "ddp_find_unused_parameters", False
        ),
        bf16=True,
        report_to=report_to,
        run_name=run_name if report_to == "trackio" else None,
    )

    # LoRA config with modules_to_save for embedding layers
    # This prevents token ID mismatches during inference
    peft_config = LoraConfig(
        r=cfg["lora"]["rank"],
        lora_alpha=cfg["lora"]["alpha"],
        lora_dropout=cfg["lora"]["dropout"],
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=cfg["lora"]["target_modules"],
        modules_to_save=["embed_tokens", "lm_head"],  # Critical for special tokens
        ensure_weight_tying=True,
    )

    print("Loading tokenizer...")
    base_model = cfg["model"]["base"]
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print("Initializing SFT trainer...")
    trainer = SFTTrainer(
        model=base_model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        args=config,
        peft_config=peft_config,
        processing_class=tokenizer,
        callbacks=callbacks,
    )

    print("Starting SFT training...")
    trainer.train()

    is_main = os.environ.get("RANK", "0") == "0"
    if dist.is_available() and dist.is_initialized():
        dist.barrier()

    if not is_main:
        return

    if push_to_hub:
        print("Pushing to Hub...")
        trainer.push_to_hub()
        print(f"Done! Model: https://huggingface.co/{output_name}")
    else:
        trainer.save_model()
        print(f"Done! Model saved to: {output_dir}")

    # Export GGUF
    print("\nExporting to GGUF...")
    # Need to get the merged model for GGUF
    print("Loading model for GGUF export...")
    from peft import PeftModel

    base = AutoModelForCausalLM.from_pretrained(
        base_model, torch_dtype=torch.bfloat16, device_map="auto"
    )
    base.config.tie_word_embeddings = False
    model = PeftModel.from_pretrained(base, output_dir, local_files_only=True)
    model = model.merge_and_unload()
    export_gguf(model, tokenizer, output_dir, Path(output_dir).name)

    # Run eval
    eval_avg = run_eval(output_dir)
    if report_to == "trackio":
        try:
            import trackio

            if eval_avg is not None:
                trackio.log({"eval.avg": eval_avg})
        except Exception:
            pass


def cmd_grpo(args):
    """Run GRPO reinforcement learning on top of merged SFT weights."""
    print(
        "GRPO is not part of the main training pipeline and has been moved to `experiments/grpo/`."
    )
    print("To run experimental GRPO, use:")
    print("  cd finetune && uv run python experiments/grpo/grpo.py")
    return

    import torch
    import torch.distributed as dist
    import os
    from datasets import load_dataset
    from peft import LoraConfig, PeftModel, get_peft_model
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from transformers.utils import logging as hf_logging

    hf_logging.set_verbosity_error()
    from trl import GRPOTrainer, GRPOConfig

    # Import reward from the shared module
    sys.path.insert(0, os.path.dirname(__file__))
    from reward import QMDRewardFunction

    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")

    if args.dry_run:
        print("GRPO Training Configuration:")
        print(yaml.dump(cfg, default_flow_style=False))
        return

    # Tracking
    report_to = "none"
    if os.environ.get("HF_TOKEN"):
        try:
            import trackio  # noqa: F401

            report_to = "trackio"
        except Exception:
            print("Trackio not installed; disabling tracking.")

    tracking = cfg.get("tracking", {})
    if report_to == "trackio":
        project = tracking.get("project")
        if project:
            os.environ.setdefault("TRACKIO_PROJECT", project)

    run_name = tracking.get("run_name")
    if run_name and "{" in run_name:
        from datetime import datetime

        now = datetime.now()
        run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
            "{time}", now.strftime("%H:%M")
        )

    # Load tokenizer
    base_model_name = cfg["model"]["base"]
    print(f"Loading tokenizer from {base_model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load and format dataset
    print("Loading dataset...")
    dataset = load_dataset(cfg["dataset"]["name"], split="train")

    def extract_prompt(example):
        content = example[cfg["dataset"]["prompt_field"]][0]["content"]
        messages = [{"role": "user", "content": content}]
        formatted = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return {"prompt": formatted}

    dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
    max_samples = cfg["dataset"].get("max_samples", len(dataset))
    dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
    print(f"Using {len(dataset)} prompts for GRPO")

    # Load base model, merge SFT adapter
    sft_model_name = cfg["model"]["sft"]
    print(f"Loading SFT model from {sft_model_name}...")
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    if torch.cuda.is_available():
        available = torch.cuda.device_count()
        if available == 0:
            raise RuntimeError("CUDA is available but no devices were detected.")
        if local_rank >= available:
            print(
                f"Warning: LOCAL_RANK={local_rank} but only {available} CUDA device(s) visible. "
                "Falling back to the last available device."
            )
            local_rank = available - 1
        torch.cuda.set_device(local_rank)
    dtype_name = cfg["model"].get("torch_dtype", "bfloat16")
    dtype_map = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
    }
    torch_dtype = dtype_map.get(dtype_name, torch.bfloat16)
    model_kwargs = {
        "torch_dtype": torch_dtype,
        "device_map": {"": local_rank} if torch.cuda.is_available() else "auto",
    }

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        **model_kwargs,
    )
    model = PeftModel.from_pretrained(base_model, sft_model_name)
    model = model.merge_and_unload()
    print("SFT adapter merged.")

    # Add fresh LoRA for GRPO with modules_to_save
    grpo_lora_config = LoraConfig(
        r=cfg["lora"]["rank"],
        lora_alpha=cfg["lora"]["alpha"],
        lora_dropout=cfg["lora"]["dropout"],
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=cfg["lora"]["target_modules"],
        modules_to_save=["embed_tokens", "lm_head"],  # Critical for special tokens
        ensure_weight_tying=True,
    )
    model = get_peft_model(model, grpo_lora_config)
    model.print_trainable_parameters()

    # Build GRPO config
    output_name = cfg["model"]["output"]
    push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
    if "push_to_hub" in cfg["model"]:
        push_to_hub = bool(cfg["model"]["push_to_hub"])
    output_dir = output_name.split("/")[-1] if push_to_hub else output_name

    grpo_cfg = cfg.get("grpo", {})
    learning_rate = cfg["training"]["learning_rate"]
    if isinstance(learning_rate, str):
        learning_rate = float(learning_rate)

    save_interval_minutes = cfg["training"].get("save_interval_minutes")
    save_steps = cfg["training"].get("save_steps", 200)
    save_total_limit = cfg["training"].get("save_total_limit", 2)
    save_strategy = cfg["training"].get("save_strategy", "epoch")
    if save_interval_minutes:
        # Prefer wall-clock checkpointing (for long jobs / preemption safety)
        save_steps = max(save_steps, 10_000_000)
        save_strategy = "steps"

    callbacks = []
    if save_interval_minutes:
        try:
            interval_value = float(save_interval_minutes)
        except (TypeError, ValueError):
            interval_value = None
        if interval_value and interval_value > 0:
            callbacks.append(TimedSaveCallback(interval_value))

    config = GRPOConfig(
        output_dir=output_dir,
        push_to_hub=push_to_hub,
        hub_model_id=output_name if push_to_hub else None,
        num_generations=grpo_cfg.get("num_generations", 4),
        max_completion_length=grpo_cfg.get("max_completion_length", 200),
        beta=grpo_cfg.get("beta", 0.04),
        num_train_epochs=cfg["training"]["epochs"],
        per_device_train_batch_size=cfg["training"]["batch_size"],
        gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
        learning_rate=learning_rate,
        max_grad_norm=cfg["training"]["max_grad_norm"],
        max_steps=cfg["training"].get("max_steps", -1),
        logging_steps=10,
        save_strategy=save_strategy,
        save_steps=save_steps,
        save_total_limit=save_total_limit,
        bf16=True,
        skip_memory_metrics=True,
        report_to=report_to,
        run_name=run_name if report_to == "trackio" else None,
    )

    # Train
    print("Initializing GRPO trainer...")
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        args=config,
        train_dataset=dataset,
        reward_funcs=[QMDRewardFunction()],
        callbacks=callbacks,
    )

    print("Starting GRPO training...")
    trainer.train()

    is_main = os.environ.get("RANK", "0") == "0"
    if dist.is_available() and dist.is_initialized():
        dist.barrier()
    if not is_main:
        return

    if push_to_hub:
        print("Pushing to Hub...")
        trainer.push_to_hub()

    trainer.save_model()
    if report_to == "trackio":
        try:
            import trackio

            trackio.finish()
        except Exception:
            pass
    print(f"Done! Model saved to: {output_dir}")

    # Export GGUF
    print("\nExporting to GGUF...")
    merged = model.merge_and_unload()
    export_gguf(merged, tokenizer, output_dir, Path(output_dir).name)

    # Run eval
    eval_avg = run_eval(output_dir)
    if report_to == "trackio" and eval_avg is not None:
        try:
            import trackio

            trackio.log({"eval.avg": eval_avg})
        except Exception:
            pass


def main():
    parser = argparse.ArgumentParser(
        description="QMD Query Expansion Training",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  uv run train.py sft  --config configs/sft.yaml
        """,
    )
    sub = parser.add_subparsers(dest="stage", required=True)

    sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
    sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
    sft_parser.add_argument(
        "--dry-run", action="store_true", help="Print config and exit"
    )

    args = parser.parse_args()

    cmd_sft(args)

if __name__ == "__main__":
    main()