# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
import logging
from typing import Set, List, Any, Optional

import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from torch.nn.parallel import DistributedDataParallel as DDP

from mindspeed.fsdp.utils.log import print_rank
from mindspeed.fsdp.utils.str_match import module_name_match
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
from mindspeed_mm.fsdp.params.argument import Arguments, parse_args
from mindspeed_mm.fsdp.params.parallel_args import FSDPPlanConfig
from mindspeed_mm.fsdp.utils.device import get_torch_device, get_device_type
from mindspeed_mm.fsdp.utils.dtype import get_dtype
from mindspeed_mm.fsdp.params.training_args import TrainingArguments


logger = logging.getLogger(__name__)


def pregather_fsdp_params(model: torch.nn.Module):
    """
    Pre-gather FSDP2 parameters before forward pass.
    This ensures all ranks have parameters ready before timed computation,
    reducing straggler effects caused by uneven allGather times.
    
    Args:
        model: The model with FSDP2 applied modules.
    """
    for name, module in model.named_modules():
        if hasattr(module, 'unshard') and callable(getattr(module, 'unshard')):
            try:
                module.unshard()
            except Exception as e:
                logging.debug("Failed to unshard module %s: %s", name, e)
    get_torch_device().synchronize()


def fully_shard_parallel_modules(model: torch.nn.Module, fsdp_mesh: DeviceMesh, fsdp_plan: FSDPPlanConfig, training_config: TrainingArguments, **kwargs):
    """
    Apply Fully Sharded Data Parallelism (FSDP) to specified modules in the model.
    
    Args:
        model: The neural network model to apply FSDP to.
        fsdp_mesh: Device mesh defining the FSDP process group.
        fsdp_plan: Configuration specifying which modules to apply FSDP to and mixed precision settings.
        **kwargs: Additional keyword arguments.
    
    Returns:
        The model with FSDP applied to specified modules.
    """
    
    ps = get_parallel_state()

    if ps.fully_shard_parallel_size == 1 and not training_config.init_model_with_meta_device:

        # wrap model in DDP
        dp_group = ps.get_dp_group()
        model = DDP(
            model.to(get_device_type()), 
            process_group=dp_group, 
            find_unused_parameters=True,
            device_ids=[get_torch_device()],
        )

        print_rank(logger.info,
                   "DDP mode is enabled (fully_shard_parallel_size=1) instead of FSDP wrapping")
        return model
    
    if hasattr(model, 'fully_shard') and callable(getattr(model, 'fully_shard')):
        execute_result = model.fully_shard(fsdp_plan=fsdp_plan)
        if execute_result:
            return model

    # Get modules and parameters that should be ignored for FSDP
    ignored_modules, ignored_params = get_ignored_modules(model, fsdp_plan)
    # Get modules that should have FSDP applied
    fsdp_modules = get_fsdp_modules(model, fsdp_plan, ignored_modules)
    # Get modules that FSDP hook add
    hook_modules = get_fsdp_hook_modules(model, fsdp_plan)

    # Configure mixed precision if enabled
    config = {'mesh': fsdp_mesh, 'ignored_params': ignored_params, "reshard_after_forward": fsdp_plan.reshard_after_forward}
    config["mp_policy"] = get_mixprecision_policy(fsdp_plan)
    # Apply FSDP to specific child modules first
    for module in fsdp_modules:
        hook_module = find_hook_module(module, hook_modules)
        fully_shard(module, hook_module=hook_module, **config)
    # Apply FSDP to the entire model
    fully_shard(model, **config)

    set_modules_to_prefetch(model, fsdp_modules, fsdp_plan)
    return model


def get_mixprecision_policy(fsdp_plan: FSDPPlanConfig):
    """Construct the MixedPrecisionPolicy object."""
    param_dtype = get_dtype(fsdp_plan.param_dtype) if fsdp_plan.param_dtype else None
    reduce_dtype = get_dtype(fsdp_plan.reduce_dtype) if fsdp_plan.reduce_dtype else None
    output_dtype = get_dtype(fsdp_plan.output_dtype) if fsdp_plan.output_dtype else None

    return MixedPrecisionPolicy(
        param_dtype=param_dtype,
        reduce_dtype=reduce_dtype,
        output_dtype=output_dtype,
        cast_forward_inputs=fsdp_plan.cast_forward_inputs
    )


