# GRPO Training Config for QMD Query Expansion
# Target: Qwen3-1.7B, trained on top of merged SFT weights
#
# Usage: uv run train.py grpo --config experiments/grpo/grpo.yaml
#
# The reward function (reward.py) scores expansions on format compliance,
# diversity, hyde quality, content quality, and named entity preservation.
# beta > 0 is critical to prevent drift from the SFT checkpoint.
model:
base: "Qwen/Qwen3-1.7B"
sft: "outputs/sft" # Use local SFT output (or HF path if uploaded)
output: "outputs/grpo" # Local training output (push to HF manually after eval)
push_to_hub: false
torch_dtype: "bfloat16"
load_in_4bit: false
load_in_8bit: false
dataset:
# Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
# HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
name: "data/train/"
prompt_field: "messages"
max_samples: 1000
training:
epochs: 1
batch_size: 2
gradient_accumulation_steps: 8
learning_rate: 0.0000005
max_grad_norm: 0.5
max_steps: 200
# Save checkpoints every 30 minutes
save_interval_minutes: 30
# Fallback time-step save cadence if needed (not used for wall-clock mode)
save_steps: 50
grpo:
num_generations: 4
max_completion_length: 200
beta: 0.04 # KL regularization - prevents drift from SFT checkpoint
lora:
rank: 4
alpha: 8
dropout: 0.05
target_modules:
- "q_proj"
- "v_proj"
tracking:
project: "qmd-query-expansion"
run_name: "grpo-1.7B"