from dataclasses import dataclass
from typing import Any
import torch
@dataclass(frozen=True)
class MoEPrepareInput:
"""Input consumed by the dispatcher prepare stage."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
tokens_full: bool
@dataclass(frozen=True)
class MoEPrepareOutput:
"""Output produced by the dispatcher prepare stage."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
original_shape: Any
mlp_output_dtype: torch.dtype
dynamic_scale: torch.Tensor | None = None
@dataclass(frozen=True)
class MoERoutingInput:
"""Input consumed by expert selection."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
top_k: int
renormalize: bool = False
k_group: int = 1
group_count: int = 1
group_select_mode: int = 0
norm_type: int = 0
routed_scaling_factor: float = 1.0
eps: float = 1e-20
custom_routing_function: Any = None
@dataclass(frozen=True)
class MoEWeights:
"""Dense and quantized weights consumed by grouped expert MLP computation."""
w13_weight: torch.Tensor
w2_weight: torch.Tensor
w13_bias: torch.Tensor | None = None
w2_bias: torch.Tensor | None = None
w13_weight_scale: torch.Tensor | None = None
w2_weight_scale: torch.Tensor | None = None
@dataclass(frozen=True)
class MoEStaticCombineMetadata:
"""Metadata required to restore token order in static MoE."""
topk_weights: torch.Tensor
expanded_row_idx: torch.Tensor
restore_shape: torch.Size
@dataclass(frozen=True)
class MoEDynamicCombineMetadata:
"""Metadata required to restore token order after dynamic MoE exchange."""
input_splits: Any
output_splits: Any
topk_weights: torch.Tensor
local_unpermute_indices: torch.Tensor
global_unpermute_indices: torch.Tensor | None
hidden_shape: torch.Size
@dataclass(frozen=True)
class MoETokenDispatchInput:
"""Input consumed by the dispatcher token-routing stage."""
hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
num_experts: int
top_k: int
local_num_experts: int
dynamic_scale: torch.Tensor | None = None
@dataclass(frozen=True)
class MoETokenDispatchOutput:
"""Output produced by the dispatcher token-routing stage."""
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
combine_metadata: Any
dynamic_scale: torch.Tensor | None = None
@dataclass(frozen=True)
class MoEMlpComputeInput:
"""Input consumed by grouped expert MLP computation."""
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
weights: MoEWeights
mlp_output_dtype: torch.dtype
dynamic_scale: torch.Tensor | None = None