#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "transformers>=4.36.0",
#     "peft>=0.7.0",
#     "torch>=2.0.0",
#     "accelerate>=0.24.0",
#     "huggingface_hub>=0.20.0",
#     "sentencepiece>=0.1.99",
#     "protobuf>=3.20.0",
#     "numpy",
#     "optimum[onnxruntime]",
#     "onnx>=1.15.0",
#     "onnxruntime>=1.17.0",
#     "onnxconverter-common>=1.14.0",
# ]
# ///
"""
Convert QMD query expansion model to ONNX format for Transformers.js.

Loads the base model, merges SFT and GRPO adapters, then exports to ONNX
with quantization for browser deployment via Transformers.js + WebGPU.

Usage:
    uv run convert_onnx.py --size 1.7B
    uv run convert_onnx.py --size 1.7B --no-upload
    uv run convert_onnx.py --base Qwen/Qwen3-1.7B \
                           --sft tobil/qmd-query-expansion-1.7B-sft \
                           --grpo tobil/qmd-query-expansion-1.7B-grpo \
                           --output tobil/qmd-query-expansion-1.7B-ONNX

Quantization options:
    --quantize q4    MatMulNBits 4-bit (default, smallest)
    --quantize q8    8-bit dynamic quantization
    --quantize fp16  FP16 (requires GPU export)
    --quantize none  No quantization (FP32, ~7GB)
"""

import argparse
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path

import torch
from huggingface_hub import HfApi, login
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

PRESETS = {
    "1.7B": {
        "base": "Qwen/Qwen3-1.7B",
        "sft": "tobil/qmd-query-expansion-1.7B-sft",
        "grpo": "tobil/qmd-query-expansion-1.7B-grpo",
        "output": "tobil/qmd-query-expansion-1.7B-ONNX",
    },
    "4B": {
        "base": "Qwen/Qwen3-4B",
        "sft": "tobil/qmd-query-expansion-4B-sft",
        "grpo": "tobil/qmd-query-expansion-4B-grpo",
        "output": "tobil/qmd-query-expansion-4B-ONNX",
    },
}


