# GRPO Training Config for QMD Query Expansion
# Target: Qwen3-1.7B, trained on top of merged SFT weights
#
# Usage: uv run train.py grpo --config experiments/grpo/grpo.yaml
#
# The reward function (reward.py) scores expansions on format compliance,
# diversity, hyde quality, content quality, and named entity preservation.
# beta > 0 is critical to prevent drift from the SFT checkpoint.

model:
  base: "Qwen/Qwen3-1.7B"
  sft: "outputs/sft"  # Use local SFT output (or HF path if uploaded)
  output: "outputs/grpo"  # Local training output (push to HF manually after eval)
  push_to_hub: false
  torch_dtype: "bfloat16"
  load_in_4bit: false
  load_in_8bit: false

dataset:
  # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
  # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
  name: "data/train/"
  prompt_field: "messages"
  max_samples: 1000

training:
  epochs: 1
  batch_size: 2
  gradient_accumulation_steps: 8
  learning_rate: 0.0000005
  max_grad_norm: 0.5
  max_steps: 200
  # Save checkpoints every 30 minutes
  save_interval_minutes: 30
  # Fallback time-step save cadence if needed (not used for wall-clock mode)
  save_steps: 50

grpo:
  num_generations: 4
  max_completion_length: 200
  beta: 0.04  # KL regularization - prevents drift from SFT checkpoint

lora:
  rank: 4
  alpha: 8
  dropout: 0.05
  target_modules:
    - "q_proj"
    - "v_proj"

tracking:
  project: "qmd-query-expansion"
  run_name: "grpo-1.7B"