# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd.  All rights reserved.
from typing import Optional
import torch
import torch_npu

from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.transformer.moe.moe_utils import (reduce_aux_losses_tracker_across_ranks,
                                                     clear_aux_losses_tracker, group_limited_topk,
                                                     get_capacity)
from megatron.core.tensor_parallel.utils import divide

AG_TP_HIDDEN_STATUS = None
AG_SHARED_EXPERTS_INPUTS = []
GEMM_BACKWARD_NEED_TENSORS = None
RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = None
SWAP_STREAM = None
SWAP_STREAM2 = None
SWAP_TENSOR = None
MATMUL_OUTPUT_GRAD = None
UNPERMUTED_TOKENS = None


def get_swap_stream():
    global SWAP_STREAM2
    if SWAP_STREAM2 is None:
        _ = torch_npu.npu.Stream(device=torch.npu.current_device())
        SWAP_STREAM2 = torch_npu.npu.Stream(device=torch.npu.current_device())
    stream = SWAP_STREAM2
    return stream


def set_swap_status(tensor):
    global SWAP_TENSOR
    SWAP_TENSOR = tensor


def get_swap_status():
    global SWAP_STREAM
    if SWAP_STREAM is None:
        SWAP_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
    global SWAP_TENSOR
    stream = SWAP_STREAM
    tensor = SWAP_TENSOR
    SWAP_TENSOR = None
    return stream, tensor


def set_prob_backward_need_tensors(matmul_output_grad, unpermuted_tokens):
    global MATMUL_OUTPUT_GRAD
    MATMUL_OUTPUT_GRAD = matmul_output_grad
    global UNPERMUTED_TOKENS
    UNPERMUTED_TOKENS = unpermuted_tokens


def get_prob_backward_need_tensors():
    global SWAP_STREAM2
    if SWAP_STREAM2 is None:
        _ = torch_npu.npu.Stream(device=torch.npu.current_device())
        SWAP_STREAM2 = torch_npu.npu.Stream(device=torch.npu.current_device())
    global MATMUL_OUTPUT_GRAD
    global UNPERMUTED_TOKENS
    stream = SWAP_STREAM2
    matmul_output_grad = MATMUL_OUTPUT_GRAD
    unpermuted_tokens = UNPERMUTED_TOKENS
    MATMUL_OUTPUT_GRAD = None
    UNPERMUTED_TOKENS = None
    return stream, matmul_output_grad, unpermuted_tokens


def set_ag_tp_hidden_status(_inputs):
    global AG_TP_HIDDEN_STATUS
    AG_TP_HIDDEN_STATUS = _inputs


def get_ag_tp_hidden_status():
    global AG_TP_HIDDEN_STATUS
    result = AG_TP_HIDDEN_STATUS
    AG_TP_HIDDEN_STATUS = None
    return result


def set_gemm_backward_need_tensors(_inputs):
    global GEMM_BACKWARD_NEED_TENSORS
    GEMM_BACKWARD_NEED_TENSORS = _inputs


def get_gemm_backward_need_tensors():
    global GEMM_BACKWARD_NEED_TENSORS
    result = GEMM_BACKWARD_NEED_TENSORS
    GEMM_BACKWARD_NEED_TENSORS = None
    return result


def set_rs_global_hidden_states_grad_with_handle(_inputs):
    global RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
    RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = _inputs


def get_rs_global_hidden_states_grad_with_handle():
    global RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
    result = RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
    RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = None
    return result


ALL2ALL_EXPERTS_OUTPUT = None


def set_all2all_experts_output(_input):
    global ALL2ALL_EXPERTS_OUTPUT
    ALL2ALL_EXPERTS_OUTPUT = _input


def get_all2all_experts_output():
    global ALL2ALL_EXPERTS_OUTPUT
    result = ALL2ALL_EXPERTS_OUTPUT
    ALL2ALL_EXPERTS_OUTPUT = None
    return result


