import logging
from typing import Set, List, Any, Optional
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from torch.nn.parallel import DistributedDataParallel as DDP
from mindspeed.fsdp.utils.log import print_rank
from mindspeed.fsdp.utils.str_match import module_name_match
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
from mindspeed_mm.fsdp.params.argument import Arguments, parse_args
from mindspeed_mm.fsdp.params.parallel_args import FSDPPlanConfig
from mindspeed_mm.fsdp.utils.device import get_torch_device, get_device_type
from mindspeed_mm.fsdp.utils.dtype import get_dtype
from mindspeed_mm.fsdp.params.training_args import TrainingArguments
logger = logging.getLogger(__name__)
def pregather_fsdp_params(model: torch.nn.Module):
"""
Pre-gather FSDP2 parameters before forward pass.
This ensures all ranks have parameters ready before timed computation,
reducing straggler effects caused by uneven allGather times.
Args:
model: The model with FSDP2 applied modules.
"""
for name, module in model.named_modules():
if hasattr(module, 'unshard') and callable(getattr(module, 'unshard')):
try:
module.unshard()
except Exception as e:
logging.debug("Failed to unshard module %s: %s", name, e)
get_torch_device().synchronize()
def fully_shard_parallel_modules(model: torch.nn.Module, fsdp_mesh: DeviceMesh, fsdp_plan: FSDPPlanConfig, training_config: TrainingArguments, **kwargs):
"""
Apply Fully Sharded Data Parallelism (FSDP) to specified modules in the model.
Args:
model: The neural network model to apply FSDP to.
fsdp_mesh: Device mesh defining the FSDP process group.
fsdp_plan: Configuration specifying which modules to apply FSDP to and mixed precision settings.
**kwargs: Additional keyword arguments.
Returns:
The model with FSDP applied to specified modules.
"""
ps = get_parallel_state()
if ps.fully_shard_parallel_size == 1 and not training_config.init_model_with_meta_device:
dp_group = ps.get_dp_group()
model = DDP(
model.to(get_device_type()),
process_group=dp_group,
find_unused_parameters=True,
device_ids=[get_torch_device()],
)
print_rank(logger.info,
"DDP mode is enabled (fully_shard_parallel_size=1) instead of FSDP wrapping")
return model
if hasattr(model, 'fully_shard') and callable(getattr(model, 'fully_shard')):
execute_result = model.fully_shard(fsdp_plan=fsdp_plan)
if execute_result:
return model
ignored_modules, ignored_params = get_ignored_modules(model, fsdp_plan)
fsdp_modules = get_fsdp_modules(model, fsdp_plan, ignored_modules)
hook_modules = get_fsdp_hook_modules(model, fsdp_plan)
config = {'mesh': fsdp_mesh, 'ignored_params': ignored_params, "reshard_after_forward": fsdp_plan.reshard_after_forward}
config["mp_policy"] = get_mixprecision_policy(fsdp_plan)
for module in fsdp_modules:
hook_module = find_hook_module(module, hook_modules)
fully_shard(module, hook_module=hook_module, **config)
fully_shard(model, **config)
set_modules_to_prefetch(model, fsdp_modules, fsdp_plan)
return model
def get_mixprecision_policy(fsdp_plan: FSDPPlanConfig):
"""Construct the MixedPrecisionPolicy object."""
param_dtype = get_dtype(fsdp_plan.param_dtype) if fsdp_plan.param_dtype else None
reduce_dtype = get_dtype(fsdp_plan.reduce_dtype) if fsdp_plan.reduce_dtype else None
output_dtype = get_dtype(fsdp_plan.output_dtype) if fsdp_plan.output_dtype else None
return MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
output_dtype=output_dtype,
cast_forward_inputs=fsdp_plan.cast_forward_inputs
)
def _post_order_traverse(model: torch.nn.Module, parent_path: str = ""):
"""
Perform post-order traversal of model submodules.
Post-order traversal ensures child modules are visited before their parents,
which is important for FSDP to properly handle nested modules.
Args:
model: The model to traverse.
parent_path: The path to the current module in the hierarchy.
Yields:
Tuple of (module_path, module) for each module in the model.
"""
for name, child in model.named_children():
child_path = f"{parent_path}.{name}" if parent_path else name
yield from _post_order_traverse(child, child_path)
yield parent_path, model
def get_fsdp_modules(model: torch.nn.Module, fsdp_plan: FSDPPlanConfig, ignored_modules: Set[str]) -> List[Any]:
fsdp_modules = []
if fsdp_plan.apply_modules is None:
return fsdp_modules
if fsdp_plan.apply_modules:
for name, module in _post_order_traverse(model):
for pattern in fsdp_plan.apply_modules:
if module_name_match(pattern, name) and name not in ignored_modules:
print_rank(logger.debug, f'[FSDP2]: Apply fsdp2 to module <{name}>')
fsdp_modules.append(module)
if len(fsdp_modules) == 0:
raise RuntimeError(f'[FSDP2] No module named {fsdp_plan.apply_modules}.')
return fsdp_modules
def get_fsdp_hook_modules(model: torch.nn.Module, fsdp_plan: FSDPPlanConfig) -> List[Any]:
fsdp_hook_modules = []
if fsdp_plan.apply_modules is None:
return fsdp_hook_modules
if fsdp_plan.hook_modules:
for name, module in _post_order_traverse(model):
for pattern in fsdp_plan.hook_modules:
if module_name_match(pattern, name):
print_rank(logger.debug, f'[FSDP2]: Apply fsdp2 hook to hook_module <{name}>')
fsdp_hook_modules.append(module)
if len(fsdp_hook_modules) == 0:
raise RuntimeError(f'[FSDP2] No module named {fsdp_plan.hook_modules}.')
return fsdp_hook_modules
def find_hook_module(target_module: torch.nn.Module, hook_module_list: List[torch.nn.Module]) -> Optional[torch.nn.Module]:
for hook_module in hook_module_list:
for _, sub_mod in hook_module.named_modules():
if sub_mod is target_module:
return hook_module
return None
def get_ignored_modules(model: torch.nn.Module, fsdp_plan: FSDPPlanConfig):
ignored_modules = set()
ignored_params = set()
if fsdp_plan.ignored_modules is None:
return ignored_modules, ignored_params
for name, module in model.named_modules():
for pattern in fsdp_plan.ignored_modules:
if module_name_match(pattern, name):
print_rank(logger.debug, f'[FSDP2]: Ignored module to apply fsdp2 <{name}>')
ignored_modules.add(name)
ignored_params.update(list(module.parameters(recurse=True)))
return ignored_modules, ignored_params
def set_modules_to_prefetch(model: torch.nn.Module, fsdp_modules: list[torch.nn.Module], fsdp_plan: FSDPPlanConfig):
"""Configure forward and backward prefetching."""
wrapped_modules_in_order: list[torch.nn.Module] = []
for sub_module in model.modules():
if any(sub_module is target_module for target_module in fsdp_modules):
wrapped_modules_in_order.append(sub_module)
if fsdp_plan.num_to_forward_prefetch > 0:
for i, layer in enumerate(wrapped_modules_in_order):
j_end = min(len(wrapped_modules_in_order), i + 1 + fsdp_plan.num_to_forward_prefetch)
layers_to_prefetch = wrapped_modules_in_order[i + 1:j_end]
if layers_to_prefetch:
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
if fsdp_plan.num_to_backward_prefetch > 0:
rev_wrapped_modules_in_order = list(reversed(wrapped_modules_in_order))
for i, layer in enumerate(rev_wrapped_modules_in_order):
j_end = min(len(rev_wrapped_modules_in_order), i + 1 + fsdp_plan.num_to_backward_prefetch)
layers_to_prefetch = rev_wrapped_modules_in_order[i + 1:j_end]
if layers_to_prefetch:
layer.set_modules_to_backward_prefetch(layers_to_prefetch)