Fused MoE
Overview
fused_moe is the MindIE-SD entry point for MoE, responsible for expert selection, token dispatch, expert computation, and result combining in MoE forward inference on NPU. This API targets open-source framework integration scenarios: callers provide activations, router logits, expert weights, and communication configurations, and the unified entry handles the entire routed experts forward computation.
In MoE models, each token selects a small number of experts based on router output. Compared to dense MLP, MoE can expand model capacity while controlling the actual computation per inference step, but introduces additional flows including token-to-expert routing, reordering, cross-card communication, and result recovery. fused_moe encapsulates these flows in a unified API, reducing repetitive adaptation costs on the framework side and enabling reuse of a single MoE computation entry across different parallelism strategies.
Currently fused_moe includes two paths: a fused operator path and a non-fused operator path. The fused operator path is not yet supported in the current version; setting use_fused_op=True falls back to the non-fused operator path. The default use_fused_op=False directly uses the non-fused operator path.
API Reference
from mindiesd import fused_moe
Function Signature
fused_moe(
hidden_states,
router_logits,
num_experts,
top_k,
w13_weight,
w2_weight,
w13_bias=None,
w2_bias=None,
quant_config=None,
w13_weight_scale=None,
w2_weight_scale=None,
tp_group=None,
ep_group=None,
dispatcher_type=None,
tokens_full=True,
k_group=1,
group_count=1,
group_select_mode=0,
routing_method="softmax",
renormalize=False,
routed_scaling_factor=1.0,
custom_routing_function=None,
reduce_results=True,
use_fused_op=False,
) -> torch.Tensor
Parameters
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
hidden_states |
torch.Tensor |
Yes | - | Input activations, shape [..., hidden_size], at least 2 dimensions. |
router_logits |
torch.Tensor |
Yes | - | Router logits, shape [..., num_experts], at least 2 dimensions; leading dimensions must match hidden_states. |
num_experts |
int |
Yes | - | Total number of experts, must match the last dimension of router_logits. Must be divisible by EP group size when using EP. |
top_k |
int |
Yes | - | Number of experts selected per token, range [1, num_experts]. |
w13_weight |
torch.Tensor |
Yes | - | Fused gate/up projection weight, shape [local_experts, hidden_size, 2 * intermediate_size], must be a 3D tensor. |
w2_weight |
torch.Tensor |
Yes | - | Down projection weight, shape [local_experts, intermediate_size, hidden_size], must have the same local_experts as w13_weight. |
w13_bias |
torch.Tensor / None |
No | None |
Gate/up projection bias, shape [local_experts, 2 * intermediate_size], must match w13_weight expert and output dimensions. |
w2_bias |
torch.Tensor / None |
No | None |
Down projection bias, shape [local_experts, hidden_size], must match w2_weight expert and output dimensions. |
quant_config |
QuantConfig / None |
No | None |
MindIE-SD quantization config for enabling quantization in the MoE forward flow. |
w13_weight_scale |
torch.Tensor / None |
No | None |
Quantization scale for w13_weight. |
w2_weight_scale |
torch.Tensor / None |
No | None |
Quantization scale for w2_weight. |
tp_group |
dist.ProcessGroup / None |
No | None |
TP communication group. Takes effect when EP is not enabled and TP group size > 1. |
ep_group |
dist.ProcessGroup / None |
No | None |
EP communication group. Takes priority when EP group size > 1. |
dispatcher_type |
str / None |
No | None |
Token dispatch strategy. Options: "static", "dynamic"; None auto-selects based on platform and communication config. "dynamic" is only supported in EP scenarios. |
tokens_full |
bool |
No | True |
Input token layout marker. Only two layouts supported: True means each rank inputs full tokens; False means each rank inputs evenly split local token shards per communication group. |
k_group |
int |
No | 1 |
Number of expert groups selected per token in grouped routing, range [1, group_count]. |
group_count |
int |
No | 1 |
Total number of expert groups; num_experts must be divisible by group_count. |
group_select_mode |
int |
No | 0 |
Expert group scoring method. 0 uses max score within group, 1 uses sum of top-2 scores within group. Each group must have at least 2 experts when using 1. |
routing_method |
str |
No | "softmax" |
Router logit scoring method, options: "softmax" or "sigmoid". |
renormalize |
bool |
No | False |
Whether to re-normalize top-k routing weights selected by softmax routing. Sigmoid routing outputs normalized weights per NPU gating top-k operator semantics. |
routed_scaling_factor |
float |
No | 1.0 |
Routing weight scaling factor, applied during expert selection. |
custom_routing_function |
callable / None |
No | None |
Custom routing function, called as custom_routing_function(hidden_states, gating_output, topk, renormalize), must return (topk_weights, topk_ids). |
reduce_results |
bool |
No | True |
Whether to perform communication reduction on full token output in static MoE; unused in dynamic MoE path. |
use_fused_op |
bool |
No | False |
Whether to prefer the fused operator path. Currently unsupported; setting True falls back to the non-fused path. Default False directly uses the non-fused path. |
Return Value
torch.Tensor: MoE forward computation result, same shape as hidden_states.
Fused Operator Path (Reserved)
The fused operator path is not yet supported in the current version. Setting use_fused_op=True falls back to the non-fused operator path.
Non-Fused Operator Path (Current Default)
This path splits the MoE forward pass into multiple stages, supporting both non-quantized and quantized MoE inference, covering single-card, TP, and EP scenarios, and providing static and dynamic token dispatch methods. The static dispatcher is suitable for single-card, TP, and some EP scenarios; the dynamic dispatcher targets all-to-all token exchange in EP scenarios.
Execution Flow
This path divides the MoE forward process into the following stages:
- prepare: Organize input activations and router logits, and prepare data needed for subsequent computation based on input layout.
- select_experts: Select top-k experts for each token based on router output and generate routing weights.
- dispatch: Route and reorder tokens by expert according to expert selection results, generating inputs for expert computation.
- mlp: Execute expert-side grouped MLP computation to complete routed experts' feed-forward computation.
- combine: Merge expert outputs back to the original token order to recover the routed MoE output.
- finalize: Complete post-processing after output recovery and communication.
Dispatcher Strategy
Two token dispatch strategies are supported:
- static dispatcher: Uses a static token dispatch path, suitable for single-card, TP, and some EP scenarios. This path completes token sorting, expert token statistics, and result recovery via NPU MoE routing operators.
- dynamic dispatcher: Uses a dynamic token dispatch path, suitable for EP scenarios. This path performs all-to-all communication based on token-to-expert distribution and restores token order before and after expert computation.
When dispatcher_type=None, the API auto-selects based on communication mode and NPU model: A3/A5 hardware uses dynamic dispatcher in EP scenarios, A2 hardware uses static dispatcher; non-EP scenarios (single-card/TP) always use static dispatcher. You can also explicitly specify via dispatcher_type="static" or dispatcher_type="dynamic".
Communication Configuration
Distributed MoE inference is supported through tp_group and ep_group. The current API enables only one MoE communication strategy at a time:
- When a valid
ep_groupis provided, the EP communication strategy takes priority. - When EP is not enabled and a valid
tp_groupis provided, the TP communication strategy is used. - When neither is enabled, the single-card path is used.
Callers must ensure that input activation and router logit layouts are consistent with the tokens_full configuration.
Quantization Configuration
When quant_config is not provided, or quant_config.quant_algo is None / NO_QUANT, the MoE forward flow runs in non-quantized mode. The following quantization configuration is currently supported:
QuantConfig(quant_algo=QuantAlgorithm.W8A8_DYNAMIC): W8A8 dynamic quantization, weights using INT8.
Routing Selection
Default routing uses NPU gating top-k operators for expert selection. routing_method="softmax" applies softmax to router logits; routing_method="sigmoid" applies sigmoid. Softmax routing can re-normalize selected top-k weights via renormalize=True; sigmoid routing outputs normalized weights per NPU gating top-k operator semantics.
When grouped routing is not enabled, keep the defaults k_group=1, group_count=1. When grouped routing is enabled, expert groups are first selected by group_select_mode, then top-k experts are selected per token within the selected groups.
Grouped routing requires the following constraints:
num_expertsmust be divisible bygroup_count.k_grouprange is[1, group_count].top_kmust not exceed the total number of experts within the selected expert groups.- When
group_select_mode=1, each expert group must contain at least 2 experts.
Usage Examples
Usage Notes
In actual framework integration, hidden_states, router_logits, and expert weights typically come from model forward and weight loading flows.
Distributed examples assume tp_group or ep_group have already been initialized.
Single-Card MoE
Single-card scenarios do not require communication groups; default strategy is used.
import torch
from mindiesd import fused_moe
num_tokens = 8
hidden_size = 4096
intermediate_size = 14336
num_experts = 8
top_k = 2
dtype = torch.bfloat16
device = "npu"
hidden_states = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype)
w13_weight = torch.randn(num_experts, hidden_size, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight = torch.randn(num_experts, intermediate_size, hidden_size, device=device, dtype=dtype)
w13_bias = torch.randn(num_experts, 2 * intermediate_size, device=device, dtype=dtype)
w2_bias = torch.randn(num_experts, hidden_size, device=device, dtype=dtype)
out = fused_moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=top_k,
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_bias=w13_bias,
w2_bias=w2_bias,
renormalize=True,
)
Grouped Routing MoE
When the model needs to first select expert groups and then select top-k experts from the selected groups, configure the grouped routing parameters. The example below uses sigmoid routing and scales routing weights via routed_scaling_factor.
import torch
from mindiesd import fused_moe
num_tokens = 8
hidden_size = 4096
intermediate_size = 14336
num_experts = 16
top_k = 2
dtype = torch.bfloat16
device = "npu"
hidden_states = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype)
w13_weight = torch.randn(num_experts, hidden_size, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight = torch.randn(num_experts, intermediate_size, hidden_size, device=device, dtype=dtype)
out = fused_moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=top_k,
w13_weight=w13_weight,
w2_weight=w2_weight,
k_group=1,
group_count=4,
group_select_mode=1,
routing_method="sigmoid",
routed_scaling_factor=0.5,
)
INT8 Dynamic Quant MoE
The INT8 path requires w13_weight and w2_weight to be torch.int8, along with corresponding quantization scales. MindIE-SD checks weight format before MLP computation; if weights are not in NPU NZ format, they are automatically converted to NZ before calling INT8 grouped MLP operators.
import torch
from mindiesd import fused_moe
from mindiesd.quantization.config import QuantConfig
from mindiesd.quantization.mode import QuantAlgorithm
num_tokens = 8
hidden_size = 4096
intermediate_size = 14336
num_experts = 8
top_k = 2
dtype = torch.bfloat16
device = "npu"
hidden_states = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype)
w13_weight = torch.randint(
-128,
127,
(num_experts, hidden_size, 2 * intermediate_size),
device=device,
dtype=torch.int8,
)
w2_weight = torch.randint(
-128,
127,
(num_experts, intermediate_size, hidden_size),
device=device,
dtype=torch.int8,
)
w13_weight_scale = torch.rand(num_experts, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight_scale = torch.rand(num_experts, hidden_size, device=device, dtype=dtype)
out = fused_moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=top_k,
w13_weight=w13_weight,
w2_weight=w2_weight,
quant_config=QuantConfig(quant_algo=QuantAlgorithm.W8A8_DYNAMIC),
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
dispatcher_type="static",
tokens_full=True,
reduce_results=False,
)
TP Static MoE
TP scenarios require tp_group. When tokens_full=True, each rank inputs full tokens.
import torch
from mindiesd import fused_moe
num_tokens = 8
hidden_size = 4096
intermediate_size = 14336
num_experts = 8
top_k = 2
dtype = torch.bfloat16
device = "npu"
hidden_states = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype)
w13_weight = torch.randn(num_experts, hidden_size, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight = torch.randn(num_experts, intermediate_size, hidden_size, device=device, dtype=dtype)
out = fused_moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=top_k,
w13_weight=w13_weight,
w2_weight=w2_weight,
tp_group=tp_group,
dispatcher_type="static",
tokens_full=True,
reduce_results=True,
)
EP Static MoE
EP scenarios can also explicitly specify the static dispatcher. Here, w13_weight and w2_weight represent the local expert weights held by the current rank, and ep_group is used for MoE communication.
import torch
from mindiesd import fused_moe
num_tokens = 8
hidden_size = 4096
intermediate_size = 14336
num_experts = 8
top_k = 2
dtype = torch.bfloat16
device = "npu"
ep_world_size = torch.distributed.get_world_size(ep_group)
local_experts = num_experts // ep_world_size
hidden_states = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype)
w13_weight = torch.randn(local_experts, hidden_size, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight = torch.randn(local_experts, intermediate_size, hidden_size, device=device, dtype=dtype)
out = fused_moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=top_k,
w13_weight=w13_weight,
w2_weight=w2_weight,
ep_group=ep_group,
dispatcher_type="static",
tokens_full=True,
reduce_results=True,
renormalize=True,
)
EP Dynamic MoE
The EP dynamic path requires ep_group. When each rank inputs locally sharded tokens evenly split by the communication group, set tokens_full=False. Here, local_hidden_states and local_router_logits represent the current rank's local token input, while w13_weight and w2_weight still represent the current rank's local expert weights. With this layout, the token count must be divisible by the communication group size.
import torch
from mindiesd import fused_moe
num_tokens = 8
hidden_size = 4096
intermediate_size = 14336
num_experts = 8
top_k = 2
dtype = torch.bfloat16
device = "npu"
ep_world_size = torch.distributed.get_world_size(ep_group)
local_experts = num_experts // ep_world_size
local_num_tokens = num_tokens // ep_world_size
local_hidden_states = torch.randn(local_num_tokens, hidden_size, device=device, dtype=dtype)
local_router_logits = torch.randn(local_num_tokens, num_experts, device=device, dtype=dtype)
w13_weight = torch.randn(local_experts, hidden_size, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight = torch.randn(local_experts, intermediate_size, hidden_size, device=device, dtype=dtype)
out = fused_moe(
hidden_states=local_hidden_states,
router_logits=local_router_logits,
num_experts=num_experts,
top_k=top_k,
w13_weight=w13_weight,
w2_weight=w2_weight,
ep_group=ep_group,
dispatcher_type="dynamic",
tokens_full=False,
renormalize=True,
)
Custom Routing Function
When the model already has custom routing logic, it can be integrated via custom_routing_function. This function must return (topk_weights, topk_ids).
import torch
from mindiesd import fused_moe
num_tokens = 8
hidden_size = 4096
intermediate_size = 14336
num_experts = 8
top_k = 2
dtype = torch.bfloat16
device = "npu"
hidden_states = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype)
w13_weight = torch.randn(num_experts, hidden_size, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight = torch.randn(num_experts, intermediate_size, hidden_size, device=device, dtype=dtype)
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
topk_result = gating_output.softmax(dim=-1).topk(topk, dim=-1)
topk_weights = topk_result.values
topk_ids = topk_result.indices
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
out = fused_moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=top_k,
w13_weight=w13_weight,
w2_weight=w2_weight,
custom_routing_function=custom_routing_function,
renormalize=True,
)
Notes
- The current API only supports forward inference; backward gradient computation is not supported.
- Currently supports non-quantized MoE and INT8 dynamic quant MoE paths.
w13_weightandw2_weightmust use the weight layout required by grouped matmul.- The last dimension of
router_logitsmust equalnum_experts. - The
dynamicdispatcher must be used in EP communication scenarios. - When using distributed communication, callers must initialize communication groups in advance and ensure consistent input layout across all ranks.