def only_recompute_activation(layer_number):
    args = get_args()
    vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
    vpp_size = args.virtual_pipeline_model_parallel_size
    pp_size = args.pipeline_model_parallel_size

    if vpp_size is not None:
        layer_per_chunk = args.num_layers_per_virtual_pipeline_stage
    elif pp_size is not None:
        layer_per_chunk = args.num_layers // pp_size
    else:
        layer_per_chunk = args.num_layers

    if vpp_rank is None:
        vpp_rank = 0
    if vpp_size is None:
        vpp_size = 1
    recompute_priority = ((layer_number - 1) % layer_per_chunk) * vpp_size + vpp_rank
    moe_zero_memory_num_layers = args.moe_zero_memory_num_layers

    if moe_zero_memory_num_layers:
        if recompute_priority < moe_zero_memory_num_layers:
            return False
        else:
            return True
    else:
        return False


def forward_func(func, inputs):
    def detach_tensor(input_):
        if input_.requires_grad and input_.grad_fn is None:
            return input_
        else:
            new_input = input_.detach()
            new_input.requires_grad = True
        return new_input

    detach_inputs = []
    if isinstance(inputs, tuple):
        for input_ in inputs:
            if isinstance(input_, tuple):
                detach_input = []
                for i in input_:
                    if isinstance(i, torch.Tensor) and torch.is_floating_point(i):
                        detach_input.append(detach_tensor(i))
                    else:
                        detach_input.append(i)
                detach_inputs.append(tuple(detach_input))
            else:
                if isinstance(input_, torch.Tensor) and torch.is_floating_point(input_):
                    detach_input = detach_tensor(input_)
                else:
                    detach_input = input_
                detach_inputs.append(detach_input)
    elif isinstance(inputs, torch.Tensor):
        detach_inputs.append(detach_tensor(inputs))

    with torch.enable_grad():
        output = func(*detach_inputs)

    return output, *detach_inputs


def backward_func(func_tensor, gradinputs):
    if gradinputs is None or func_tensor.grad_fn is None:
        return
    if isinstance(gradinputs, torch.Tensor):
        func_tensor.backward(gradinputs)
    elif isinstance(gradinputs, tuple):
        func_tensor.backward(*gradinputs)


