from functools import wraps
import torch


def communication_wrapper(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        from megatron.training import get_args
        arguments = get_args()
        if arguments.enable_high_availability:
            from mindspeed_llm.core.high_availability import tft_is_arf_reboot_node
            if tft_is_arf_reboot_node():
                return None
            if arguments.enable_elastic_training:
                group_index = 2
                return torch_wrapper(fn, group_index, *args, **kwargs)
        return fn(*args, **kwargs)
    return wrapper


def barrier_wrapper(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        from megatron.training import get_args
        arguments = get_args()
        if arguments.enable_high_availability:
            from mindspeed_llm.core.high_availability import tft_is_arf_reboot_node, tft_get_node_group
            if tft_is_arf_reboot_node():
                node_group = tft_get_node_group()
                return fn(node_group) if node_group is not None else None
            if arguments.enable_elastic_training:
                return torch_wrapper(fn, 0, *args, **kwargs)
        return fn(*args, **kwargs)
    return wrapper


def new_group_wrapper(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        backend = kwargs.get('backend', None)
        from mindspeed_llm.core.high_availability import tft_is_arf_reboot_node
        if tft_is_arf_reboot_node() and isinstance(backend, str) and 'gloo' in backend:
            return None

        if (backend is None) or torch.distributed.distributed_c10d._is_barrier_after_init():
            kwargs['use_local_synchronization'] = True
        res = fn(*args, **kwargs)
        return res
    return wrapper


def is_need_change_group(group_index, *args, **kwargs):
    """
    Check whether the 'group' parameter passed in is 'None' to determine if the value of 'group'
    parameter needs to be changed in the scenario of scale-in training, and whether to modify 'args'
    or 'kwargs'.
    """
    if group_index < 0:
        return False, ""
    if len(args) <= group_index and kwargs.get('group', None) is None:
        return True, 'kwargs'
    if len(args) > group_index and args[group_index] is None:
        return True, 'args'
    if len(args) > group_index and args[group_index] == torch.distributed.group.WORLD:
        return True, 'args'
    if kwargs.get('group', None) == torch.distributed.group.WORLD:
        return True, 'kwargs'
    return False, ""


def group_index_two_torch_wrapper(fn):
    """
    In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None',
    change it to the scale-in world group.
    """
    @wraps(fn)
    def wrapper(*args, **kwargs):
        from megatron.training import get_args
        if not get_args().enable_elastic_training:
            return fn(*args, **kwargs)
        group_index = 2
        return torch_wrapper(fn, group_index, *args, **kwargs)
    return wrapper


def group_index_three_torch_wrapper(fn):
    """
    In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None',
    change it to the scale-in world group.
    """
    @wraps(fn)
    def wrapper(*args, **kwargs):
        from megatron.training import get_args
        if not get_args().enable_elastic_training:
            return fn(*args, **kwargs)
        group_index = 3
        return torch_wrapper(fn, group_index, *args, **kwargs)
    return wrapper


def all_to_all_single_wrapper(fn):
    """
    In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None',
    change it to the scale-in world group.
    """
    @wraps(fn)
    def wrapper(*args, **kwargs):
        from megatron.training import get_args
        if not get_args().enable_elastic_training:
            return fn(*args, **kwargs)
        group_index = 4
        return torch_wrapper(fn, group_index, *args, **kwargs)
    return wrapper


def torch_wrapper(fn, group_index, *args, **kwargs):
    """
    In the context of scale-in training scenarios, if the 'group' parameter passed in is 'None',
    change it to the scale-in world group.
    """
    from mindspeed_llm.core.high_availability.tft_arf_group_repair import tft_is_arf_reboot_node
    from mindspeed_llm.core.high_availability import elastic_training_common
    if tft_is_arf_reboot_node():
        return None
    if elastic_training_common.zit_scale_in_running_state():
        need_change_group, change_str = is_need_change_group(group_index, *args, **kwargs)
        if need_change_group and change_str == 'args':
            args_list = list(args)
            args_list[group_index] = elastic_training_common.zit_get_scale_in_world_group()
            new_args = tuple(args_list)
            return fn(*new_args, **kwargs)
        if need_change_group and change_str == 'kwargs':
            kwargs['group'] = elastic_training_common.zit_get_scale_in_world_group()
            return fn(*args, **kwargs)
    return fn(*args, **kwargs)