# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
from functools import wraps, partial

import inspect
import torch


# fix duplicate all-gather
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False, dense_or_moe_group: str = None):
    """
    Initiates param sync (all-gather) communication operations for all model parameters.

    By default, when overlap_param_gather is set to True, dispatches asynchronous communication
    calls; when overlap_param_gather is set to False, calls synchronous communication
    ops. Can override this default behavior using flags below.

    Args:
        force_sync (bool, optional): force synchronous collective regardless of
            other settings.
        force_dispatch (bool, optional): force dispatch regardless of other settings.
    """
    if not force_sync:
        # If overlapping param AG with optimizer step, AG should not be dispatched again
        # in forward_backward_step.
        if self.overlap_param_gather_with_optimizer_step and not force_dispatch:
            return

    if dense_or_moe_group is None:
        bucket_groups = self.bucket_groups + self.expert_parallel_bucket_groups
    elif dense_or_moe_group == 'dense':
        bucket_groups = self.bucket_groups
    elif dense_or_moe_group == 'moe':
        bucket_groups = self.expert_parallel_bucket_groups

    for bucket_group in bucket_groups:
        bucket_group.start_param_sync(force_sync=force_sync)


def step_with_ready_grads_distrib_opti_wrapper(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        self = args[0]
        is_moe_param = getattr(self, 'is_moe_param', None)

        # determine group type
        needrecover = False
        for model_chunk in self.model_chunks:
            if 'dense_or_moe_group' in inspect.signature(model_chunk.start_param_sync).parameters:
                needrecover = True
                model_chunk.start_param_sync = partial(model_chunk.start_param_sync, dense_or_moe_group=is_moe_param)

        update_successful = func(*args, **kwargs)

        if needrecover:
            for model_chunk in self.model_chunks:
                model_chunk.start_param_sync = model_chunk.start_param_sync.func

        return update_successful
    return wrapper


def get_megatron_optimizer_wrapper(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        chained_optimizer = func(*args, **kwargs)

        if hasattr(chained_optimizer, 'chained_optimizers'):
            if 'model_chunks' in kwargs:
                model_chunks = kwargs['model_chunks']
            elif len(args) > 1:
                model_chunks = args[1]
            else:
                return chained_optimizer
        
            for optimizer in chained_optimizer.chained_optimizers:
                optimizer.is_moe_param = 'dense'

            is_expert_parallel = False
            for model_chunk in model_chunks:
                ddp_config = model_chunk.ddp_config
                if ddp_config.use_custom_fsdp:
                    named_parameters = model_chunk.optimizer_named_parameters()
                else:
                    named_parameters = model_chunk.named_parameters()

                for name, param in named_parameters:
                    if (
                        ddp_config.use_custom_fsdp
                        and ddp_config.data_parallel_sharding_strategy == "optim_grads_params"
                    ):
                        param_shard = param
                        param = param.orig_param

                    if not param.requires_grad:
                        continue

                    is_expert_parallel = not getattr(param, 'allreduce', True)
                    if is_expert_parallel:
                        break
                if is_expert_parallel:
                    break

            if is_expert_parallel:
                chained_optimizer.chained_optimizers[-1].is_moe_param = 'moe'

        return chained_optimizer
    return wrapper