"""Unit tests for LoRA utility functions.
This module provides unit tests for LoRA utility functions
that don't require full MindSpeed-MM dependencies.
"""
import math
import os
import pytest
import torch
import torch.nn as nn
from mindspeed_mm.fsdp.utils.lora_utils import (
match_target_modules,
validate_lora_config,
get_lora_trainable_params,
is_pattern_match,
)
class SimpleModel(nn.Module):
"""Simple test model for LoRA testing."""
def __init__(self) -> None:
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return x
class TestValidateLoraConfig:
"""Test LoRA configuration validation."""
def test_valid_config(self) -> None:
"""Test validation of valid LoRA configuration."""
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj", "k_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_invalid_rank_zero(self) -> None:
"""Test validation fails with rank=0."""
with pytest.raises(ValueError, match="rank must be positive"):
validate_lora_config(
rank=0,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_invalid_rank_negative(self) -> None:
"""Test validation fails with negative rank."""
with pytest.raises(ValueError, match="rank must be positive"):
validate_lora_config(
rank=-1,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_invalid_alpha_zero(self) -> None:
"""Test validation fails with alpha=0."""
with pytest.raises(ValueError, match="alpha must be positive"):
validate_lora_config(
rank=8,
alpha=0,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_invalid_alpha_negative(self) -> None:
"""Test validation fails with negative alpha."""
with pytest.raises(ValueError, match="alpha must be positive"):
validate_lora_config(
rank=8,
alpha=-1,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=True,
)
def test_invalid_target_modules_empty(self) -> None:
"""Test validation fails with empty target_modules."""
with pytest.raises(ValueError, match="target_modules cannot be empty"):
validate_lora_config(
rank=8,
alpha=16,
target_modules=[],
dropout=0.05,
init_lora_weights=True,
)
def test_invalid_dropout_negative(self) -> None:
"""Test validation fails with negative dropout."""
with pytest.raises(ValueError, match="dropout must be in \\[0, 1\\)"):
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=-0.1,
init_lora_weights=True,
)
def test_invalid_dropout_too_high(self) -> None:
"""Test validation fails with dropout >= 1.0."""
with pytest.raises(ValueError, match="dropout must be in \\[0, 1\\)"):
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=1.0,
init_lora_weights=True,
)
def test_invalid_init_method(self) -> None:
"""Test validation fails with invalid init method."""
with pytest.raises(ValueError, match="init_lora_weights must be True, False, one of"):
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights="invalid_method",
)
def test_valid_init_methods(self) -> None:
"""Test validation passes with all valid init methods."""
valid_methods = ["gaussian", "eva", "olora", "pissa", "corda", "loftq", "orthogonal"]
for method in valid_methods:
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=method,
)
pissa_niter_values = ["pissa_niter_0", "pissa_niter_5", "pissa_niter_10", "pissa_niter_100"]
for value in pissa_niter_values:
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=value,
)
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=True,
)
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=False,
)
def test_invalid_pissa_niter_format(self) -> None:
"""Test validation fails with invalid pissa_niter format."""
invalid_pissa_values = ["pissa_niter_", "pissa_niter_abc", "pissa_niter_-1", "pissa_niter_1.5"]
for value in invalid_pissa_values:
with pytest.raises(ValueError, match="init_lora_weights must be True, False, one of"):
validate_lora_config(
rank=8,
alpha=16,
target_modules=["q_proj"],
dropout=0.05,
init_lora_weights=value,
)
class TestIsPatternMatch:
"""Test pattern matching logic."""
def test_exact_match(self) -> None:
"""Test exact string matching."""
assert is_pattern_match("q_proj", "q_proj") is True
assert is_pattern_match("k_proj", "k_proj") is True
def test_suffix_match(self) -> None:
"""Test suffix matching."""
assert is_pattern_match("model.linear1", "linear1") is True
assert is_pattern_match("model.linear2", "linear2") is True
def test_wildcard_match(self) -> None:
"""Test wildcard pattern matching."""
assert is_pattern_match("layers.0.q_proj", "layers.{*}.q_proj") is True
assert is_pattern_match("layers.5.k_proj", "layers.{*}.k_proj") is True
assert is_pattern_match("model.layers.10.mlp.gate", "model.layers.{*}.mlp.gate") is True
def test_wildcard_no_match(self) -> None:
"""Test wildcard pattern not matching."""
assert is_pattern_match("layers.0.q_proj", "layers.{*}.v_proj") is False
assert is_pattern_match("model.linear1", "layers.{*}.q_proj") is False
def test_exact_no_match(self) -> None:
"""Test exact string not matching."""
assert is_pattern_match("q_proj", "k_proj") is False
assert is_pattern_match("linear1", "linear2") is False
class TestMatchTargetModules:
"""Test target module matching."""
def test_match_exact_names(self) -> None:
"""Test matching exact module names."""
model = SimpleModel()
patterns = ["linear1", "linear2"]
matched = match_target_modules(model, patterns)
assert len(matched) == 2
assert "linear1" in matched
assert "linear2" in matched
def test_match_wildcard_pattern(self) -> None:
"""Test matching with wildcard patterns."""
model = SimpleModel()
patterns = ["linear{*}"]
matched = match_target_modules(model, patterns)
assert len(matched) == 2
assert "linear1" in matched
assert "linear2" in matched
def test_match_mixed_patterns(self) -> None:
"""Test matching with mixed exact and wildcard patterns."""
model = SimpleModel()
patterns = ["linear1", "linear{*}"]
matched = match_target_modules(model, patterns)
assert len(matched) == 2
assert "linear1" in matched
assert "linear2" in matched
def test_no_match(self) -> None:
"""Test when no modules match."""
model = SimpleModel()
patterns = ["nonexistent", "another_nonexistent"]
matched = match_target_modules(model, patterns)
assert len(matched) == 0
def test_nested_model_matching(self) -> None:
"""Test matching in nested model structure."""
class NestedModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 20),
nn.Linear(20, 10),
])
model = NestedModel()
patterns = ["layers.{*}"]
matched = match_target_modules(model, patterns)
assert len(matched) == 2
assert any("layers.0" in m for m in matched)
assert any("layers.1" in m for m in matched)
class TestGetLoraTrainableParams:
"""Test getting LoRA trainable parameters statistics."""
def test_all_params_trainable(self) -> None:
"""Test when all parameters are trainable."""
model = SimpleModel()
trainable, total, stats = get_lora_trainable_params(model)
assert trainable > 0
assert total > 0
assert trainable == total
assert stats["trainable_params"] == trainable
assert stats["total_params"] == total
assert math.isclose(stats["trainable_ratio"], 1.0)
assert stats["lora_params"] == 0
assert stats["base_params"] == trainable
def test_no_params_trainable(self) -> None:
"""Test when no parameters are trainable."""
model = SimpleModel()
for param in model.parameters():
param.requires_grad = False
trainable, total, stats = get_lora_trainable_params(model)
assert trainable == 0
assert total > 0
assert stats["trainable_params"] == 0
assert stats["total_params"] == total
assert math.isclose(stats["trainable_ratio"], 0.0)
assert stats["lora_params"] == 0
assert stats["base_params"] == 0
def test_partial_trainable(self) -> None:
"""Test when some parameters are trainable."""
model = SimpleModel()
for param in model.linear1.parameters():
param.requires_grad = False
trainable, total, stats = get_lora_trainable_params(model)
assert trainable > 0
assert total > 0
assert trainable < total
assert 0 < stats["trainable_ratio"] < 1.0
assert stats["lora_params"] == 0
assert stats["base_params"] == trainable
def test_lora_params(self) -> None:
"""Test with LoRA parameters."""
model = SimpleModel()
model.lora_A_weight = nn.Parameter(torch.randn(4, 10))
model.lora_B_weight = nn.Parameter(torch.randn(10, 4))
for param in model.linear1.parameters():
param.requires_grad = False
for param in model.linear2.parameters():
param.requires_grad = False
trainable, total, stats = get_lora_trainable_params(model)
assert trainable == 80
assert stats["lora_params"] == 80
assert stats["base_params"] == 0
assert stats["trainable_ratio"] > 0
class TestLoraUtilsIntegration:
"""Integration tests for LoRA utilities."""
def test_full_workflow(self) -> None:
"""Test complete LoRA workflow."""
model = SimpleModel()
validate_lora_config(
rank=8,
alpha=16,
target_modules=["linear1", "linear2"],
dropout=0.05,
init_lora_weights=True,
)
matched = match_target_modules(model, ["linear1", "linear2"])
assert len(matched) == 2
trainable, total, stats = get_lora_trainable_params(model)
assert trainable > 0
assert total > 0
for param in model.parameters():
param.requires_grad = False
trainable_after_freeze, _, _ = get_lora_trainable_params(model)
assert trainable_after_freeze == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])