# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch",
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.45.0",
# "accelerate>=0.24.0",
# "huggingface_hub>=0.20.0",
# "trackio",
# "nvidia-ml-py",
# "datasets",
# "bitsandbytes",
# "pyyaml",
# "gguf",
# ]
# ///
"""
Unified training script for QMD query expansion models.
Primary pipeline is SFT-only:
sft - Supervised fine-tuning on labeled examples
GRPO was moved to `experiments/grpo/` and is not part of the main training
pipeline by default.
Usage:
uv run train.py sft --config configs/sft.yaml
"""
import argparse
import os
import subprocess
import sys
import time
from pathlib import Path
import yaml
from transformers import TrainerCallback
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
"""Export model to GGUF at Q4_K_M, Q6_K, Q8_0 quantizations."""
import shutil
import tempfile
output_path = Path(output_dir)
gguf_dir = output_path / "gguf"
gguf_dir.mkdir(exist_ok=True)
# Save merged model to temp dir
print("Saving merged model for GGUF conversion...")
with tempfile.TemporaryDirectory() as tmp:
merged_path = Path(tmp) / "merged"
model.save_pretrained(merged_path, safe_serialization=True)
tokenizer.save_pretrained(merged_path)
# Setup llama.cpp
llama_cpp = Path("/tmp/llama.cpp")
if not llama_cpp.exists():
print("Cloning llama.cpp...")
subprocess.run(
[
"git",
"clone",
"--depth",
"1",
"https://github.com/ggerganov/llama.cpp.git",
str(llama_cpp),
],
capture_output=True,
)
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
"-r",
str(llama_cpp / "requirements.txt"),
],
capture_output=True,
)
# Build quantize tool if needed
quantize_bin = llama_cpp / "build" / "bin" / "llama-quantize"
if not quantize_bin.exists():
print("Building llama-quantize...")
build_dir = llama_cpp / "build"
build_dir.mkdir(exist_ok=True)
subprocess.run(
[
"cmake",
"-B",
str(build_dir),
"-S",
str(llama_cpp),
"-DGGML_CUDA=OFF",
],
capture_output=True,
)
subprocess.run(
[
"cmake",
"--build",
str(build_dir),
"--target",
"llama-quantize",
"-j",
"4",
],
capture_output=True,
)
# Convert to FP16 first
fp16_file = gguf_dir / f"{model_name}-f16.gguf"
print(f"Converting to FP16: {fp16_file}")
log_out = Path("/tmp/qmd-gguf-convert.log")
log_err = Path("/tmp/qmd-gguf-convert.err")
with log_out.open("w") as out_f, log_err.open("w") as err_f:
result = subprocess.run(
[
sys.executable,
str(llama_cpp / "convert_hf_to_gguf.py"),
str(merged_path),
"--outfile",
str(fp16_file),
"--outtype",
"f16",
],
stdout=out_f,
stderr=err_f,
text=True,
)
if result.returncode != 0:
print("GGUF conversion failed.")
print(f"stdout: {log_out}")
print(f"stderr: {log_err}")
return
# Quantize to 4, 6, 8 bit
for quant_type in ["Q4_K_M", "Q6_K", "Q8_0"]:
out_file = gguf_dir / f"{model_name}-{quant_type.lower()}.gguf"
print(f"Quantizing {quant_type}: {out_file}")
subprocess.run(
[str(quantize_bin), str(fp16_file), str(out_file), quant_type],
capture_output=True,
)
if out_file.exists():
size_mb = out_file.stat().st_size / (1024 * 1024)
print(f" {quant_type}: {size_mb:.1f} MB")
# Remove FP16 to save space
if fp16_file.exists():
fp16_file.unlink()
print(f"GGUF files saved to: {gguf_dir}")
class TimedSaveCallback(TrainerCallback):
"""Trigger periodic checkpoint saves based on elapsed wall-clock time."""
def __init__(self, interval_minutes: float):
self.interval_seconds = float(interval_minutes) * 60.0
self.last_save_time = time.time()
def on_step_end(self, args, state, control, **kwargs):
if not getattr(state, "is_world_process_zero", False):
return control
now = time.time()
if now - self.last_save_time >= self.interval_seconds:
control.should_save = True
self.last_save_time = now
return control
def run_eval(model_path: str) -> float | None:
"""Run eval.py on the trained model and return average score."""
print("\n" + "=" * 60)
print("Running evaluation...")
print("=" * 60)
eval_script = Path(__file__).parent / "eval.py"
result = subprocess.run(
[sys.executable, str(eval_script), model_path],
cwd=str(Path(__file__).parent),
capture_output=True,
text=True,
)
if result.stdout:
print(result.stdout, end="")
if result.stderr:
print(result.stderr, end="")
avg = None
for line in (result.stdout or "").splitlines():
if line.strip().startswith("Average:"):
try:
avg = float(line.split("Average:", 1)[1].split("%", 1)[0].strip())
except ValueError:
pass
break
return avg
def cmd_sft(args):
"""Run supervised fine-tuning."""
import torch
from datasets import load_dataset
import torch.distributed as dist
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_error()
from trl import SFTTrainer, SFTConfig
with open(args.config) as f:
cfg = yaml.safe_load(f)
os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
if args.dry_run:
print("SFT Training Configuration:")
print(yaml.dump(cfg, default_flow_style=False))
return
dataset_name = cfg["dataset"]["name"]
print(f"Loading dataset: {dataset_name}...")
# Support local JSONL files and glob patterns
if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
from pathlib import Path
import glob
# Handle glob patterns like "data/*.jsonl"
if "*" in dataset_name:
jsonl_files = sorted(glob.glob(dataset_name))
if not jsonl_files:
raise ValueError(f"No files found matching: {dataset_name}")
print(
f" Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}"
)
dataset = load_dataset("json", data_files=jsonl_files, split="train")
else:
data_path = Path(dataset_name)
if data_path.is_dir():
train_file = data_path / "train.jsonl"
dataset = load_dataset(
"json", data_files=str(train_file), split="train"
)
else:
dataset = load_dataset("json", data_files=dataset_name, split="train")
else:
dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
print(f"Dataset loaded: {len(dataset)} examples")
dataset = dataset.shuffle(seed=42)
split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]
print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# Check if output looks like a HF Hub path (contains /)
output_name = cfg["model"]["output"]
push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
if "push_to_hub" in cfg["model"]:
push_to_hub = bool(cfg["model"]["push_to_hub"])
output_dir = output_name.split("/")[-1] if push_to_hub else output_name
report_to = "none"
if os.environ.get("HF_TOKEN"):
try:
import trackio # noqa: F401
report_to = "trackio"
except Exception:
print("Trackio not installed; disabling tracking.")
tracking = cfg.get("tracking", {})
if report_to == "trackio":
project = tracking.get("project")
if project:
os.environ.setdefault("TRACKIO_PROJECT", project)
run_name = tracking.get("run_name")
if run_name and "{" in run_name:
from datetime import datetime
now = datetime.now()
run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
"{time}", now.strftime("%H:%M")
)
save_interval_minutes = cfg["training"].get("save_interval_minutes")
save_steps = cfg["training"].get("save_steps", 200)
save_total_limit = cfg["training"].get("save_total_limit", 2)
if save_interval_minutes:
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
save_steps = max(save_steps, 10_000_000)
callbacks = []
if save_interval_minutes:
try:
interval_value = float(save_interval_minutes)
except (TypeError, ValueError):
interval_value = None
if interval_value and interval_value > 0:
callbacks.append(TimedSaveCallback(interval_value))
config = SFTConfig(
output_dir=output_dir,
push_to_hub=push_to_hub,
hub_model_id=output_name if push_to_hub else None,
hub_strategy="every_save" if push_to_hub else "end",
num_train_epochs=cfg["training"]["epochs"],
per_device_train_batch_size=cfg["training"]["batch_size"],
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
learning_rate=cfg["training"]["learning_rate"],
max_length=cfg["training"]["max_length"],
logging_steps=10,
save_strategy="steps",
save_steps=save_steps,
save_total_limit=save_total_limit,
eval_strategy="steps",
eval_steps=cfg["training"].get("eval_steps", 200),
warmup_ratio=cfg["training"]["warmup_ratio"],
lr_scheduler_type=cfg["training"]["lr_scheduler"],
ddp_find_unused_parameters=cfg["training"].get(
"ddp_find_unused_parameters", False
),
bf16=True,
report_to=report_to,
run_name=run_name if report_to == "trackio" else None,
)
# LoRA config with modules_to_save for embedding layers
# This prevents token ID mismatches during inference
peft_config = LoraConfig(
r=cfg["lora"]["rank"],
lora_alpha=cfg["lora"]["alpha"],
lora_dropout=cfg["lora"]["dropout"],
bias="none",
task_type="CAUSAL_LM",
target_modules=cfg["lora"]["target_modules"],
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
ensure_weight_tying=True,
)
print("Loading tokenizer...")
base_model = cfg["model"]["base"]
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Initializing SFT trainer...")
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=config,
peft_config=peft_config,
processing_class=tokenizer,
callbacks=callbacks,
)
print("Starting SFT training...")
trainer.train()
is_main = os.environ.get("RANK", "0") == "0"
if dist.is_available() and dist.is_initialized():
dist.barrier()
if not is_main:
return
if push_to_hub:
print("Pushing to Hub...")
trainer.push_to_hub()
print(f"Done! Model: https://huggingface.co/{output_name}")
else:
trainer.save_model()
print(f"Done! Model saved to: {output_dir}")
# Export GGUF
print("\nExporting to GGUF...")
# Need to get the merged model for GGUF
print("Loading model for GGUF export...")
from peft import PeftModel
base = AutoModelForCausalLM.from_pretrained(
base_model, torch_dtype=torch.bfloat16, device_map="auto"
)
base.config.tie_word_embeddings = False
model = PeftModel.from_pretrained(base, output_dir, local_files_only=True)
model = model.merge_and_unload()
export_gguf(model, tokenizer, output_dir, Path(output_dir).name)
# Run eval
eval_avg = run_eval(output_dir)
if report_to == "trackio":
try:
import trackio
if eval_avg is not None:
trackio.log({"eval.avg": eval_avg})
except Exception:
pass
def cmd_grpo(args):
"""Run GRPO reinforcement learning on top of merged SFT weights."""
print(
"GRPO is not part of the main training pipeline and has been moved to `experiments/grpo/`."
)
print("To run experimental GRPO, use:")
print(" cd finetune && uv run python experiments/grpo/grpo.py")
return
import torch
import torch.distributed as dist
import os
from datasets import load_dataset
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_error()
from trl import GRPOTrainer, GRPOConfig
# Import reward from the shared module
sys.path.insert(0, os.path.dirname(__file__))
from reward import QMDRewardFunction
with open(args.config) as f:
cfg = yaml.safe_load(f)
os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
if args.dry_run:
print("GRPO Training Configuration:")
print(yaml.dump(cfg, default_flow_style=False))
return
# Tracking
report_to = "none"
if os.environ.get("HF_TOKEN"):
try:
import trackio # noqa: F401
report_to = "trackio"
except Exception:
print("Trackio not installed; disabling tracking.")
tracking = cfg.get("tracking", {})
if report_to == "trackio":
project = tracking.get("project")
if project:
os.environ.setdefault("TRACKIO_PROJECT", project)
run_name = tracking.get("run_name")
if run_name and "{" in run_name:
from datetime import datetime
now = datetime.now()
run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
"{time}", now.strftime("%H:%M")
)
# Load tokenizer
base_model_name = cfg["model"]["base"]
print(f"Loading tokenizer from {base_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load and format dataset
print("Loading dataset...")
dataset = load_dataset(cfg["dataset"]["name"], split="train")
def extract_prompt(example):
content = example[cfg["dataset"]["prompt_field"]][0]["content"]
messages = [{"role": "user", "content": content}]
formatted = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return {"prompt": formatted}
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
max_samples = cfg["dataset"].get("max_samples", len(dataset))
dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
print(f"Using {len(dataset)} prompts for GRPO")
# Load base model, merge SFT adapter
sft_model_name = cfg["model"]["sft"]
print(f"Loading SFT model from {sft_model_name}...")
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if torch.cuda.is_available():
available = torch.cuda.device_count()
if available == 0:
raise RuntimeError("CUDA is available but no devices were detected.")
if local_rank >= available:
print(
f"Warning: LOCAL_RANK={local_rank} but only {available} CUDA device(s) visible. "
"Falling back to the last available device."
)
local_rank = available - 1
torch.cuda.set_device(local_rank)
dtype_name = cfg["model"].get("torch_dtype", "bfloat16")
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(dtype_name, torch.bfloat16)
model_kwargs = {
"torch_dtype": torch_dtype,
"device_map": {"": local_rank} if torch.cuda.is_available() else "auto",
}
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
**model_kwargs,
)
model = PeftModel.from_pretrained(base_model, sft_model_name)
model = model.merge_and_unload()
print("SFT adapter merged.")
# Add fresh LoRA for GRPO with modules_to_save
grpo_lora_config = LoraConfig(
r=cfg["lora"]["rank"],
lora_alpha=cfg["lora"]["alpha"],
lora_dropout=cfg["lora"]["dropout"],
bias="none",
task_type="CAUSAL_LM",
target_modules=cfg["lora"]["target_modules"],
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
ensure_weight_tying=True,
)
model = get_peft_model(model, grpo_lora_config)
model.print_trainable_parameters()
# Build GRPO config
output_name = cfg["model"]["output"]
push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
if "push_to_hub" in cfg["model"]:
push_to_hub = bool(cfg["model"]["push_to_hub"])
output_dir = output_name.split("/")[-1] if push_to_hub else output_name
grpo_cfg = cfg.get("grpo", {})
learning_rate = cfg["training"]["learning_rate"]
if isinstance(learning_rate, str):
learning_rate = float(learning_rate)
save_interval_minutes = cfg["training"].get("save_interval_minutes")
save_steps = cfg["training"].get("save_steps", 200)
save_total_limit = cfg["training"].get("save_total_limit", 2)
save_strategy = cfg["training"].get("save_strategy", "epoch")
if save_interval_minutes:
# Prefer wall-clock checkpointing (for long jobs / preemption safety)
save_steps = max(save_steps, 10_000_000)
save_strategy = "steps"
callbacks = []
if save_interval_minutes:
try:
interval_value = float(save_interval_minutes)
except (TypeError, ValueError):
interval_value = None
if interval_value and interval_value > 0:
callbacks.append(TimedSaveCallback(interval_value))
config = GRPOConfig(
output_dir=output_dir,
push_to_hub=push_to_hub,
hub_model_id=output_name if push_to_hub else None,
num_generations=grpo_cfg.get("num_generations", 4),
max_completion_length=grpo_cfg.get("max_completion_length", 200),
beta=grpo_cfg.get("beta", 0.04),
num_train_epochs=cfg["training"]["epochs"],
per_device_train_batch_size=cfg["training"]["batch_size"],
gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
learning_rate=learning_rate,
max_grad_norm=cfg["training"]["max_grad_norm"],
max_steps=cfg["training"].get("max_steps", -1),
logging_steps=10,
save_strategy=save_strategy,
save_steps=save_steps,
save_total_limit=save_total_limit,
bf16=True,
skip_memory_metrics=True,
report_to=report_to,
run_name=run_name if report_to == "trackio" else None,
)
# Train
print("Initializing GRPO trainer...")
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
args=config,
train_dataset=dataset,
reward_funcs=[QMDRewardFunction()],
callbacks=callbacks,
)
print("Starting GRPO training...")
trainer.train()
is_main = os.environ.get("RANK", "0") == "0"
if dist.is_available() and dist.is_initialized():
dist.barrier()
if not is_main:
return
if push_to_hub:
print("Pushing to Hub...")
trainer.push_to_hub()
trainer.save_model()
if report_to == "trackio":
try:
import trackio
trackio.finish()
except Exception:
pass
print(f"Done! Model saved to: {output_dir}")
# Export GGUF
print("\nExporting to GGUF...")
merged = model.merge_and_unload()
export_gguf(merged, tokenizer, output_dir, Path(output_dir).name)
# Run eval
eval_avg = run_eval(output_dir)
if report_to == "trackio" and eval_avg is not None:
try:
import trackio
trackio.log({"eval.avg": eval_avg})
except Exception:
pass
def main():
parser = argparse.ArgumentParser(
description="QMD Query Expansion Training",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
uv run train.py sft --config configs/sft.yaml
""",
)
sub = parser.add_subparsers(dest="stage", required=True)
sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
sft_parser.add_argument(
"--dry-run", action="store_true", help="Print config and exit"
)
args = parser.parse_args()
cmd_sft(args)
if __name__ == "__main__":
main()