def permute(tokens, routing_map, num_out_tokens: int = None):

    if routing_map.dim() == 1:
        topk = 1
    else:
        topk = routing_map.size(1)
    flatten_indices = routing_map.view(-1)
    # previous use argsort, argsort int64 will be run on host cpu
    sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1]
    if num_out_tokens is not None:
        sorted_indices = sorted_indices[:num_out_tokens]
    permuted_tokens = tokens.index_select(0, sorted_indices // topk)
    return permuted_tokens, sorted_indices


def unpermute(
    permuted_tokens: torch.Tensor,
    sorted_indices: torch.Tensor,
    restore_shape: torch.Size,
    probs: torch.Tensor = None,
    routing_map: torch.Tensor = None,
):

    assert sorted_indices.numel() == permuted_tokens.size(0)
    if probs is not None:
        # Unpermute and merge the tokens with their probabilities
        num_unpermuted_tokens = probs.numel()
        topk = probs.size(1)
    else:
        # Unpermute the tokens without merge
        num_unpermuted_tokens = permuted_tokens.size(0)
        topk = 1

    unpermuted_tokens = torch.zeros(
        [num_unpermuted_tokens, permuted_tokens.shape[-1]],
        dtype=permuted_tokens.dtype,
        device=permuted_tokens.device,
    )
    unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
    unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
    if probs is not None:
        unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
    unpermuted_tokens = unpermuted_tokens.sum(dim=1)

    return unpermuted_tokens


def get_mean(tensor):
    """
    Calculate the mean of a tensor, excluding specified 'noop_layers'.

    Parameters:
        tensor (torch.Tensor): A one-dimensional tensor.

    Returns:
        float: The mean of the tensor, excluding the 'noop_layers' if specified.

    Notes:
        - If `args.noop_layers` is a set and is not empty, the mean is calculated by excluding these layers.
        - If `args.noop_layers` is empty or None, the mean is calculated directly from the tensor.
        - `args.num_layers` represents the total number of layers, used to adjust the mean calculation when
        excluding 'noop_layers'.
    """
    args = get_args()
    if hasattr(args, 'noop_layers') and isinstance(args.noop_layers, set) and len(args.noop_layers) > 0:
        return tensor.sum() / (args.num_layers - len(args.noop_layers))
    return tensor.mean()


def track_moe_metrics(
    loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False
):
    # Aux loss logging

    reduce_aux_losses_tracker_across_ranks()
    tracker = parallel_state.get_moe_layer_wise_logging_tracker()
    if writer is not None:
        aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()}
        for name, loss_list in aux_losses.items():
            # adaptation for
            loss_list_mean = get_mean(loss_list)
            if total_loss_dict is not None:
                if name not in total_loss_dict:
                    # adaptation for loss_list.mean()
                    total_loss_dict[name] = loss_list_mean
                else:
                    # adaptation for loss_list.mean()
                    total_loss_dict[name] += loss_list_mean

            # currently when using add_scalars,
            # torch.utils.add_scalars makes each timer its own run, which
            # polutes the runs list, so we just add each as a scalar
            # adaptation for loss_list.mean()
            writer.add_scalar(name, loss_list_mean, iteration)
            if per_layer_logging:
                for i, loss in enumerate(loss_list.tolist()):
                    writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)

            # W&B logging lacks support for logging multiple scalars simultaneously.
            # As a workaround, we log each scalar individually first, then we can create
            # a custom panel to manually group them to a single plot.
            if wandb_writer:
                # adaptation for loss_list.mean()
                wandb_writer.log({f"{name}": loss_list_mean}, iteration)
                if per_layer_logging:
                    wandb_writer.log(
                        {
                            f"moe/{name}_layer_{i}": loss
                            for i, loss in enumerate(loss_list.tolist())
                        },
                        iteration,
                    )

    clear_aux_losses_tracker()


def get_grouped_expert_params(model, num_local_experts, tp_size, config):
    hidden_size = config.hidden_size
    ffn_hidden_size = config.moe_ffn_hidden_size
    
    fc1_output_size = ffn_hidden_size
    if config.gated_linear_unit:
        # Project to 4h. If using swiglu double the output width
        fc1_output_size *= 2
    fc1_ffn_hidden_size_per_expert = divide(fc1_output_size, tp_size)

    fc2_input_size = ffn_hidden_size
    fc2_ffn_hidden_size_per_expert = divide(fc2_input_size, tp_size)

    weight1_reshaped = model.weight1.view(num_local_experts, hidden_size, fc1_ffn_hidden_size_per_expert)
    weight2_reshaped = model.weight2.view(num_local_experts, fc2_ffn_hidden_size_per_expert, hidden_size)
    
    group_mlp_expert_params = {}
    
    for idx in range(num_local_experts):
        expert_weight1 = weight1_reshaped[idx]  # shape (hidden_size, fc1_ffn_hidden_size_per_expert)
        expert_weight2 = weight2_reshaped[idx]  # shape (fc2_ffn_hidden_size_per_expert, hidden_size)
        
        total_params = expert_weight1.numel() + expert_weight2.numel()
        
        group_mlp_expert_params[idx] = {
            'weight1': expert_weight1,
            'weight2': expert_weight2,
            'total_params': total_params
        }
    return group_mlp_expert_params
    
    
def get_expert_param_data(group_mlp_expert_params, params, idx):
    expert_param = group_mlp_expert_params[idx]
    offset = 0
    for weight in ['weight1', 'weight2']:
        seg1 = params[offset: offset + expert_param[weight].numel()]
        seg1.copy_(expert_param[weight].data.flatten())
        offset += expert_param[weight].numel()


def set_expert_param_data(group_mlp_expert_params, params, idx):
    expert_param = group_mlp_expert_params[idx]
    offset = 0
    with torch.no_grad():
        for weight in ['weight1', 'weight2']:
            seg = params[offset: offset + expert_param[weight].numel()]
            expert_param[weight].copy_(seg.reshape(expert_param[weight].shape))
            expert_param[weight].grad = None
            offset += expert_param[weight].numel()
        
        
