#!/usr/bin/env python3
"""Prepare QMD query expansion data for LFM2.5-1.2B-Instruct training.

LFM2.5 uses ChatML format:
  <|startoftext|><|im_start|>user
  Expand this search query: {query}<|im_end|>
  <|im_start|>assistant
  {output}<|im_end|>

No /no_think needed (that's Qwen3-specific).
"""

import json
import os
import random
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))
from dataset.schema import normalize_output_items, output_items_to_text

from transformers import AutoTokenizer


def format_for_training(query_text: str, output_items: list[list[str]], tokenizer) -> dict:
    """Format a single example for SFT training using LFM2.5 chat format."""
    output_text = output_items_to_text(output_items)

    messages = [
        {"role": "user", "content": f"Expand this search query: {query_text}"},
        {"role": "assistant", "content": output_text},
    ]

    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False
    )

    return {"text": text}


def main():
    input_path = Path("data/qmd_expansion_v2.jsonl")
    output_dir = Path("data/train-lfm2")
    output_dir.mkdir(parents=True, exist_ok=True)

    print("Loading LFM2.5 tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "LiquidAI/LFM2.5-1.2B-Instruct", trust_remote_code=True
    )

    examples = []
    with open(input_path) as f:
        for line in f:
            row = json.loads(line)
            items = normalize_output_items(row["output"])
            example = format_for_training(row["query"], items, tokenizer)
            examples.append(example)

    # Shuffle and split
    random.seed(42)
    random.shuffle(examples)

    split_idx = int(len(examples) * 0.9)
    train = examples[:split_idx]
    val = examples[split_idx:]

    # Write as JSONL
    train_path = output_dir / "train.jsonl"
    val_path = output_dir / "val.jsonl"

    with open(train_path, "w") as f:
        for ex in train:
            f.write(json.dumps(ex) + "\n")

    with open(val_path, "w") as f:
        for ex in val:
            f.write(json.dumps(ex) + "\n")

    print(f"Written {len(train)} train, {len(val)} val examples to {output_dir}")
    print(f"\nSample formatted text:")
    print(train[0]["text"][:500])


if __name__ == "__main__":
    main()