"""Integration tests for LoRA with FSDP2.
This module provides comprehensive tests for LoRA integration
with FSDP2 distributed training, including:
- LoRA adapter injection
- Parameter validation
- Weight saving and loading
- Multi-card compatibility
"""
import os
import pathlib
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from mindspeed_mm.fsdp.params.argument import Arguments, parse_args
from mindspeed_mm.fsdp.utils.lora_utils import (
match_target_modules,
validate_lora_config,
get_lora_trainable_params,
)
from mindspeed_mm.fsdp.utils.lora_weight_manager import LoraWeightManager
class SimpleModel(nn.Module):
"""Simple test model for LoRA testing."""
def __init__(self) -> None:
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return x
class TestLoraUtils:
"""Test LoRA utility functions."""
def test_validate_lora_config_valid(self) -> None:
"""Test validation of valid LoRA configuration."""
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj", "k_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_validate_lora_config_invalid_rank(self) -> None:
"""Test validation of invalid LoRA rank."""
with pytest.raises(ValueError):
validate_lora_config(
rank=0,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_validate_lora_config_invalid_alpha(self) -> None:
"""Test validation of invalid LoRA alpha."""
with pytest.raises(ValueError):
validate_lora_config(
rank=8,
alpha=0,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_validate_lora_config_invalid_dropout(self) -> None:
"""Test validation of invalid LoRA dropout."""
with pytest.raises(ValueError):
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=1.5,
init_lora_weights=True,
)
def test_match_target_modules_exact(self) -> None:
"""Test exact matching of target modules."""
model = SimpleModel()
patterns = ["linear1", "linear2"]
matched = match_target_modules(model, patterns)
assert len(matched) == 2
assert "linear1" in matched
assert "linear2" in matched
def test_match_target_modules_wildcard(self) -> None:
"""Test wildcard matching of target modules."""
model = SimpleModel()
patterns = ["linear{*}"]
matched = match_target_modules(model, patterns)
assert len(matched) == 2
assert "linear1" in matched
assert "linear2" in matched
def test_match_target_modules_no_match(self) -> None:
"""Test when no modules match the pattern."""
model = SimpleModel()
patterns = ["nonexistent"]
matched = match_target_modules(model, patterns)
assert len(matched) == 0
def test_get_lora_trainable_params(self) -> None:
"""Test getting LoRA trainable parameters statistics."""
model = SimpleModel()
trainable, total, stats = get_lora_trainable_params(model)
assert trainable > 0
assert total > 0
assert stats["trainable_params"] == trainable
assert stats["total_params"] == total
assert stats["trainable_ratio"] > 0
assert stats["lora_params"] == 0
class TestLoraWeightManager:
"""Test LoRA weight manager."""
def test_init_weight_manager(self) -> None:
"""Test initialization of LoRA weight manager."""
model = SimpleModel()
manager = LoraWeightManager(model)
assert manager.model is model
assert isinstance(manager._is_distributed, bool)
def test_get_lora_param_count(self) -> None:
"""Test getting LoRA parameter count."""
model = SimpleModel()
manager = LoraWeightManager(model)
num_params, num_elements = manager.get_lora_param_count()
assert num_params == 0
assert num_elements == 0
def test_get_lora_state_dict(self) -> None:
"""Test getting LoRA state dictionary."""
model = SimpleModel()
manager = LoraWeightManager(model)
state_dict = manager.get_lora_state_dict()
assert isinstance(state_dict, dict)
assert len(state_dict) == 0
def test_verify_lora_weights_no_lora(self) -> None:
"""Test verification when no LoRA weights exist."""
model = SimpleModel()
manager = LoraWeightManager(model)
result = manager.verify_lora_weights()
assert result is False
class TestLoraFSDP2Integration:
"""Integration tests for LoRA with FSDP2."""
def test_lora_enable_with_config(self, tmp_path: pathlib.Path) -> None:
"""Test enabling LoRA with configuration file."""
config_path = os.path.join(tmp_path, "test_config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
f.write("""
parallel:
tensor_parallel_size: 1
fully_shard_parallel_size: 1
fsdp_plan:
apply_modules: []
param_dtype: bf16
recompute: false
data:
dataset_param:
dataset_type: huggingface
attr:
images: images
messages: messages
role_tag: role
content_tag: content
user_tag: user
assistant_tag: assistant
preprocess_parameters:
model_name_or_path: ./test_model
use_fast_tokenizer: true
basic_parameters:
cutoff_len: 1024
template: qwen3_vl_nothink
dataset_dir: ./data
dataset: ./data/test.json
cache_dir: ./cache_dir/
dataloader_param:
pin_memory: false
shuffle: false
dataloader_mode: sampler
drop_last: true
num_workers: 0
model:
model_id: qwen3vl
model_name_or_path: ./test_model
trust_remote_code: true
attn_implementation: sdpa
loss_cfg:
loss_type: default
router_aux_loss_coef: 0.0
training:
micro_batch_size: 1
gradient_accumulation_steps: 1
seed: 42
lr: 1e-4
train_iters: 10
clip_grad: 0.0
init_model_with_meta_device: false
optimizer: adamw
adam_fused: false
save_interval: 10
no_load_opt: false
no_load_rng: false
no_save_optim: true
no_save_rng: true
load: null
save: ./test_save
use_deter_comp: false
lora:
enable: true
rank: 4
alpha: 8
target_modules:
- "linear1"
- "linear2"
dropout: 0.0
init_lora_weights: true
pretrained_lora_path: null
tools:
profile:
enable: false
memory_profile:
enable: false
""")
if not os.path.exists("./test_model"):
pytest.skip("Test model not found")
args = parse_args(Arguments, config_path)
assert args.training.lora.enable is True
assert args.training.lora.rank == 4
assert args.training.lora.alpha == 8
assert len(args.training.lora.target_modules) == 2
def test_lora_save_load_cycle(self, tmp_path: str) -> None:
"""Test saving and loading LoRA weights."""
model = SimpleModel()
manager = LoraWeightManager(model)
model.lora_A_weight = nn.Parameter(torch.randn(4, 10))
model.lora_B_weight = nn.Parameter(torch.randn(10, 4))
save_path = os.path.join(tmp_path, "lora_test")
num_saved, num_elements = manager.save_lora_only(save_path)
assert num_saved == 2
assert num_elements > 0
assert os.path.exists(os.path.join(save_path, "lora_adapter.safetensors"))
@pytest.mark.skipif(
not dist.is_initialized() or dist.get_world_size() < 2,
reason="Requires at least 2 GPUs for multi-card test"
)
def test_lora_multicard_compatibility(self) -> None:
"""Test LoRA compatibility with multi-card training."""
rank = dist.get_rank()
world_size = dist.get_world_size()
model = SimpleModel()
manager = LoraWeightManager(model)
num_params, num_elements = manager.get_lora_param_count()
gathered_params = [None] * world_size
dist.all_gather_object(gathered_params, num_params)
if rank == 0:
for i, params in enumerate(gathered_params):
assert params == num_params, f"Rank {i} has different parameter count"
dist.barrier()
def run_tests() -> None:
"""Run all tests."""
pytest.main([__file__, "-v", "-s"])
if __name__ == "__main__":
run_tests()