# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import torch
import torch_npu
from mindspeed.op_builder import MoeTokenPermuteOpBuilder
__all__ = ["npu_moe_token_permute"]
moe_token_permute_op_builder = MoeTokenPermuteOpBuilder()
def npu_moe_token_permute(
tokens: torch.Tensor,
indices: torch.Tensor,
num_out_tokens: int = None,
padded_mode: bool = False
):
moe_token_permute_ops = moe_token_permute_op_builder.load()
return moe_token_permute_ops.npu_moe_token_permute(tokens, indices, num_out_tokens, padded_mode)