import torch
from mindspeed_mm.fsdp.utils.device import IS_NPU_AVAILABLE

if IS_NPU_AVAILABLE:
    import torch_npu


def eager_permute(tokens, indices):
    topk = 1 if indices.dim() == 1 else indices.size(1)
    indices_dtype = indices.dtype
    sorted_indices = torch.argsort(indices.float().view(-1), stable=True)
    permuted_tokens = tokens.index_select(0, sorted_indices // topk)
    sorted_indices1 = torch.argsort(sorted_indices.float(), stable=True).to(indices_dtype)
    return permuted_tokens, sorted_indices1


def fused_permute(tokens, indices):
    return torch_npu.npu_moe_token_permute(tokens, indices)


def permute(tokens, indices, fused=True):
    if fused and IS_NPU_AVAILABLE:
        return fused_permute(tokens, indices)
    else:
        return eager_permute(tokens, indices)