def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
    """Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
    print(f"\nStep 1: Loading base model {base_model}...")
    model = AutoModelForCausalLM.from_pretrained(
        base_model, dtype=torch.float32, trust_remote_code=True,
    )

    print(f"Step 2: Merging SFT adapter {sft_model}...")
    model = PeftModel.from_pretrained(model, sft_model)
    model = model.merge_and_unload()

    print(f"Step 3: Merging GRPO adapter {grpo_model}...")
    model = PeftModel.from_pretrained(model, grpo_model)
    model = model.merge_and_unload()

    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    return model, tokenizer


def export_onnx(model, tokenizer, output_dir: str):
    """Export merged model to ONNX using Optimum."""
    from optimum.exporters.onnx import main_export

    # Save merged model to temp dir first (Optimum needs HF format on disk)
    merged_dir = "/tmp/merged_model_onnx"
    print(f"\nStep 4: Saving merged model to {merged_dir}...")
    model.save_pretrained(merged_dir, safe_serialization=True)
    tokenizer.save_pretrained(merged_dir)

    print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
    # no_post_process=True avoids the 2GB protobuf serialization limit
    # that occurs during tied-weight deduplication on large FP32 models.
    # The exported model still works correctly — the tied weights just
    # aren't deduplicated in the graph, which is fine since we quantize next.
    main_export(
        model_name_or_path=merged_dir,
        output=output_dir,
        task="text-generation-with-past",
        device="cpu",
        fp16=False,
        no_post_process=True,
    )

    # Clean up temp merged dir
    shutil.rmtree(merged_dir, ignore_errors=True)


def _find_onnx_model(onnx_dir: str) -> Path:
    """Find the main ONNX model file in the output directory."""
    model_path = Path(onnx_dir) / "model.onnx"
    if model_path.exists():
        return model_path
    candidates = list(Path(onnx_dir).glob("*.onnx"))
    if not candidates:
        raise FileNotFoundError(f"No .onnx files found in {onnx_dir}")
    return candidates[0]


def quantize_onnx(onnx_dir: str, quantize_type: str):
    """Quantize the exported ONNX model."""
    if quantize_type == "none":
        print("\nSkipping quantization (FP32).")
        return

    model_path = _find_onnx_model(onnx_dir)
    print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")

    if quantize_type == "q4":
        _quantize_q4(model_path)
    elif quantize_type == "q8":
        _quantize_q8(model_path)
    elif quantize_type == "fp16":
        _convert_fp16(model_path)


def _quantize_q4(model_path: Path):
    """4-bit MatMulNBits quantization via onnxruntime. Needs ~16GB RAM for 1.7B models."""
    from onnxruntime.quantization import matmul_nbits_quantizer

    q_path = model_path.with_name(model_path.stem + "_q4" + model_path.suffix)
    quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
        model=str(model_path),
        block_size=32,
        is_symmetric=True,
        bits=4,
    )
    quant.process()
    quant.model.save(str(q_path))

    # Remove original FP32 files, keep only quantized
    if q_path.exists():
        _report_size(q_path)
        model_path.unlink(missing_ok=True)
        data_path = model_path.with_name(model_path.name + "_data")
        data_path.unlink(missing_ok=True)
        # Rename quantized to model.onnx for Transformers.js compatibility
        q_path.rename(model_path)
        print(f"  Renamed {q_path.name} -> {model_path.name}")


def _quantize_q8(model_path: Path):
    """8-bit dynamic quantization via onnxruntime."""
    from onnxruntime.quantization import quantize_dynamic, QuantType

    q_path = model_path.with_name(model_path.stem + "_q8" + model_path.suffix)
    quantize_dynamic(
        model_input=str(model_path),
        model_output=str(q_path),
        weight_type=QuantType.QUInt8,
    )

    if q_path.exists():
        _report_size(q_path)
        model_path.unlink(missing_ok=True)
        data_path = model_path.with_name(model_path.name + "_data")
        data_path.unlink(missing_ok=True)
        q_path.rename(model_path)
        print(f"  Renamed {q_path.name} -> {model_path.name}")


def _convert_fp16(model_path: Path):
    """Convert ONNX model weights to FP16."""
    from onnxconverter_common import float16
    import onnx

    print("  Converting to FP16...")
    model = onnx.load(str(model_path), load_external_data=True)
    model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)

    fp16_path = model_path.with_name(model_path.stem + "_fp16" + model_path.suffix)
    onnx.save(model_fp16, str(fp16_path))

    if fp16_path.exists():
        _report_size(fp16_path)
        model_path.unlink(missing_ok=True)
        data_path = model_path.with_name(model_path.name + "_data")
        data_path.unlink(missing_ok=True)
        fp16_path.rename(model_path)
        print(f"  Renamed {fp16_path.name} -> {model_path.name}")


def _report_size(path: Path):
    """Print file size in MB."""
    size_mb = path.stat().st_size / (1024 * 1024)
    print(f"  {path.name}: {size_mb:.1f} MB")



def validate_onnx(onnx_dir: str, base_model: str):
    """Run a sample inference through the ONNX model to verify it works."""
    import onnxruntime as ort
    import numpy as np

    model_path = _find_onnx_model(onnx_dir)
    print(f"\nValidation: loading {model_path.name}...")

    tokenizer = AutoTokenizer.from_pretrained(onnx_dir, trust_remote_code=True)
    session = ort.InferenceSession(
        str(model_path),
        providers=["CPUExecutionProvider"],
    )

    # Tokenize a test prompt
    test_query = "/no_think Expand this search query: distributed consensus"
    chat_prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": test_query}],
        add_generation_prompt=True,
        tokenize=False,
    )
    inputs = tokenizer(chat_prompt, return_tensors="np")
    input_ids = inputs["input_ids"].astype(np.int64)
    attention_mask = inputs["attention_mask"].astype(np.int64)

    # Build feed dict with all required inputs
    seq_len = input_ids.shape[1]
    feed = {"input_ids": input_ids, "attention_mask": attention_mask}

    # Add position_ids if needed
    all_inputs = {inp.name: inp for inp in session.get_inputs()}
    if "position_ids" in all_inputs:
        feed["position_ids"] = np.arange(seq_len, dtype=np.int64).reshape(1, -1)

    # Initialize past_key_values to zeros if the model expects them
    for name, inp in sorted(all_inputs.items()):
        if name.startswith("past_key_values"):
            shape = []
            for dim in inp.shape:
                shape.append(dim if isinstance(dim, int) else 0)
            # batch dim = 1
            if shape and shape[0] == 0:
                shape[0] = 1
            feed[name] = np.zeros(shape, dtype=np.float32)

    # Run inference
    output_names = [o.name for o in session.get_outputs()]
    results = session.run(output_names, feed)

    # Check logits shape
    logits = results[0]
    print(f"  Input tokens: {input_ids.shape[1]}")
    print(f"  Output logits shape: {logits.shape}")
    print(f"  Logits range: [{logits.min():.2f}, {logits.max():.2f}]")

    # Greedy decode next token
    next_token_id = int(np.argmax(logits[0, -1, :]))
    next_token = tokenizer.decode([next_token_id])
    print(f"  Next token: {repr(next_token)} (id={next_token_id})")

    # Check KV cache outputs exist
    kv_outputs = [n for n in output_names if n.startswith("present")]
    if kv_outputs:
        print(f"  KV cache outputs: {len(kv_outputs)} tensors (generation-ready)")
    else:
        print("  WARNING: No KV cache outputs — model may not support efficient generation")

    # Sanity checks
    assert logits.shape[0] == 1, "Batch size mismatch"
    assert logits.shape[1] == input_ids.shape[1], "Sequence length mismatch"
    assert logits.max() > logits.min(), "Logits are constant (broken model)"
    assert not np.isnan(logits).any(), "Logits contain NaN"
    assert not np.isinf(logits).any(), "Logits contain Inf"

    print("  Validation PASSED")


def write_transformers_js_config(onnx_dir: str, quantize_type: str = "q4"):
    """Write Transformers.js compatibility config."""
    config_path = Path(onnx_dir) / "transformers_js_config.json"
    config = {
        "model_type": "text-generation",
        "quantized": quantize_type != "none",
    }
    config_path.write_text(json.dumps(config, indent=2) + "\n")
    print(f"  Wrote {config_path.name}")


def upload_to_hub(
    onnx_dir: str,
    output_repo: str,
    base_model: str,
    sft_model: str,
    grpo_model: str,
    quantize_type: str = "q4",
):
    """Upload ONNX model to HuggingFace Hub."""
    print(f"\nStep 7: Uploading to {output_repo}...")
    api = HfApi()
    api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)

    api.upload_folder(
        folder_path=onnx_dir,
        repo_id=output_repo,
        commit_message="Upload ONNX model",
    )

    # Map quantize_type to Transformers.js dtype values
    dtype_map = {"q4": "q4", "q8": "q8", "fp16": "fp16", "none": "fp32"}
    tj_dtype = dtype_map.get(quantize_type, "fp32")
    format_desc = "FP32 (no quantization)" if quantize_type == "none" else f"{quantize_type.upper()} quantization"
    repo_name = output_repo.split("/")[-1]

    readme = f"""---
