#!/usr/bin/env python3
"""
Strict schema for QMD training data.
Every JSONL file in data/ MUST conform to this format:
{"query": "auth config", "output": [["hyde", "..."], ["lex", "..."], ["vec", "..."]]}
- query: non-empty string
- output: list of [type, text] pairs where type is "lex", "vec", or "hyde"
- Extra fields (category, intent, is_short, etc.) are allowed but ignored
There is exactly ONE format. No alternatives, no legacy fallbacks.
"""
from __future__ import annotations
import json
from enum import Enum
from pathlib import Path
from typing import Annotated, Iterable
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
field_validator,
)
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
class OutputType(str, Enum):
lex = "lex"
vec = "vec"
hyde = "hyde"
VALID_OUTPUT_TYPES = {t.value for t in OutputType}
class OutputPair(BaseModel):
"""A single expansion line: [type, text]."""
type: OutputType
text: str
model_config = ConfigDict(frozen=True)
@field_validator("text")
@classmethod
def text_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("text must not be empty")
return v
def to_list(self) -> list[str]:
return [self.type.value, self.text]
def _coerce_output_pairs(v: list) -> list[OutputPair]:
"""Accept [["lex", "..."], ...] from JSON and coerce to OutputPair list."""
pairs = []
for i, item in enumerate(v):
if isinstance(item, OutputPair):
pairs.append(item)
elif isinstance(item, (list, tuple)) and len(item) == 2:
pairs.append(OutputPair(type=item[0], text=item[1]))
else:
raise ValueError(
f"output[{i}] must be [type, text], got {item!r}"
)
return pairs
# ---------------------------------------------------------------------------
# Pydantic model — single source of truth for the JSONL schema
# ---------------------------------------------------------------------------
class TrainingExample(BaseModel):
"""One training example in the canonical JSONL format."""
query: str
output: Annotated[list[OutputPair], BeforeValidator(_coerce_output_pairs)]
# Optional metadata — present in some files, ignored during training.
category: str | None = None
intent: str | None = None
is_short: bool | None = None
model_config = ConfigDict(extra="ignore")
@field_validator("query")
@classmethod
def query_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("query must not be empty")
return v
@field_validator("output")
@classmethod
def output_not_empty(cls, v: list[OutputPair]) -> list[OutputPair]:
if not v:
raise ValueError("output must not be empty")
return v
def output_as_lists(self) -> list[list[str]]:
"""Return output as list-of-lists for JSON serialization."""
return [p.to_list() for p in self.output]
# ---------------------------------------------------------------------------
# Loading
# ---------------------------------------------------------------------------
def load_examples(path: str | Path) -> list[TrainingExample]:
"""Load and validate a JSONL file. Fails loudly on any bad line."""
path = Path(path)
examples: list[TrainingExample] = []
with path.open("r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
raise ValueError(f"{path}:{line_num}: invalid JSON: {e}") from e
try:
examples.append(TrainingExample.model_validate(obj))
except Exception as e:
raise ValueError(f"{path}:{line_num}: {e}") from e
return examples
# ---------------------------------------------------------------------------
# Helpers (used by prepare_data.py, reward.py, and other tools)
# ---------------------------------------------------------------------------
def parse_output_text(text: str) -> list[list[str]]:
"""Parse prefixed output text into list pairs.
>>> parse_output_text("lex: foo\\nvec: bar")
[["lex", "foo"], ["vec", "bar"]]
"""
items: list[list[str]] = []
for raw_line in text.strip().split("\n"):
line = raw_line.strip()
if not line:
continue
if line.startswith("lex:"):
items.append(["lex", line[4:].strip()])
elif line.startswith("vec:"):
items.append(["vec", line[4:].strip()])
elif line.startswith("hyde:"):
items.append(["hyde", line[5:].strip()])
return items
def reorder_hyde_first(items: list[list[str]]) -> list[list[str]]:
"""Reorder items to put hyde first, then lex, then vec."""
hyde_items = [item for item in items if item and item[0] == "hyde"]
lex_items = [item for item in items if item and item[0] == "lex"]
vec_items = [item for item in items if item and item[0] == "vec"]
return hyde_items + lex_items + vec_items
def output_items_to_text(
items: Iterable, hyde_first: bool = True
) -> str:
"""Render output pairs to prefixed text lines.
Accepts list[OutputPair] or list[list[str]].
"""
normalized = []
for item in items:
if isinstance(item, OutputPair):
normalized.append([item.type.value, item.text.strip()])
continue
if not item:
continue
try:
kind, text = item[0], item[1]
except Exception:
continue
if kind not in VALID_OUTPUT_TYPES:
continue
if text is None:
continue
text = str(text).strip()
if not text:
continue
normalized.append([kind, text])
if hyde_first:
normalized = reorder_hyde_first(normalized)
lines = [f"{kind}: {text}" for kind, text in normalized]
return "\n".join(lines)
def normalize_output_items(
items: Iterable, hyde_first: bool = True
) -> list[list[str]]:
"""Normalize output pairs (filter invalid, trim whitespace, reorder).
Accepts list[OutputPair] or list[list[str]].
"""
normalized: list[list[str]] = []
for item in items:
if isinstance(item, OutputPair):
normalized.append([item.type.value, item.text.strip()])
continue
if not item:
continue
try:
kind, text = item[0], item[1]
except Exception:
continue
if kind not in VALID_OUTPUT_TYPES:
continue
if text is None:
continue
text = str(text).strip()
if not text:
continue
normalized.append([kind, text])
if hyde_first:
normalized = reorder_hyde_first(normalized)
return normalized
def has_type(items: Iterable, kind: str) -> bool:
for item in items:
if isinstance(item, OutputPair):
if item.type.value == kind:
return True
elif item and item[0] == kind:
return True
return False