#!/usr/bin/env python3
"""
Strict schema for QMD training data.

Every JSONL file in data/ MUST conform to this format:

    {"query": "auth config", "output": [["hyde", "..."], ["lex", "..."], ["vec", "..."]]}

- query: non-empty string
- output: list of [type, text] pairs where type is "lex", "vec", or "hyde"
- Extra fields (category, intent, is_short, etc.) are allowed but ignored

There is exactly ONE format. No alternatives, no legacy fallbacks.
"""

from __future__ import annotations

import json
from enum import Enum
from pathlib import Path
from typing import Annotated, Iterable

from pydantic import (
    BaseModel,
    BeforeValidator,
    ConfigDict,
    field_validator,
)


# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------

class OutputType(str, Enum):
    lex = "lex"
    vec = "vec"
    hyde = "hyde"


VALID_OUTPUT_TYPES = {t.value for t in OutputType}


class OutputPair(BaseModel):
    """A single expansion line: [type, text]."""

    type: OutputType
    text: str

    model_config = ConfigDict(frozen=True)

    @field_validator("text")
    @classmethod
    def text_not_empty(cls, v: str) -> str:
        if not v or not v.strip():
            raise ValueError("text must not be empty")
        return v

    def to_list(self) -> list[str]:
        return [self.type.value, self.text]


def _coerce_output_pairs(v: list) -> list[OutputPair]:
    """Accept [["lex", "..."], ...] from JSON and coerce to OutputPair list."""
    pairs = []
    for i, item in enumerate(v):
        if isinstance(item, OutputPair):
            pairs.append(item)
        elif isinstance(item, (list, tuple)) and len(item) == 2:
            pairs.append(OutputPair(type=item[0], text=item[1]))
        else:
            raise ValueError(
                f"output[{i}] must be [type, text], got {item!r}"
            )
    return pairs


# ---------------------------------------------------------------------------
# Pydantic model — single source of truth for the JSONL schema
# ---------------------------------------------------------------------------

class TrainingExample(BaseModel):
    """One training example in the canonical JSONL format."""

    query: str
    output: Annotated[list[OutputPair], BeforeValidator(_coerce_output_pairs)]

    # Optional metadata — present in some files, ignored during training.
    category: str | None = None
    intent: str | None = None
    is_short: bool | None = None

    model_config = ConfigDict(extra="ignore")

    @field_validator("query")
    @classmethod
    def query_not_empty(cls, v: str) -> str:
        if not v or not v.strip():
            raise ValueError("query must not be empty")
        return v

    @field_validator("output")
    @classmethod
    def output_not_empty(cls, v: list[OutputPair]) -> list[OutputPair]:
        if not v:
            raise ValueError("output must not be empty")
        return v

    def output_as_lists(self) -> list[list[str]]:
        """Return output as list-of-lists for JSON serialization."""
        return [p.to_list() for p in self.output]


# ---------------------------------------------------------------------------
# Loading
# ---------------------------------------------------------------------------

def load_examples(path: str | Path) -> list[TrainingExample]:
    """Load and validate a JSONL file. Fails loudly on any bad line."""
    path = Path(path)
    examples: list[TrainingExample] = []
    with path.open("r", encoding="utf-8") as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError as e:
                raise ValueError(f"{path}:{line_num}: invalid JSON: {e}") from e
            try:
                examples.append(TrainingExample.model_validate(obj))
            except Exception as e:
                raise ValueError(f"{path}:{line_num}: {e}") from e
    return examples


# ---------------------------------------------------------------------------
# Helpers (used by prepare_data.py, reward.py, and other tools)
# ---------------------------------------------------------------------------

def parse_output_text(text: str) -> list[list[str]]:
    """Parse prefixed output text into list pairs.

    >>> parse_output_text("lex: foo\\nvec: bar")
    [["lex", "foo"], ["vec", "bar"]]
    """
    items: list[list[str]] = []
    for raw_line in text.strip().split("\n"):
        line = raw_line.strip()
        if not line:
            continue
        if line.startswith("lex:"):
            items.append(["lex", line[4:].strip()])
        elif line.startswith("vec:"):
            items.append(["vec", line[4:].strip()])
        elif line.startswith("hyde:"):
            items.append(["hyde", line[5:].strip()])
    return items


def reorder_hyde_first(items: list[list[str]]) -> list[list[str]]:
    """Reorder items to put hyde first, then lex, then vec."""
    hyde_items = [item for item in items if item and item[0] == "hyde"]
    lex_items = [item for item in items if item and item[0] == "lex"]
    vec_items = [item for item in items if item and item[0] == "vec"]
    return hyde_items + lex_items + vec_items


def output_items_to_text(
    items: Iterable, hyde_first: bool = True
) -> str:
    """Render output pairs to prefixed text lines.

    Accepts list[OutputPair] or list[list[str]].
    """
    normalized = []
    for item in items:
        if isinstance(item, OutputPair):
            normalized.append([item.type.value, item.text.strip()])
            continue
        if not item:
            continue
        try:
            kind, text = item[0], item[1]
        except Exception:
            continue
        if kind not in VALID_OUTPUT_TYPES:
            continue
        if text is None:
            continue
        text = str(text).strip()
        if not text:
            continue
        normalized.append([kind, text])

    if hyde_first:
        normalized = reorder_hyde_first(normalized)

    lines = [f"{kind}: {text}" for kind, text in normalized]
    return "\n".join(lines)


def normalize_output_items(
    items: Iterable, hyde_first: bool = True
) -> list[list[str]]:
    """Normalize output pairs (filter invalid, trim whitespace, reorder).

    Accepts list[OutputPair] or list[list[str]].
    """
    normalized: list[list[str]] = []
    for item in items:
        if isinstance(item, OutputPair):
            normalized.append([item.type.value, item.text.strip()])
            continue
        if not item:
            continue
        try:
            kind, text = item[0], item[1]
        except Exception:
            continue
        if kind not in VALID_OUTPUT_TYPES:
            continue
        if text is None:
            continue
        text = str(text).strip()
        if not text:
            continue
        normalized.append([kind, text])

    if hyde_first:
        normalized = reorder_hyde_first(normalized)

    return normalized


def has_type(items: Iterable, kind: str) -> bool:
    for item in items:
        if isinstance(item, OutputPair):
            if item.type.value == kind:
                return True
        elif item and item[0] == kind:
            return True
    return False