def get_expert_param_dtype(experts, group_mlp_expert_params, idx):
    e = group_mlp_expert_params[idx]["weight1"].dtype
    return e
    
    
def get_expert_param_size(experts, group_mlp_expert_params, idx):
    e = group_mlp_expert_params[idx]["total_params"]
    return e


def topk_softmax_with_capacity(
    logits: torch.Tensor,
    topk: int,
    capacity_factor: Optional[float] = None,
    pad_to_capacity: bool = False,
    drop_policy: str = "probs",
    use_pre_softmax: bool = False,
    num_groups: Optional[int] = None,
    group_topk: Optional[int] = None,
    scaling_factor: Optional[float] = None,
    deterministic_mode: bool = False,
    score_function: str = "softmax",
    expert_bias: Optional[torch.Tensor] = None,
):
    """Apply capacity and padding to the top-k selection.
    Args:
        logits (torch.Tensor): Logits tensor.
        topk (int): The number of experts to select for each token.
        capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
                               of tokens exceeds the capacity.
        pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded
                               tokens will be 0.
        drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
                           If "prob", the tokens with the lowest probabilities will be dropped.
                           If "position", tokens at the end of each batch will be dropped.
        use_pre_softmax (bool): Whether to apply softmax or sigmoid before top-k selection.
        num_groups (int): Number of groups for routed experts.
        group_topk (int): Number of selected groups for each token.
        scaling_factor (float): Scaling factor of routing score in top-k selection.
        deterministic_mode (bool): Deprecated.
        score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
        expert_bias (torch.Tensor): The bias added to logits for expert routing.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
              the routing probabilities for each token to each expert.
            - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
              indicating which experts were selected for each token. True values represent
              the selected experts.
            - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
              the number of local tokens assigned to each expert before dropping and padding.
    """
    assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
    num_tokens, num_experts = logits.shape

    def compute_topk(scores, topk, num_groups=None, group_topk=None):
        if group_topk:
            return group_limited_topk(
                scores=scores,
                topk=topk,
                num_tokens=num_tokens,
                num_experts=num_experts,
                num_groups=num_groups,
                group_topk=group_topk,
            )
        else:
            return torch.topk(scores, k=topk, dim=1)

    if score_function == "softmax":
        if use_pre_softmax:
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
            probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        else:
            scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
            probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
    elif score_function == "sigmoid":
        scores = torch.sigmoid(logits.float()).type_as(logits)
        if expert_bias is not None:
            scores_for_routing = scores + expert_bias
            _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
            scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
        else:
            scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
    else:
        raise ValueError(f"Invalid score_function: {score_function}")

    if scaling_factor:
        probs = probs * scaling_factor

    args = get_args()
    if args.fix_router:
        top_indices = torch.arange(top_indices.numel(), device=top_indices.device,
                                   dtype=torch.int64).view(top_indices.shape) % logits.shape[-1]

    topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
    topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
    tokens_per_expert = topk_map.sum(dim=0)

    if capacity_factor is None:
        # TopK without capacity
        return topk_masked_gates, topk_map, tokens_per_expert
    else:
        # TopK with capacity
        expert_capacity = get_capacity(
            num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
        )

        # Maskout exceeded tokens
        if drop_policy == "probs":
            _, capacity_indices = torch.topk(
                topk_masked_gates, k=expert_capacity, dim=0, sorted=False
            )
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        elif drop_policy == "position":
            _, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False)
            capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
        else:
            raise ValueError(f"Invalid drop_policy: {drop_policy}")

        if pad_to_capacity:
            final_map = capacity_mask
            final_probs = topk_masked_gates * final_map
        else:
            # Get exceed mask and maskout exceeded probs and indices
            final_map = torch.logical_and(topk_map, capacity_mask)
            final_probs = topk_masked_gates * final_map
        return final_probs, final_map, tokens_per_expert