import argparse
import copy
import re
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed
from torch.optim.lr_scheduler import LRScheduler
from transformers.optimization import get_scheduler
from megatron.core import mpu
from megatron.core.transformer.module import MegatronModule
from mindspeed_llm.tasks.posttrain.lu_lora.layers.tp_lu_lora_linear import CCLULoRAParallelLinear
from mindspeed_llm.tasks.posttrain.lu_lora.lr_scheduler import LULoRALRScheduler
def _get_children_layers(layer: torch.nn.Module, prefix_name: Optional[str] = None) -> list[tuple[str, torch.nn.Module]]:
"""
Get children layers for the module.
Args:
layer (torch.nn.Module): A current layer to collect children.
prefix_name (Optional[str]): A current layer to collect children.
Returns:
list[tuple[str, torch.nn.Module]]: The children layers.
"""
children = list(layer.named_children())
if not children:
return []
layers = [(f'{prefix_name}.{name}' if prefix_name else name, child) for name, child in children]
visited_layers = copy.copy(layers)
for name, child in layers:
visited_layers.extend(_get_children_layers(layer=child, prefix_name=name))
return visited_layers
def apply_to_lu_lora_layers(
callback: Callable[[CCLULoRAParallelLinear], None],
model: torch.nn.Module,
adapter_pattern: str,
lu_lora_transformer_blocks: Tuple[int, int] = (),
) -> None:
"""
Apply a callback to the LU-LoRA layers.
Args:
callback (Callable[[CCLULoRAParallelLinear], None]): a callback to run for a LU-LoRA layer, must have the only one argument - CCLULoRAParallelLinear.
model (torch.nn.Module): a model.
adapter_pattern (str): a regex pattern to detect LU-LoRA layers.
lu_lora_transformer_blocks (Tuple[int, int]): indices of the first and the last transformer block trained with LU-LoRA.
"""
adapter_pattern = re.compile(adapter_pattern)
if isinstance(model, list):
model = model[0]
modules_for_lu_lora = _get_lora_layers_for_lu_lora(
model, adapter_pattern, lu_lora_transformer_blocks
)
for _, module in modules_for_lu_lora:
if not hasattr(module, 'enable_lu_lora'):
continue
callback(module)
def get_lu_lora_layers_indices_on_rank(
num_model_layers: int,
pipeline_parallel_size: int,
current_pipeline_parallel_rank: int,
lu_lora_final_layer_index: int
) -> Tuple[int, int]:
"""
Get indices of the first and the last transformer block trained with LU-LoRA on a rank.
Args:
num_model_layers (int): number of model layers.
pipeline_parallel_size (int): number of parallel pipelines of the current run.
current_pipeline_parallel_rank (int): current pipeline parallel rank.
lu_lora_final_layer_index (int): the index of the final transformer block trained with LU-LoRA.
Returns:
Tuple[int, int]: indices of the first and the last transformer block trained with LU-LoRA.
"""
total_layers_per_rank = num_model_layers // pipeline_parallel_size
layers_on_one_rank = tuple(range(current_pipeline_parallel_rank * total_layers_per_rank,
current_pipeline_parallel_rank * total_layers_per_rank + total_layers_per_rank))
all_lu_lora_layers = tuple(range(lu_lora_final_layer_index + 1))
lu_lora_layers_on_rank = list(set(all_lu_lora_layers).intersection(layers_on_one_rank))
return (min(lu_lora_layers_on_rank), max(lu_lora_layers_on_rank)) if lu_lora_layers_on_rank else ()
def configure_lr_for_lu_lora_layers(
model: Union[List[torch.nn.Module], torch.nn.Module],
lr_scheduler: LRScheduler,
args: argparse.Namespace
) -> LULoRALRScheduler:
"""
Configure learning rate scheduler for LU-LoRA layers.
Args:
model (Union[List[torch.nn.Module], torch.nn.Module]): a model
lr_scheduler (LRScheduler): learning rate scheduler for LU-LoRA layers
args (argparse.Namespace): command line arguments
Returns:
LULoRALRScheduler: LU-LoRA learning rate scheduler
"""
if isinstance(model, list):
model = model[0]
lr_scheduler = _get_lu_lora_scheduler(
args=args,
backprop_scheduler=lr_scheduler,
)
for _, module in _get_children_layers(model):
if not hasattr(module, 'enable_lu_lora') or not module.is_lu_lora_enabled:
continue
module.set_lr_scheduler(lr_scheduler.lu_lora_scheduler)
return lr_scheduler
def activate_lu_lora_layers(
model: torch.nn.Module,
args: argparse.Namespace,
) -> None:
"""
Activate LU-LoRA for selected layers.
Args:
model (torch.nn.Module): a model.
args (argparse.Namespace): command line arguments.
"""
rank = mpu.get_pipeline_model_parallel_rank()
lu_lora_layers = get_lu_lora_layers_indices_on_rank(
num_model_layers=args.num_layers,
pipeline_parallel_size=args.pipeline_model_parallel_size,
current_pipeline_parallel_rank=rank,
lu_lora_final_layer_index=args.lu_lora_final_layer_index
)
if not lu_lora_layers:
return
lora_adapter_name_pattern = _get_adapter_pattern(args.lora_target_modules)
apply_to_lu_lora_layers(
callback=lambda module: module.enable_lu_lora(),
model=model,
adapter_pattern=lora_adapter_name_pattern,
lu_lora_transformer_blocks=lu_lora_layers,
)
def _get_lora_layers_for_lu_lora(
model: torch.nn.Module,
adapter_pattern: re.Pattern,
lu_lora_transformer_blocks: Tuple[int, int] = ()
) -> List[Tuple[str, torch.nn.Module]]:
"""
Get LoRA layers that will be trained with LU-LoRA.
Args:
model (torch.nn.Module): a model
adapter_pattern (re.Pattern): a regex pattern to detect LU-LoRA layers
lu_lora_transformer_blocks (Tuple[int, int]): indices of the first and the last transformer block trained with LU-LoRA
Returns:
List[Tuple[str, torch.nn.Module]]: list of layers and their names that will be trained with LU-LoRA.
"""
min_transformer_block_number, max_transformer_block_number = lu_lora_transformer_blocks
modules = []
for name, module in _get_children_layers(model):
pattern_match = adapter_pattern.match(name)
if pattern_match is None:
continue
modules.append((int(pattern_match.group('transformer_block_number')), name, module))
return tuple(
(name, module)
for transformer_block_number, name, module in modules
if min_transformer_block_number <= transformer_block_number <= max_transformer_block_number
)
def _get_adapter_pattern(target_modules: List[str]) -> str:
"""
Get pattern of LoRA adapters for the target modules.
Args:
target_modules (List[str]): list of target modules to fine-tune with LoRA.
Returns:
str: regex pattern to match modules to fine-tune with LoRA.
"""
layers_with_lora_pattern = '|'.join(target_modules)
return rf'^.*\.layers\.(?P<transformer_block_number>\d+)\..*\.({layers_with_lora_pattern})$'
def _get_lu_lora_scheduler(
args: argparse.Namespace,
backprop_scheduler: LRScheduler,
) -> LULoRALRScheduler:
"""
Get learning rate scheduler.
Args:
args (argparse.Namespace): command line arguments.
backprop_scheduler (LRScheduler): learning rate scheduler for LoRA.
Returns:
LULoRALRScheduler: learning rate scheduler.
"""
optim_params = [
{"params": [torch.nn.Parameter(torch.randn(1))], "lr": args.lu_lora_lr},
{"params": [torch.nn.Parameter(torch.randn(1))], "lr": args.lu_lora_lr * args.lu_lora_lr_ratio}
]
empty_optimizer = torch.optim.Adam(optim_params)
lu_lora_scheduler = get_scheduler(
name=args.lr_decay_style,
optimizer=empty_optimizer,
num_warmup_steps=args.lr_warmup_iters if args.lr_warmup_iters > 0 else int(args.train_iters * args.lr_warmup_fraction),
num_training_steps=args.train_iters
)
return LULoRALRScheduler(backprop_scheduler, lu_lora_scheduler)