"""MoE Permutaion API"""
from typing import Tuple, Optional
import torch
import torch_npu
class MoePermuteMaskMap(torch.autograd.Function):
"""functional Permute with mask router map"""
@staticmethod
def forward(
ctx,
tokens,
routing_map,
probs: Optional[torch.Tensor] = None,
num_out_tokens: Optional[int] = None,
drop_and_pad: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
permuted_input, permuted_probs, sorted_indices = torch_npu.npu_moe_token_permute_with_routing_map(
tokens, routing_map, probs=probs, num_out_tokens=num_out_tokens, drop_and_pad=drop_and_pad)
num_tokens, _ = tokens.shape
num_experts = routing_map.shape[1]
ctx.num_tokens = num_tokens
ctx.num_experts = num_experts
ctx.drop_and_pad = drop_and_pad
ctx.sorted_indices = sorted_indices
ctx.routing_map = routing_map
return permuted_input, permuted_probs, sorted_indices
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
permuted_probs_grad: torch.Tensor,
_,
) -> Tuple[torch.Tensor, ...]:
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
sorted_indices = ctx.sorted_indices
routing_map = ctx.routing_map
num_tokens = ctx.num_tokens
num_experts = ctx.num_experts
drop_and_pad = ctx.drop_and_pad
act_grad, probs_grad = torch_npu.npu_moe_token_permute_with_routing_map_grad(
permuted_act_grad,
permuted_probs_grad,
sorted_indices,
routing_map,
num_experts,
num_tokens,
drop_and_pad
)
if not ctx.needs_input_grad[2]:
probs_grad = None
return act_grad, None, probs_grad, None, None
class MoeUnpermuteMaskMap(torch.autograd.Function):
"""functional Unpermute with mask router map"""
@staticmethod
def forward(
ctx,
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: Optional[torch.Tensor] = None,
routing_map: Optional[torch.Tensor] = None,
drop_and_pad: bool = False,
) -> torch.Tensor:
if not permuted_tokens.numel():
ctx.probs = probs
return permuted_tokens
unpermuted_output, out_index, permuted_token_id, _ = torch_npu._npu_moe_token_unpermute_with_routing_map(
permuted_tokens, sorted_indices, restore_shape, probs=probs, routing_map=routing_map,
drop_and_pad=drop_and_pad)
with_probs = probs is not None
if with_probs:
ctx.save_for_backward(permuted_tokens, probs)
ctx.restore_shape = restore_shape
ctx.drop_and_pad = drop_and_pad
ctx.sorted_indices = sorted_indices
ctx.routing_map = routing_map
ctx.out_index = out_index
ctx.permuted_token_id = permuted_token_id
ctx.with_probs = with_probs
return unpermuted_output
@staticmethod
def backward(ctx, unpermuted_tokens_grad):
if not unpermuted_tokens_grad.numel():
return unpermuted_tokens_grad, None, None, ctx.probs, None, None
restore_shape = ctx.restore_shape
drop_and_pad = ctx.drop_and_pad
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
permuted_tokens, probs = None, None
if ctx.with_probs:
permuted_tokens, probs = ctx.saved_tensors
sorted_indices = ctx.sorted_indices
routing_map = ctx.routing_map
out_index = ctx.out_index
permuted_token_id = ctx.permuted_token_id
if drop_and_pad:
act_grad, probs_grad = (
torch_npu.npu_moe_token_unpermute_with_routing_map_grad(
unpermuted_tokens_grad,
out_index,
permuted_token_id=permuted_token_id,
routing_map=routing_map,
permuted_tokens=permuted_tokens,
probs=probs,
drop_and_pad=drop_and_pad,
restore_shape=restore_shape
)
)
else:
act_grad, probs_grad = (
torch_npu.npu_moe_token_unpermute_with_routing_map_grad(
unpermuted_tokens_grad,
sorted_indices,
sorted_indices,
routing_map=routing_map,
permuted_tokens=permuted_tokens,
probs=probs,
drop_and_pad=drop_and_pad,
restore_shape=restore_shape
)
)
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad, None, None