#!/usr/bin/env python3
"""
QMD Query Expansion fine-tuning with Unsloth (Qwen3.5 support).

Usage:
    python train_unsloth.py --model 0.8B
    python train_unsloth.py --model 2B
    python train_unsloth.py --model 4B --epochs 3

Requires: pip install unsloth unsloth_zoo
"""

import argparse
import json
import sys
from pathlib import Path

MODEL_MAP = {
    "0.8B": "unsloth/Qwen3.5-0.8B",
    "2B":   "unsloth/Qwen3.5-2B",
    "4B":   "unsloth/Qwen3.5-4B",
    "9B":   "unsloth/Qwen3.5-9B",
    "27B":  "unsloth/Qwen3.5-27B",
}

def main():
    parser = argparse.ArgumentParser(description="QMD fine-tuning with Unsloth")
    parser.add_argument("--model", required=True, choices=list(MODEL_MAP.keys()),
                        help="Model size to train")
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--grad-accum", type=int, default=4)
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--max-seq-len", type=int, default=512)
    parser.add_argument("--lora-rank", type=int, default=16)
    parser.add_argument("--data", type=str, default="data/train/train.jsonl")
    parser.add_argument("--output", type=str, default=None,
                        help="Output directory (default: outputs/qwen3.5-{size})")
    parser.add_argument("--push-hub", type=str, default=None,
                        help="Push to HF hub (e.g. tobil/qmd-query-expansion-qwen3.5-0.8B)")
    parser.add_argument("--no-gguf", action="store_true")
    parser.add_argument("--no-eval", action="store_true")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    model_name = MODEL_MAP[args.model]
    output_dir = args.output or f"outputs/qwen3.5-{args.model}"

    print(f"{'='*60}")
    print(f"QMD Query Expansion — Unsloth SFT")
    print(f"  Base model:  {model_name}")
    print(f"  Output:      {output_dir}")
    print(f"  Data:        {args.data}")
    print(f"  Epochs:      {args.epochs}")
    print(f"  Batch:       {args.batch_size} x {args.grad_accum} accum")
    print(f"  LR:          {args.lr}")
    print(f"  LoRA rank:   {args.lora_rank}")
    print(f"  Max seq len: {args.max_seq_len}")
    print(f"{'='*60}")

    if args.dry_run:
        print("Dry run — exiting.")
        return

    # --- Imports (heavy) ---
    import os
    import torch
    from unsloth import FastLanguageModel
    from datasets import load_dataset
    from trl import SFTTrainer, SFTConfig

    # --- Load model ---
    print(f"\nLoading {model_name}...")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_name,
        max_seq_length=args.max_seq_len,
        load_in_4bit=False,
        load_in_16bit=True,
        full_finetuning=False,
    )

    # --- LoRA ---
    model = FastLanguageModel.get_peft_model(
        model,
        r=args.lora_rank,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_alpha=args.lora_rank,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        max_seq_length=args.max_seq_len,
    )

    # --- Dataset ---
    print(f"Loading dataset from {args.data}...")
    dataset = load_dataset("json", data_files=args.data, split="train")
    dataset = dataset.shuffle(seed=42)
    split = dataset.train_test_split(test_size=0.1, seed=42)
    train_ds = split["train"]
    eval_ds = split["test"]
    print(f"  Train: {len(train_ds)}, Eval: {len(eval_ds)}")

    # --- Tracking ---
    report_to = "none"
    if os.environ.get("HF_TOKEN"):
        try:
            import trackio
            report_to = "trackio"
            os.environ.setdefault("TRACKIO_PROJECT", "qmd-query-expansion")
        except ImportError:
            pass

    # --- Trainer ---
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        args=SFTConfig(
            output_dir=output_dir,
            max_seq_length=args.max_seq_len,
            num_train_epochs=args.epochs,
            per_device_train_batch_size=args.batch_size,
            gradient_accumulation_steps=args.grad_accum,
            learning_rate=args.lr,
            warmup_ratio=0.03,
            lr_scheduler_type="cosine",
            logging_steps=10,
            save_strategy="steps",
            save_steps=200,
            save_total_limit=3,
            eval_strategy="steps",
            eval_steps=200,
            bf16=True,
            optim="adamw_8bit",
            seed=3407,
            dataset_num_proc=4,
            report_to=report_to,
            run_name=f"sft-qwen3.5-{args.model}",
        ),
    )

    print("\nStarting training...")
    stats = trainer.train()
    print(f"\nTraining complete!")
    print(f"  Total steps: {stats.global_step}")
    print(f"  Final loss:  {stats.training_loss:.4f}")

    # --- Save ---
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Adapter saved to {output_dir}")

    # --- GGUF export ---
    if not args.no_gguf:
        print("\nExporting GGUF quantizations...")
        gguf_dir = f"{output_dir}/gguf"
        for quant in ["q4_k_m", "q8_0"]:
            print(f"  {quant}...")
            try:
                model.save_pretrained_gguf(
                    gguf_dir, tokenizer, quantization_method=quant
                )
                print(f"  ✓ {quant} saved")
            except Exception as e:
                print(f"  ✗ {quant} failed: {e}")

    # --- Push to Hub ---
    if args.push_hub:
        print(f"\nPushing to {args.push_hub}...")
        model.push_to_hub_merged(args.push_hub, tokenizer, save_method="lora")
        if not args.no_gguf:
            for quant in ["q4_k_m", "q8_0"]:
                try:
                    model.push_to_hub_gguf(args.push_hub, tokenizer, quantization_method=quant)
                except Exception as e:
                    print(f"  GGUF push {quant} failed: {e}")

    # --- Eval ---
    if not args.no_eval:
        print("\nRunning evaluation...")
        import subprocess
        subprocess.run(
            [sys.executable, "eval.py", output_dir],
            cwd=str(Path(__file__).parent),
        )

    print(f"\n{'='*60}")
    print(f"Done! Model at: {output_dir}")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()