# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from functools import wraps
from mindspeed.core.context_parallel import mpu
from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
from mindspeed.core.context_parallel.model_parallel_utils import get_context_parallel_group_for_hybrid_ulysses
from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import UlyssesContextAttention
from mindspeed.core.context_parallel import DotProductAttention as MegatronDotProductAttention

from mindspeed_llm.core.context_parallel.dot_product_attention import CPDotProductAttentionImpl


class CPDotProductAttention(CPDotProductAttentionImpl, MegatronDotProductAttention):
    """
    Dot product attention with context parallelism support.

    This class inherits from both CPDotProductAttentionImpl and MegatronDotProductAttention,
    combining context parallel capabilities with Megatron's dot product attention implementation.

    Attributes:
        Inherits all attributes from parent classes.
    """

    def __init__(self, *args, **kwargs):
        """
        Initialize context parallel dot product attention.

        Args:
            *args: Positional arguments passed to parent class.
            **kwargs: Keyword arguments passed to parent class.
        """
        CPDotProductAttentionImpl.__init__(self, *args, **kwargs)


def attention_init_wrapper(fn):
    """
    Wrapper function for attention initialization with context parallel support.

    This decorator wraps the attention initialization function to add support for
    Ulysses context parallelism and hybrid context parallelism algorithms.

    Args:
        fn: The original attention initialization function to be wrapped.

    Returns:
        Callable: Wrapped function that initializes attention with context parallel support.

    The wrapper supports:
        - Ulysses context parallel algorithm (ulysses_cp_algo)
        - Hybrid context parallel algorithm (hybrid_cp_algo)
        - 2D tensor parallel combined with context parallel
    """
    @wraps(fn)
    def wrapper(
        self,
        config,
        submodules,
        layer_number,
        attn_mask_type,
        attention_type,
        cp_comm_type: str = None,):
        fn(self, config, submodules, layer_number, attn_mask_type, attention_type, cp_comm_type)
        cp = config.context_parallel_size
        if config.tp_2d:
            tp_y_cp_sz = cp * config.tp_y
        else:
            tp_y_cp_sz = cp
        if tp_y_cp_sz > 1 and config.context_parallel_algo in ['ulysses_cp_algo', 'hybrid_cp_algo']:
            if config.tp_2d:
                tp_y_cp = TensorParallelYUnionCP()
                ulysses_group = tp_y_cp.group
            else:
                ulysses_group = mpu.get_context_parallel_group()
            if config.context_parallel_algo in ['hybrid_cp_algo']:
                ulysses_group = get_context_parallel_group_for_hybrid_ulysses()
            self.core_attention = UlyssesContextAttention(self.core_attention, ulysses_group)

    return wrapper