import torch
import torch_npu
from mindspeed.op_builder import MoeTokenUnpermuteOpBuilder
__all__ = ["npu_moe_token_unpermute"]
moe_token_unpermute_op_builder = MoeTokenUnpermuteOpBuilder()
def npu_moe_token_unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor = None,
padded_mode: bool = False,
restore_shape: torch.Size = None,
):
moe_token_unpermute_ops = moe_token_unpermute_op_builder.load()
return moe_token_unpermute_ops.npu_moe_token_unpermute(
permuted_tokens, sorted_indices, probs, padded_mode, restore_shape)