"""Run DSPy GEPA using reward.py as the metric."""
from __future__ import annotations
import argparse
import importlib
import json
import sys
from pathlib import Path
def _import_dspy():
script_dir = Path(__file__).parent
repo_root = script_dir.parent
original_sys_path = list(sys.path)
try:
sys.path = [p for p in sys.path if p and str(p) != str(script_dir)]
return importlib.import_module("dspy")
finally:
sys.path = original_sys_path
dspy = _import_dspy()
repo_root = Path(__file__).parent.parent
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from dataset.schema import normalize_output_items, output_items_to_text, parse_output_text
from reward import score_expansion_detailed
class ExpandSignature(dspy.Signature):
"""Expand a search query into lex/vec/hyde lines."""
query = dspy.InputField(desc="User search query")
output = dspy.OutputField(
desc=(
"JSON array of [kind, text] pairs. kind is lex|vec|hyde. "
"Return 2-3 lex, 2-3 vec, optional 0-1 hyde. "
"Lex items are short keywords and must not echo the query. "
"Vec items are natural language search phrases. "
"Hyde is 50-200 chars, single line."
)
)
class Expander(dspy.Module):
def __init__(self):
super().__init__()
self.predict = dspy.Predict(ExpandSignature)
def forward(self, query: str):
return self.predict(query=query)
def reward_metric(gold, pred, trace=None, pred_name=None, pred_trace=None):
expansion = output_items_to_text(_coerce_output_items(pred))
detail = score_expansion_detailed(gold.query, expansion)
score = detail["percentage"] / 100.0
feedback = "; ".join(detail.get("deductions", [])) or f"score={detail['percentage']:.1f}"
return dspy.Prediction(score=score, feedback=feedback)
def load_queries(path: Path) -> list[str]:
queries: list[str] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
query = obj.get("query") or obj.get("input")
if isinstance(query, str) and query.strip():
queries.append(query.strip())
return queries
def to_examples(queries: list[str]) -> list[dspy.Example]:
return [dspy.Example(query=q).with_inputs("query") for q in queries]
def _coerce_output_items(pred) -> list[list[str]]:
raw_output = getattr(pred, "output", None)
if isinstance(raw_output, (list, tuple)):
return normalize_output_items(raw_output)
raw_text = str(raw_output or getattr(pred, "expansion", "") or "").strip()
if not raw_text:
return []
if raw_text[0] in ("[", "{"):
try:
obj = json.loads(raw_text)
if isinstance(obj, dict) and "output" in obj:
obj = obj["output"]
if isinstance(obj, (list, tuple)):
return normalize_output_items(obj)
except Exception:
pass
return parse_output_text(raw_text)
def write_jsonl(path: Path, queries: list[str], outputs: list[list[list[str]]]) -> None:
with path.open("w", encoding="utf-8") as f:
for query, output in zip(queries, outputs, strict=True):
f.write(json.dumps({"query": query, "output": output}, ensure_ascii=False) + "\n")
def main() -> int:
parser = argparse.ArgumentParser(description="Run DSPy GEPA with reward.py")
parser.add_argument("--input", type=str, required=True, help="Training JSONL path")
parser.add_argument(
"--model",
type=str,
default="grok-4-1-fast-reasoning",
help="LM string in provider/model format (e.g., openai/gpt-4o)",
)
parser.add_argument(
"--reflection-model",
type=str,
default="grok-4-1-fast-reasoning",
help="LM string in provider/model format (e.g., openai/gpt-4o)",
)
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for student LM")
parser.add_argument("--reflection-max-tokens", type=int, default=512, help="Max tokens for reflection LM")
parser.add_argument("--auto", type=str, default="light", choices=["light", "medium", "heavy"])
parser.add_argument("--max-full-evals", type=int, default=None)
parser.add_argument("--max-metric-calls", type=int, default=None)
parser.add_argument("--valset", type=str, default=None, help="Optional valset JSONL path")
parser.add_argument("--limit", type=int, default=None, help="Limit number of training queries")
parser.add_argument("--val-limit", type=int, default=None, help="Limit number of val queries")
parser.add_argument("--emit", type=str, default=None, help="Write generated JSONL after compile")
parser.add_argument("--save-prompt", type=str, default=None, help="Write best prompt text to file")
args = parser.parse_args()
if "/" not in args.model or "/" not in args.reflection_model:
print("Error: DSPy expects provider/model format for LM strings (e.g., xai/grok-4-1-fast-reasoning).")
return 1
if args.max_full_evals is not None and args.max_metric_calls is not None:
print("Provide only one of --max-full-evals or --max-metric-calls")
return 1
if args.max_full_evals is not None or args.max_metric_calls is not None:
args.auto = None
train_path = Path(args.input)
queries = load_queries(train_path)
if args.limit is not None:
queries = queries[: args.limit]
trainset = to_examples(queries)
valset = None
if args.valset:
val_queries = load_queries(Path(args.valset))
if args.val_limit is not None:
val_queries = val_queries[: args.val_limit]
valset = to_examples(val_queries)
lm = dspy.LM(model=args.model, max_tokens=args.max_tokens)
reflection_lm = dspy.LM(model=args.reflection_model, max_tokens=args.reflection_max_tokens)
student = Expander()
student.set_lm(lm)
compiler = dspy.GEPA(
metric=reward_metric,
reflection_lm=reflection_lm,
auto=None if args.auto is None else args.auto,
max_full_evals=args.max_full_evals,
max_metric_calls=args.max_metric_calls,
track_stats=True,
track_best_outputs=True,
failure_score=0.0,
perfect_score=1.0,
)
optimized = compiler.compile(student=student, trainset=trainset, valset=valset)
if args.save_prompt:
prompt_text = getattr(optimized.predict.signature, "__doc__", "") or ""
Path(args.save_prompt).write_text(prompt_text.strip() + "\n", encoding="utf-8")
print(f"Wrote {args.save_prompt}")
if args.emit:
outputs = []
for q in queries:
pred = optimized(query=q)
items = _coerce_output_items(pred)
outputs.append(items)
write_jsonl(Path(args.emit), queries, outputs)
print(f"Wrote {args.emit}")
if hasattr(optimized, "detailed_results"):
best = getattr(optimized.detailed_results, "best_outputs_valset", None)
if best:
print(f"Best outputs tracked: {len(best)}")
return 0
if __name__ == "__main__":
raise SystemExit(main())