"""LoRA utilities for FSDP2 training framework.
This module provides utilities for integrating LoRA (Low-Rank Adaptation)
with FSDP2 distributed training, including:
- LoRA adapter injection
- Target module pattern matching
- Configuration validation
- Parameter statistics
"""
import fnmatch
import logging
import re
from typing import List, Optional, Set, Tuple, Dict, Any
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def freeze_parameters(model: nn.Module) -> None:
"""Freeze all parameters in the model.
Args:
model: The PyTorch model to freeze.
"""
model.requires_grad_(False)
model.eval()
model.train()
def match_target_modules(model: nn.Module, patterns: List[str]) -> List[str]:
"""Match target modules using wildcard patterns.
Supports both exact matching and wildcard patterns (e.g., "language_model.layers.{*}.q_proj").
Args:
model: The PyTorch model to search.
patterns: List of module name patterns to match.
Returns:
List of matched module names.
Examples:
>>> patterns = ["q_proj", "language_model.layers.{*}.self_attn"]
>>> matched = match_target_modules(model, patterns)
"""
matched_modules: List[str] = []
for name, module in model.named_modules():
for pattern in patterns:
if is_pattern_match(name, pattern):
matched_modules.append(name)
break
return matched_modules
def is_pattern_match(module_name: str, pattern: str) -> bool:
"""Check if a module name matches a pattern.
Supports:
- Exact matching: "q_proj" matches "q_proj"
- Wildcard matching: "layers.{*}.q_proj" matches "layers.0.q_proj", "layers.1.q_proj"
Args:
module_name: Full module name to check.
pattern: Pattern to match against.
Returns:
True if the module name matches the pattern.
"""
if "{*}" in pattern:
pattern_regex = pattern.replace("{*}", "*")
return fnmatch.fnmatch(module_name, pattern_regex)
else:
return module_name == pattern or module_name.endswith("." + pattern)
def validate_lora_config(
rank: int,
alpha: int,
target_modules: List[str],
dropout: float,
init_lora_weights: bool | str,
) -> None:
"""Validate LoRA configuration parameters.
Args:
rank: LoRA rank.
alpha: LoRA alpha scaling factor.
target_modules: List of target module patterns.
dropout: Dropout rate.
init_lora_weights: Weight initialization method (True, False, or str).
Raises:
ValueError: If any configuration parameter is invalid.
"""
if rank <= 0:
raise ValueError(f"LoRA rank must be positive, got {rank}")
if alpha <= 0:
raise ValueError(f"LoRA alpha must be positive, got {alpha}")
if not target_modules:
raise ValueError("target_modules cannot be empty")
if not 0.0 <= dropout < 1.0:
raise ValueError(f"LoRA dropout must be in [0, 1), got {dropout}")
valid_init_methods = [
"gaussian", "eva", "olora", "pissa", "corda", "loftq", "orthogonal"
]
pissa_niter_pattern = re.compile(r"^pissa_niter_\d+$")
if isinstance(init_lora_weights, str):
init_val = init_lora_weights.lower()
if init_val not in valid_init_methods and not pissa_niter_pattern.match(init_val):
raise ValueError(
f"init_lora_weights must be True, False, one of {valid_init_methods}, "
f"or 'pissa_niter_[number of iters]' (e.g., 'pissa_niter_5'), "
f"got {init_lora_weights}"
)
elif not isinstance(init_lora_weights, bool):
raise ValueError(
f"init_lora_weights must be bool or str, got {type(init_lora_weights)}"
)
def get_lora_trainable_params(model: nn.Module) -> Tuple[int, int, Dict[str, Any]]:
"""Get statistics about LoRA trainable parameters.
Args:
model: The PyTorch model with LoRA adapters.
Returns:
Tuple of (trainable_params, total_params, stats_dict) where:
- trainable_params: Number of trainable parameters
- total_params: Total number of parameters
- stats_dict: Dictionary with detailed statistics
"""
trainable_params = 0
total_params = 0
lora_params = 0
base_params = 0
for name, param in model.named_parameters():
total_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
if "lora" in name:
lora_params += param.numel()
else:
base_params += param.numel()
stats_dict: Dict[str, Any] = {
"trainable_params": trainable_params,
"total_params": total_params,
"trainable_ratio": trainable_params / total_params if total_params > 0 else 0.0,
"lora_params": lora_params,
"base_params": base_params,
}
return trainable_params, total_params, stats_dict
def add_lora_to_model(
model: nn.Module,
lora_rank: int = 8,
lora_alpha: int = 16,
lora_target_modules: Optional[List[str]] = None,
lora_dropout: float = 0.05,
init_lora_weights: bool | str = True,
pretrained_lora_path: Optional[str] = None,
lora_target_modules_support: Optional[List[str]] = None,
) -> nn.Module:
"""Add LoRA adapters to a PyTorch model.
This function injects LoRA adapters into the specified target modules,
optionally loads pretrained LoRA weights, and ensures proper dtype handling.
Args:
model: The PyTorch model to add LoRA to.
lora_rank: LoRA rank (dimension of the low-rank matrices).
lora_alpha: LoRA alpha scaling factor.
lora_target_modules: List of target module names/patterns.
lora_dropout: Dropout rate for LoRA layers.
init_lora_weights: Weight initialization method (True, False, or str).
pretrained_lora_path: Path to pretrained LoRA weights (optional).
lora_target_modules_support: List of supported module types for validation.
Returns:
The model with LoRA adapters injected.
Raises:
ImportError: If PEFT library is not installed.
ValueError: If target modules are not supported.
"""
try:
from peft import LoraConfig, inject_adapter_in_model
except ImportError as e:
raise ImportError(
"PEFT library is required for LoRA training. "
"Please install it with: pip install peft"
) from e
model.lora_alpha = lora_alpha
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
)
if lora_target_modules_support is not None:
for lora_target_module in lora_config.target_modules:
if lora_target_module not in lora_target_modules_support:
raise ValueError(
f"lora_target_module {lora_target_module} not in "
f"lora_target_modules_support"
)
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
for name, param in model.named_parameters():
if "lora" in name:
param.data = param.data.to(dtype=torch.float32)
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
logger.info(
f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. "
f"{num_unexpected_keys} parameters are unexpected."
)
return model
def load_state_dict(file_path: str, torch_dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
"""Load state dictionary from a checkpoint file.
Supports both safetensors (.safetensors) and PyTorch binary (.pt/.bin) formats.
Args:
file_path: Path to the checkpoint file.
torch_dtype: Optional dtype to cast tensors to.
Returns:
State dictionary mapping parameter names to tensors.
"""
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(
file_path: str,
torch_dtype: Optional[torch.dtype] = None,
) -> Dict[str, torch.Tensor]:
"""Load state dictionary from a safetensors file.
Args:
file_path: Path to the safetensors file.
torch_dtype: Optional dtype to cast tensors to.
Returns:
State dictionary mapping parameter names to tensors.
"""
try:
from safetensors import safe_open
except ImportError as e:
raise ImportError(
"safetensors library is required. "
"Please install it with: pip install safetensors"
) from e
state_dict: Dict[str, torch.Tensor] = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(
file_path: str,
torch_dtype: Optional[torch.dtype] = None,
) -> Dict[str, torch.Tensor]:
"""Load state dictionary from a PyTorch binary file.
Args:
file_path: Path to the binary file.
torch_dtype: Optional dtype to cast tensors to.
Returns:
State dictionary mapping parameter names to tensors.
"""
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for key in state_dict:
if isinstance(state_dict[key], torch.Tensor):
state_dict[key] = state_dict[key].to(torch_dtype)
return state_dict
def print_lora_config(
rank: int,
alpha: int,
target_modules: List[str],
dropout: float,
init_lora_weights: bool | str,
trainable_params: int,
total_params: int,
) -> None:
"""Print LoRA configuration summary.
Args:
rank: LoRA rank.
alpha: LoRA alpha.
target_modules: List of target modules.
dropout: Dropout rate.
init_lora_weights: Initialization method (True, False, or str).
trainable_params: Number of trainable parameters.
total_params: Total number of parameters.
"""
logger.info("=" * 60)
logger.info("LoRA Configuration:")
logger.info(f" Enabled: True")
logger.info(f" Rank: {rank}")
logger.info(f" Alpha: {alpha}")
logger.info(f" Target modules: {target_modules}")
logger.info(f" Dropout: {dropout}")
logger.info(f" Init weights: {init_lora_weights}")
logger.info(f" Trainable parameters: {trainable_params:,}")
logger.info(f" Total parameters: {total_params:,}")
if total_params > 0:
logger.info(f" Trainable ratio: {trainable_params / total_params:.2%}")
logger.info("=" * 60)