from functools import wraps
from typing import Optional, Tuple
import torch
import torch_npu
from megatron.core.transformer.moe.moe_utils import maybe_move_tensor_to_cpu
from megatron.core.transformer.moe.moe_utils import permute as megatron_permute
from megatron.core.transformer.moe.moe_utils import unpermute as megatron_unpermute
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.moe.moe_utils import sort_chunks_by_idxs
from mindspeed.te.pytorch.permutation import MoePermuteMaskMap, MoeUnpermuteMaskMap
from mindspeed.utils import has_triton
def convert_tensors_to_fp32_if_needed(
tensor1: Optional[torch.Tensor],
tensor2: Optional[torch.Tensor]
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.dtype]]:
if not (isinstance(tensor1, torch.Tensor) and isinstance(tensor2, torch.Tensor)):
return tensor1, tensor2, None
dtype1, dtype2 = tensor1.dtype, tensor2.dtype
if (dtype1 == torch.float32) ^ (dtype2 == torch.float32):
original_dtype = dtype2 if dtype1 == torch.float32 else dtype1
tensor1 = tensor1.to(torch.float32)
tensor2 = tensor2.to(torch.float32)
return tensor1, tensor2, original_dtype
return tensor1, tensor2, None
def restore_original_dtype(
tensor1: Optional[torch.Tensor],
tensor2: Optional[torch.Tensor],
original_dtype: Optional[torch.dtype]
) -> Tuple[Optional[torch.Tensor], ...]:
"""Restore tensors to their original dtype if needed."""
if original_dtype is None:
return tensor1, tensor2
return (tensor1.to(original_dtype) if tensor1 is not None else tensor1,
tensor2.to(original_dtype) if tensor2 is not None else tensor2)
def permute(
tokens,
routing_map,
probs: Optional[torch.Tensor] = None,
num_out_tokens: Optional[int] = None,
fused: bool = False,
drop_and_pad: bool = False,
) -> torch.Tensor:
if fused:
tokens, probs, original_dtype = convert_tensors_to_fp32_if_needed(tokens, probs)
permuted_input, permuted_probs, sorted_indices = (
MoePermuteMaskMap.apply(tokens, routing_map, probs, num_out_tokens, drop_and_pad))
permuted_input, permuted_probs = restore_original_dtype(permuted_input, permuted_probs, original_dtype)
return permuted_input, permuted_probs, sorted_indices
else:
return megatron_permute(tokens, routing_map, probs=probs, num_out_tokens=num_out_tokens, fused=fused,
drop_and_pad=drop_and_pad)
def unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
fused: bool = False,
drop_and_pad: bool = False,
) -> torch.Tensor:
if fused:
return MoeUnpermuteMaskMap.apply(
permuted_tokens, sorted_indices, restore_shape, probs, routing_map, drop_and_pad)
else:
return megatron_unpermute(permuted_tokens, sorted_indices, restore_shape, probs=probs, routing_map=routing_map,
fused=fused, drop_and_pad=drop_and_pad)
def sort_chunks_by_idxs_wrapper(fn):
@wraps(fn)
def wrapper(
input: torch.Tensor,
split_sizes: torch.Tensor,
sorted_idxs: torch.Tensor,
probs: Optional[torch.Tensor] = None,
fused: bool = False,
) -> torch.Tensor:
return fn(input, split_sizes, sorted_idxs, probs=probs, fused=False)
return wrapper
def moe_alltoall_token_dispatcher_init_wrapper(fn):
@wraps(fn)
def wrapper(
self, num_local_experts, local_expert_indices, config
) -> None:
fn(self, num_local_experts, local_expert_indices, config)
if has_triton() and self.config.moe_permute_fusion:
self.permute_idx_device = torch.device("npu")
else:
self.permute_idx_device = None
input_chunk_idxs = torch.arange(
self.num_experts * self.tp_size, device=self.permute_idx_device
)
self.sort_input_by_local_experts = input_chunk_idxs.reshape(
-1, self.num_local_experts
).T.ravel()
self.restore_output_by_local_experts = input_chunk_idxs.reshape(
self.num_local_experts, -1
).T.ravel()
return wrapper
def maybe_dtoh_and_synchronize(
self, point: str, tokens_per_expert: torch.Tensor = None
) -> torch.Tensor:
"""
Move all possible GPU tensors to CPU and make a synchronization at the expected point.
"""
if not self.drop_and_pad:
if point == self.cuda_dtoh_point:
on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream
if on_side_stream:
self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.cuda_dtoh_stream):
self.input_splits = maybe_move_tensor_to_cpu(
self.input_splits, as_numpy=True, record_stream=on_side_stream
)
self.output_splits = maybe_move_tensor_to_cpu(
self.output_splits, as_numpy=True, record_stream=on_side_stream
)
self.output_splits_tp = maybe_move_tensor_to_cpu(
self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
)
self.num_out_tokens = maybe_move_tensor_to_cpu(
self.num_out_tokens, record_stream=on_side_stream
)
if self.num_local_experts > 1:
self.num_global_tokens_per_local_expert = maybe_move_tensor_to_cpu(
self.num_global_tokens_per_local_expert, record_stream=on_side_stream
)
if point == self.cuda_sync_point:
self.cuda_dtoh_stream.synchronize()
return tokens_per_expert
def transformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
if self.moe_token_dispatcher_type == "alltoall_seq":
ori_moe_permute_fusion = self.moe_permute_fusion
self.moe_permute_fusion = False
fn(self)
if self.moe_token_dispatcher_type == "alltoall_seq":
self.moe_permute_fusion = ori_moe_permute_fusion
del ori_moe_permute_fusion
return wrapper
def alltoall_seq_token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): Probs of tokens assigned to experts.
Shape: [num_tokens, num_experts].
routing_map (torch.Tensor): Mapping of tokens assigned to experts.
Shape: [num_tokens, num_experts].
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
self.hidden_shape = hidden_states.shape
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for routing map"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(routing_map)
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.all_to_all_sp2hp(hidden_states)
self.hidden_shape_before_permute = hidden_states.shape
if self.cuda_sync_point == "before_permutation_1":
torch.cuda.current_stream().synchronize()
(
permutated_local_input_tokens,
permuted_probs,
self.reversed_local_input_permutation_mapping,
) = permute(hidden_states, routing_map, probs=probs, num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion)
if self.cuda_sync_point == "before_ep_alltoall":
torch.cuda.current_stream().synchronize()
global_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
)
global_probs = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
permuted_probs,
self.output_splits,
self.input_splits,
)
if self.num_local_experts > 1:
global_input_tokens, global_probs = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert_cpu.ravel(),
self.sort_input_by_local_experts,
probs=global_probs,
)
if parallel_state.get_tensor_model_parallel_world_size() > 1:
global_input_tokens = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(
global_input_tokens
)
if self.cuda_sync_point == "before_finish":
torch.cuda.current_stream().synchronize()
return global_input_tokens, tokens_per_expert, global_probs
def alltoall_seq_token_unpermutation(
self, hidden_states: torch.Tensor, bias: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverse the token permutation to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
if parallel_state.get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_parallel.reduce_scatter_last_dim_to_tensor_parallel_region(
hidden_states
)
if self.num_local_experts > 1:
hidden_states, _ = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert_cpu.T.ravel(),
self.restore_output_by_local_experts,
)
permutated_local_input_tokens = tensor_parallel.all_to_all(
parallel_state.get_expert_model_parallel_group(),
hidden_states,
self.input_splits,
self.output_splits,
)
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
)
if parallel_state.get_tensor_model_parallel_world_size() > 1:
output = tensor_parallel.all_to_all_hp2sp(output)
output = output.view(self.hidden_shape)
return output, None
def preprocess_sync_wrapper(fn):
@wraps(fn)
def wrapper(self, routing_map):
num_tokens_per_local_expert = fn(self, routing_map)
if self.num_local_experts > 1:
if self.config.moe_permute_fusion:
self._maybe_update_cuda_sync_point("before_permutation_2")
return num_tokens_per_local_expert
return wrapper