"""LoRA configuration arguments for FSDP2 training.
This module defines the dataclass for LoRA-specific configuration
parameters used in FSDP2 distributed training.
"""
import re
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from mindspeed_mm.config.arguments.base_args import BaseArguments
class LoraArguments(BaseArguments):
"""Configuration arguments for LoRA (Low-Rank Adaptation) training.
This class contains all parameters needed to configure LoRA adapters
for efficient fine-tuning of large models.
Attributes:
enable: Whether to enable LoRA fine-tuning.
rank: Rank of the low-rank matrices.
alpha: Scaling factor for LoRA weights.
target_modules: List of target module names/patterns for LoRA.
dropout: Dropout rate for LoRA layers.
init_lora_weights: Weight initialization method.
pretrained_lora_path: Path to pretrained LoRA weights (optional).
lora_target_modules_support: List of supported module types.
"""
enable: bool = field(
default=False,
metadata={"help": "Enable LoRA fine-tuning."},
)
rank: int = field(
default=8,
metadata={"help": "Rank of the low-rank matrices."},
)
alpha: int = field(
default=16,
metadata={"help": "Scaling factor for LoRA weights."},
)
target_modules: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj"],
metadata={
"help": "List of target module names/patterns for LoRA. "
"Supports wildcard patterns (e.g., 'language_model.layers.{*}.q_proj')."
},
)
dropout: float = field(
default=0.0,
metadata={"help": "Dropout rate for LoRA layers."},
)
init_lora_weights: (
bool
| Literal[
"gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal"]
) = field(
default=True,
metadata={
"help": "How to initialize the weights of the LoRA layers. ",
},
)
pretrained_lora_path: Optional[str] = field(
default=None,
metadata={"help": "Path to pretrained LoRA weights to load."},
)
lora_target_modules_support: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of supported module types for validation. "
"If None, validation is skipped."
},
)
def model_post_init(self, __context):
"""Validate LoRA configuration after initialization."""
if self.enable:
if self.rank <= 0:
raise ValueError(f"LoRA rank must be positive, got {self.rank}")
if self.alpha <= 0:
raise ValueError(f"LoRA alpha must be positive, got {self.alpha}")
if not self.target_modules:
raise ValueError("target_modules cannot be empty when LoRA is enabled")
if not 0.0 <= self.dropout < 1.0:
raise ValueError(f"LoRA dropout must be in [0, 1), got {self.dropout}")
valid_init_methods = [
"gaussian", "eva", "olora", "pissa", "corda", "loftq", "orthogonal"
]
pissa_niter_pattern = re.compile(r"^pissa_niter_\d+$")
if isinstance(self.init_lora_weights, str):
init_val = self.init_lora_weights.lower()
if init_val not in valid_init_methods and not pissa_niter_pattern.match(init_val):
raise ValueError(
f"init_lora_weights must be True, False, one of {valid_init_methods}, "
f"or 'pissa_niter_[number of iters]' (e.g., 'pissa_niter_5'), "
f"but got {self.init_lora_weights}"
)
elif not isinstance(self.init_lora_weights, bool):
raise ValueError(
f"init_lora_weights must be bool or str, got {type(self.init_lora_weights)}"
)