def _post_order_traverse(model: torch.nn.Module, parent_path: str = ""):
    """
    Perform post-order traversal of model submodules.
    
    Post-order traversal ensures child modules are visited before their parents,
    which is important for FSDP to properly handle nested modules.
    
    Args:
        model: The model to traverse.
        parent_path: The path to the current module in the hierarchy.
    
    Yields:
        Tuple of (module_path, module) for each module in the model.
    """
    for name, child in model.named_children():
        child_path = f"{parent_path}.{name}" if parent_path else name
        yield from _post_order_traverse(child, child_path)
    yield parent_path, model


def get_fsdp_modules(model: torch.nn.Module, fsdp_plan: FSDPPlanConfig, ignored_modules: Set[str]) -> List[Any]:
    fsdp_modules = []
    if fsdp_plan.apply_modules is None:
        return fsdp_modules
    # Traverse all modules in the model
    if fsdp_plan.apply_modules:
        for name, module in _post_order_traverse(model):
            # Check if module matches any pattern in the FSDP plan
            for pattern in fsdp_plan.apply_modules:
                if module_name_match(pattern, name) and name not in ignored_modules:
                    print_rank(logger.debug, f'[FSDP2]: Apply fsdp2 to module <{name}>')
                    fsdp_modules.append(module)
        # Ensure at least one module matches the FSDP plan
        if len(fsdp_modules) == 0:
            raise RuntimeError(f'[FSDP2] No module named {fsdp_plan.apply_modules}.')
    return fsdp_modules


def get_fsdp_hook_modules(model: torch.nn.Module, fsdp_plan: FSDPPlanConfig) -> List[Any]:
    fsdp_hook_modules = []
    if fsdp_plan.apply_modules is None:
        return fsdp_hook_modules
    
    # Traverse all modules in the model
    if fsdp_plan.hook_modules:
        for name, module in _post_order_traverse(model):
            # Check if module matches any pattern in the FSDP plan
            for pattern in fsdp_plan.hook_modules:
                if module_name_match(pattern, name):
                    print_rank(logger.debug, f'[FSDP2]: Apply fsdp2 hook to hook_module <{name}>')
                    fsdp_hook_modules.append(module)
        # Ensure at least one module matches the FSDP plan
        if len(fsdp_hook_modules) == 0:
            raise RuntimeError(f'[FSDP2] No module named {fsdp_plan.hook_modules}.')
    
    return fsdp_hook_modules


def find_hook_module(target_module: torch.nn.Module, hook_module_list: List[torch.nn.Module]) -> Optional[torch.nn.Module]:
    for hook_module in hook_module_list:
        for _, sub_mod in hook_module.named_modules():
            if sub_mod is target_module:
                return hook_module
    return None


def get_ignored_modules(model: torch.nn.Module, fsdp_plan: FSDPPlanConfig):
    ignored_modules = set()
    ignored_params = set()
    if fsdp_plan.ignored_modules is None:
        return ignored_modules, ignored_params
    for name, module in model.named_modules():
        for pattern in fsdp_plan.ignored_modules:
            if module_name_match(pattern, name):
                print_rank(logger.debug, f'[FSDP2]: Ignored module to apply fsdp2 <{name}>')
                ignored_modules.add(name)
                ignored_params.update(list(module.parameters(recurse=True)))
    return ignored_modules, ignored_params


def set_modules_to_prefetch(model: torch.nn.Module, fsdp_modules: list[torch.nn.Module], fsdp_plan: FSDPPlanConfig):
    """Configure forward and backward prefetching."""
    wrapped_modules_in_order: list[torch.nn.Module] = []
    for sub_module in model.modules():  # pre-order
        if any(sub_module is target_module for target_module in fsdp_modules):
            wrapped_modules_in_order.append(sub_module)

    if fsdp_plan.num_to_forward_prefetch > 0:
        for i, layer in enumerate(wrapped_modules_in_order):
            j_end = min(len(wrapped_modules_in_order), i + 1 + fsdp_plan.num_to_forward_prefetch)
            layers_to_prefetch = wrapped_modules_in_order[i + 1:j_end]
            if layers_to_prefetch:
                layer.set_modules_to_forward_prefetch(layers_to_prefetch)

    if fsdp_plan.num_to_backward_prefetch > 0:
        rev_wrapped_modules_in_order = list(reversed(wrapped_modules_in_order))
        for i, layer in enumerate(rev_wrapped_modules_in_order):
            j_end = min(len(rev_wrapped_modules_in_order), i + 1 + fsdp_plan.num_to_backward_prefetch)
            layers_to_prefetch = rev_wrapped_modules_in_order[i + 1:j_end]
            if layers_to_prefetch:
                layer.set_modules_to_backward_prefetch(layers_to_prefetch)