from typing import List, Optional, Tuple

import torch

from ..utils import register_tensor_cast_op


@register_tensor_cast_op("init_routing_v2")
def _(
    x: torch.Tensor,
    topk_indices: torch.Tensor,
) -> torch.Tensor:
    """
    Repeat the input tokens top-k times, and rearrange them according to the order of the experts
    selected by the tokens.

    Args:
        x: (bsz, seq_len, hidden_size), the tokens
        topk_indices: (bsz, seq_len, top_k), the top-k experts selected by each token

    Returns:
        permuted_x: (bsz * seq_len * top_k, hidden_size)
    """
    num_tokens = topk_indices.numel()
    return torch.empty((num_tokens, x.shape[-1]), dtype=x.dtype, device=x.device)


@register_tensor_cast_op("unpermute_tokens")
def _(
    x: torch.Tensor,
    topk_indices: torch.Tensor,
) -> torch.Tensor:
    """
    Rearrange the input tokens (initially sorted by their selected experts) by token indices.

    Args:
        x: (bsz * seq_len * top_k, hidden_size), the tokens
        topk_indices: (bsz, seq_len, top_k), the top-k experts selected by each token

    Returns:
        unpermuted_x: (bsz, seq_len, top_k, hidden_size)
    """
    return torch.empty_like(x).view(*topk_indices.shape, x.shape[-1])


@register_tensor_cast_op("dispatch_ffn_combine")
def _(
    x: torch.Tensor,
    expert_indices: torch.Tensor,
    gmm1_w: List[torch.Tensor],
    gmm1_bias: List[Optional[torch.Tensor]],
    gmm2_w: List[torch.Tensor],
    gmm2_bias: List[Optional[torch.Tensor]],
    rank: int,
    rank_group: List[int],
) -> torch.Tensor:
    """Fused MoE FFN: routing + gate_up_proj(SwiGLU) + down_proj + all_to_all.
    BF16 variant. Args carry weights/bias from region (not activations).
    M dimension derived from expert_indices.numel().
    """
    hidden_size = x.shape[-1]
    return torch.empty((*expert_indices.shape, hidden_size), dtype=x.dtype, device=x.device)


@register_tensor_cast_op("dispatch_ffn_combine_quant")
@register_tensor_cast_op("dispatch_ffn_combine_quant_int4")
def _(
    x: torch.Tensor,
    expert_indices: torch.Tensor,
    gmm1_w: List[torch.Tensor],
    gmm1_w_scale: List[torch.Tensor],
    gmm1_w_offset: List[Optional[torch.Tensor]],
    gmm1_bias: List[Optional[torch.Tensor]],
    gmm1_out_dtype: Optional[torch.dtype],
    gmm2_w: List[torch.Tensor],
    gmm2_w_scale: List[torch.Tensor],
    gmm2_w_offset: List[Optional[torch.Tensor]],
    gmm2_bias: List[Optional[torch.Tensor]],
    gmm2_out_dtype: Optional[torch.dtype],
    rank: int,
    rank_group: List[int],
) -> torch.Tensor:
    """Fused MoE FFN: W8A8/W4A8 quant variant.
    Both GMM1 and GMM2 carry only static weight-side args. Activation-side
    quantization for routed input and SwiGLU output is produced inside the
    fused kernel rather than by external graph nodes.
    """
    hidden_size = x.shape[-1]
    return torch.empty((*expert_indices.shape, hidden_size), dtype=x.dtype, device=x.device)


@register_tensor_cast_op("dispatch_ffn_combine_fp8")
@register_tensor_cast_op("dispatch_ffn_combine_mxfp4")
def _(
    x: torch.Tensor,
    expert_indices: torch.Tensor,
    gmm1_w: List[torch.Tensor],
    gmm1_w_scale: List[torch.Tensor],
    gmm1_x_scale: List[torch.Tensor],
    gmm1_bias: List[Optional[torch.Tensor]],
    gmm1_out_dtype: Optional[torch.dtype],
    gmm2_w: List[torch.Tensor],
    gmm2_w_scale: List[torch.Tensor],
    gmm2_x_scale: List[torch.Tensor],
    gmm2_bias: List[Optional[torch.Tensor]],
    gmm2_out_dtype: Optional[torch.dtype],
    rank: int,
    rank_group: List[int],
) -> torch.Tensor:
    """Fused MoE FFN: FP8/MXFP4 quant variant.
    GMM weight/scale args mirror grouped_matmul_fp8[_swiglu] sans x.
    """
    hidden_size = x.shape[-1]
    return torch.empty((*expert_indices.shape, hidden_size), dtype=x.dtype, device=x.device)


@register_tensor_cast_op("moe_gating_top_k_softmax")
def _(x: torch.Tensor, top_k: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Fused operation for Mixture of Experts (MoE) gating, combining softmax and top-k selection.

    This function is designed to handle both the softmax operation over the gating logits
    and the top-k selection of experts in one step. It returns the shape of the expected output
    tensors (experts_weights, and experts_indices) without performing any computation.

    Args:
        x (torch.Tensor): A tensor of containing the raw unnormalized logits for each experts.
                          These logits will be used to compute the softmax probabilities and
                          select the top-k experts.
        top_k (int): The number of top experts to select based on their softmax probabilities.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - topk_weights (torch.Tensor): Corresponding normalized weights (e.g., after softmax),
              with shape `(*x.shape[:-1], top_k)`, dtype and device as input `x`.
            - topk_indices (torch.Tensor): Indices of the selected experts,
              with shape `(*x.shape[:-1], top_k)` and device as input `x`, dtype int64.
    """
    out_shape = (*x.shape[:-1], top_k)
    return (
        torch.empty(out_shape, dtype=x.dtype, device=x.device),
        torch.empty(out_shape, dtype=torch.int64, device=x.device),
    )