base_model: {base_model}
tags: [onnx, transformers.js, webgpu, query-expansion, qmd]
library_name: transformers.js
---
# {repo_name}

ONNX conversion of the QMD Query Expansion model for use with
[Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU.

## Details
- **Base:** {base_model}
- **SFT:** {sft_model}
- **GRPO:** {grpo_model}
- **Task:** Query expansion (lex/vec/hyde format)
- **Format:** ONNX with {format_desc}

## Usage with Transformers.js

```javascript
import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers";

const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
  dtype: "{tj_dtype}",
  device: "webgpu",
}});
```

## Prompt Format
```
<|im_start|>user
/no_think Expand this search query: your query here<|im_end|>
<|im_start|>assistant
```
"""
    api.upload_file(
        path_or_fileobj=readme.encode(),
        path_in_repo="README.md",
        repo_id=output_repo,
    )


def main():
    parser = argparse.ArgumentParser(description="Convert QMD model to ONNX")
    parser.add_argument(
        "--size", choices=PRESETS.keys(), help="Use preset config for model size",
    )
    parser.add_argument("--base", help="Base model (overrides preset)")
    parser.add_argument("--sft", help="SFT adapter (overrides preset)")
    parser.add_argument("--grpo", help="GRPO adapter (overrides preset)")
    parser.add_argument("--output", help="Output HF repo (overrides preset)")
    parser.add_argument(
        "--quantize",
        choices=["q4", "q8", "fp16", "none"],
        default="q4",
        help="Quantization type (default: q4)",
    )
    parser.add_argument(
        "--no-upload", action="store_true", help="Don't upload to HF Hub",
    )
    parser.add_argument(
        "--validate", action="store_true",
        help="Run inference validation on exported model",
    )
    parser.add_argument(
        "--validate-only", metavar="DIR",
        help="Skip export, only validate an existing ONNX dir",
    )
    args = parser.parse_args()

    # Validate-only mode: skip export, just run validation
    if args.validate_only:
        validate_onnx(args.validate_only, "")
        return

    # Resolve config
    if args.size:
        preset = PRESETS[args.size]
        base_model = args.base or preset["base"]
        sft_model = args.sft or preset["sft"]
        grpo_model = args.grpo or preset["grpo"]
        output_repo = args.output or preset["output"]
    elif args.base and args.sft and args.grpo and args.output:
        base_model = args.base
        sft_model = args.sft
        grpo_model = args.grpo
        output_repo = args.output
    else:
        parser.error(
            "Either --size or all of --base/--sft/--grpo/--output are required",
        )

    model_name = output_repo.split("/")[-1]
    print(f"QMD ONNX Conversion: {model_name}")
    print("=" * 60)

    # Login
    hf_token = os.environ.get("HF_TOKEN")
    if hf_token:
        print("Logging in to HuggingFace...")
        login(token=hf_token)

    # Merge adapters
    model, tokenizer = merge_adapters(base_model, sft_model, grpo_model)

    # Export to ONNX
    onnx_dir = f"/tmp/onnx_output/{model_name}"
    os.makedirs(onnx_dir, exist_ok=True)
    export_onnx(model, tokenizer, onnx_dir)

    # Quantize
    quantize_onnx(onnx_dir, args.quantize)

    # Write Transformers.js config
    write_transformers_js_config(onnx_dir, args.quantize)

    # Validate
    if args.validate:
        validate_onnx(onnx_dir, base_model)

    # Upload
    if not args.no_upload:
        upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model, args.quantize)

    print(f"\nDone! ONNX files at: {onnx_dir}")
    if not args.no_upload:
        print(f"Repository: https://huggingface.co/{output_repo}")


if __name__ == "__main__":
    main()