import os
import unittest
from unittest.mock import MagicMock, patch
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from mindiesd.layers.moe import moe
from mindiesd.layers.moe.moe import resolve_dispatcher_class
from mindiesd.layers.moe.moe_context import set_moe_comm_context
from mindiesd.layers.moe.moe_dataclass import MoEPrepareOutput, MoETokenDispatchOutput
from mindiesd.layers.moe.token_dispatcher import DynamicDispatcher, StaticDispatcher
from mindiesd.utils import ParametersInvalid
from mindiesd.utils.get_platform import NPUDevice, get_npu_device
from .common import make_moe_kwargs, make_mxfp8_ones, make_w8a8_dynamic_quant_config, make_w8a8_mxfp8_quant_config
def mock_dispatch_result(num_tokens=3, hidden_size=4):
return MoETokenDispatchOutput(
hidden_states=torch.randn(num_tokens, hidden_size),
dynamic_scale=None,
group_list=torch.tensor([2, 1]),
group_list_type=1,
combine_metadata=object(),
)
def torch_moe_reference(
hidden_states,
w13_weight,
w2_weight,
topk_weights,
topk_ids,
w13_bias=None,
w2_bias=None,
):
num_tokens, hidden_size = hidden_states.shape
top_k = topk_ids.shape[-1]
expanded_hidden = hidden_states.view(num_tokens, 1, hidden_size).repeat(1, top_k, 1).reshape(-1, hidden_size)
expert_out = torch.zeros(num_tokens * top_k, w2_weight.shape[-1], dtype=hidden_states.dtype)
flat_topk_ids = topk_ids.reshape(-1)
for expert_id in range(w13_weight.shape[0]):
mask = flat_topk_ids == expert_id
if not mask.any():
continue
gate_up = expanded_hidden[mask] @ w13_weight[expert_id]
if w13_bias is not None:
gate_up = gate_up + w13_bias[expert_id]
gate, up = gate_up.chunk(2, dim=-1)
expert_tokens = F.silu(gate) * up
expert_out[mask] = expert_tokens @ w2_weight[expert_id]
if w2_bias is not None:
expert_out[mask] = expert_out[mask] + w2_bias[expert_id]
weighted = expert_out.view(num_tokens, top_k, -1) * topk_weights.view(num_tokens, top_k, 1)
return weighted.sum(dim=1)
def torch_select_experts(router_logits, top_k, renormalize, routed_scaling_factor=1.0):
router_logits = router_logits.float()
topk_result = router_logits.softmax(dim=-1).topk(top_k, dim=-1)
topk_weights = topk_result.values
topk_ids = topk_result.indices.to(torch.int32)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if routed_scaling_factor != 1.0:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_ids
def run_static_moe(**kwargs):
dispatch_result = kwargs.pop(
"dispatch_result",
mock_dispatch_result(kwargs["hidden_states"].shape[0], kwargs["hidden_states"].shape[-1]),
)
combine_output = kwargs.pop("combine_output", torch.randn_like(kwargs["hidden_states"]))
dispatched_hidden_states = dispatch_result.hidden_states
mlp_output = kwargs.pop("mlp_output", torch.randn_like(dispatched_hidden_states))
topk_weights = torch.ones(kwargs["hidden_states"].shape[0], kwargs["top_k"])
topk_ids = torch.zeros(kwargs["hidden_states"].shape[0], kwargs["top_k"], dtype=torch.int32)
with (
patch.object(StaticDispatcher, "dispatch", return_value=dispatch_result) as mock_dispatch,
patch.object(
StaticDispatcher,
"combine",
return_value=combine_output,
),
patch("mindiesd.layers.moe.moe.unified_apply_mlp", return_value=mlp_output),
patch(
"mindiesd.layers.moe.moe.select_experts",
return_value=(topk_weights, topk_ids),
),
):
output = moe(**kwargs)
return output, mock_dispatch
def run_dynamic_moe(**kwargs):
dispatch_result = kwargs.pop(
"dispatch_result",
mock_dispatch_result(kwargs["hidden_states"].shape[0], kwargs["hidden_states"].shape[-1]),
)
prepare_output = MoEPrepareOutput(
hidden_states=kwargs["hidden_states"],
router_logits=kwargs["router_logits"],
original_shape=kwargs["hidden_states"].shape,
mlp_output_dtype=kwargs["hidden_states"].dtype,
)
topk_weights = torch.ones(kwargs["hidden_states"].shape[0], kwargs["top_k"])
topk_ids = torch.zeros(kwargs["hidden_states"].shape[0], kwargs["top_k"], dtype=torch.int32)
with (
patch.object(DynamicDispatcher, "prepare", return_value=prepare_output) as mock_prepare,
patch.object(
DynamicDispatcher,
"dispatch",
return_value=dispatch_result,
) as mock_dispatch,
patch.object(
DynamicDispatcher,
"combine",
return_value=torch.randn_like(kwargs["hidden_states"]),
),
patch("mindiesd.layers.moe.moe.unified_apply_mlp"),
patch.object(
DynamicDispatcher,
"finalize",
return_value=torch.randn_like(kwargs["hidden_states"]),
),
patch(
"mindiesd.layers.moe.moe.select_experts",
return_value=(topk_weights, topk_ids),
),
):
output = moe(**kwargs)
return output, mock_prepare, mock_dispatch
class TestMoeFunction(unittest.TestCase):
def setUp(self):
DynamicDispatcher._split_cpu_buffers.clear()
DynamicDispatcher._split_copy_events.clear()
set_moe_comm_context()
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU",
"Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU.",
)
def test_static_moe_matches_torch_reference(self):
torch.manual_seed(2026)
cases = (
dict(top_k=1, renormalize=False, routed_scaling_factor=1.0, with_bias=False, dtype=torch.bfloat16),
dict(top_k=2, renormalize=True, routed_scaling_factor=0.5, with_bias=True, dtype=torch.bfloat16),
dict(top_k=1, renormalize=False, routed_scaling_factor=1.0, with_bias=False, dtype=torch.float16),
dict(top_k=2, renormalize=True, routed_scaling_factor=0.5, with_bias=True, dtype=torch.float16),
)
for case in cases:
with self.subTest(**case):
device = torch.device("npu")
dtype = case["dtype"]
num_tokens = 5
hidden_size = 4
intermediate_size = 6
num_experts = 3
hidden_states = torch.randn(num_tokens, hidden_size) / 10
router_logits = torch.randn(num_tokens, num_experts) / 10
w13_weight = torch.randn(num_experts, hidden_size, 2 * intermediate_size) / 10
w2_weight = torch.randn(num_experts, intermediate_size, hidden_size) / 10
w13_bias = torch.randn(num_experts, 2 * intermediate_size) / 10 if case["with_bias"] else None
w2_bias = torch.randn(num_experts, hidden_size) / 10 if case["with_bias"] else None
topk_weights, topk_ids = torch_select_experts(
router_logits.to(dtype=dtype),
case["top_k"],
renormalize=case["renormalize"],
routed_scaling_factor=case["routed_scaling_factor"],
)
expected = torch_moe_reference(
hidden_states=hidden_states,
w13_weight=w13_weight,
w2_weight=w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
w13_bias=w13_bias,
w2_bias=w2_bias,
)
actual = moe(
hidden_states=hidden_states.to(device=device, dtype=dtype),
router_logits=router_logits.to(device=device, dtype=dtype),
num_experts=num_experts,
top_k=case["top_k"],
w13_weight=w13_weight.to(device=device, dtype=dtype),
w2_weight=w2_weight.to(device=device, dtype=dtype),
w13_bias=w13_bias.to(device=device, dtype=dtype) if w13_bias is not None else None,
w2_bias=w2_bias.to(device=device, dtype=dtype) if w2_bias is not None else None,
dispatcher_type="static",
tokens_full=True,
renormalize=case["renormalize"],
routed_scaling_factor=case["routed_scaling_factor"],
reduce_results=False,
)
torch.testing.assert_close(actual.cpu().float(), expected.float(), atol=5e-2, rtol=5e-2)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "NPU",
"Skip CPU-compatible tests when MINDIE_TEST_MODE is NPU.",
)
def test_default_dispatcher_selection_and_static_override(self):
ep_group = MagicMock(spec=dist.ProcessGroup)
with patch("torch.distributed.get_world_size", return_value=2):
_, dynamic_prepare, dynamic_dispatch = run_dynamic_moe(
**make_moe_kwargs(num_experts=4, tokens_full=True),
ep_group=ep_group,
)
with patch.object(DynamicDispatcher, "dispatch") as unused_dynamic_dispatch:
with patch("torch.distributed.all_reduce"):
_, static_dispatch = run_static_moe(
**make_moe_kwargs(num_experts=4, tokens_full=True, dispatcher_type="static"),
ep_group=ep_group,
)
dynamic_prepare.assert_called_once()
dynamic_dispatch.assert_called_once()
static_dispatch.assert_called_once()
unused_dynamic_dispatch.assert_not_called()
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "NPU",
"Skip CPU-compatible tests when MINDIE_TEST_MODE is NPU.",
)
def test_resolve_dispatcher_class_by_topk_and_override(self):
ep_group = MagicMock(spec=dist.ProcessGroup)
self.assertIs(resolve_dispatcher_class(top_k=1), StaticDispatcher)
with patch("torch.distributed.get_world_size", return_value=2):
set_moe_comm_context(ep_group=ep_group)
self.assertIs(resolve_dispatcher_class(top_k=1), DynamicDispatcher)
self.assertIs(resolve_dispatcher_class(top_k=2), StaticDispatcher)
self.assertIs(resolve_dispatcher_class("static"), StaticDispatcher)
self.assertIs(resolve_dispatcher_class("dynamic"), DynamicDispatcher)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "NPU",
"Skip CPU-compatible tests when MINDIE_TEST_MODE is NPU.",
)
def test_resolve_dispatcher_class_rejects_dynamic_without_ep(self):
with self.assertRaises(ParametersInvalid):
resolve_dispatcher_class("dynamic")
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "NPU",
"Skip CPU-compatible tests when MINDIE_TEST_MODE is NPU.",
)
def test_dispatcher_type_overrides_default_routing_to_dynamic(self):
kwargs = make_moe_kwargs(num_experts=4, tokens_full=True, dispatcher_type="dynamic")
ep_group = MagicMock(spec=dist.ProcessGroup)
with patch("torch.distributed.get_world_size", return_value=2):
with patch.object(StaticDispatcher, "dispatch") as static_dispatch:
_, _, dynamic_dispatch = run_dynamic_moe(**kwargs, ep_group=ep_group)
dynamic_dispatch.assert_called_once()
static_dispatch.assert_not_called()
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU",
"Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU.",
)
@unittest.skipIf(
get_npu_device() not in (NPUDevice.A2, NPUDevice.A3),
"Skip INT8 MoE tests when device is not A2 or A3.",
)
class TestMoeW8A8Dynamic(unittest.TestCase):
def setUp(self):
set_moe_comm_context()
def test_w8a8_dynamic_moe_produces_correct_output_shape_and_dtype(self):
torch.manual_seed(2026)
device = torch.device("npu")
num_tokens = 4
hidden_size = 32
intermediate_size = 32
num_experts = 2
dtype = torch.bfloat16
hidden_states = (torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) / 10).contiguous()
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype) / 10
w13_weight = torch_npu.npu_format_cast(
torch.randint(-8, 8, (num_experts, hidden_size, 2 * intermediate_size), dtype=torch.int8, device=device),
29,
)
w2_weight = torch_npu.npu_format_cast(
torch.randint(-8, 8, (num_experts, intermediate_size, hidden_size), dtype=torch.int8, device=device),
29,
)
w13_weight_scale = torch.rand(num_experts, 2 * intermediate_size, device=device, dtype=dtype)
w2_weight_scale = torch.rand(num_experts, hidden_size, device=device, dtype=dtype)
output = moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=1,
w13_weight=w13_weight,
w2_weight=w2_weight,
quant_config=make_w8a8_dynamic_quant_config(),
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
dispatcher_type="static",
tokens_full=True,
reduce_results=False,
)
self.assertEqual(tuple(output.shape), (num_tokens, hidden_size))
self.assertEqual(output.dtype, dtype)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU",
"Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU.",
)
@unittest.skipIf(get_npu_device() != NPUDevice.A5, "Skip MXFP8 MoE tests when device is not A5.")
class TestMoeW8A8Mxfp8(unittest.TestCase):
def setUp(self):
set_moe_comm_context()
def test_w8a8_mxfp8_moe_produces_correct_output_shape_and_dtype(self):
device = torch.device("npu")
num_tokens = 4
hidden_size = 128
intermediate_size = 64
num_experts = 2
dtype = torch.bfloat16
hidden_states = (torch.randn(num_tokens, hidden_size, device=device, dtype=dtype) / 10).contiguous()
router_logits = torch.randn(num_tokens, num_experts, device=device, dtype=dtype) / 10
w13_weight, w13_weight_scale = make_mxfp8_ones(
num_experts,
hidden_size,
2 * intermediate_size,
device=device,
)
w2_weight, w2_weight_scale = make_mxfp8_ones(
num_experts,
intermediate_size,
hidden_size,
device=device,
)
output = moe(
hidden_states=hidden_states,
router_logits=router_logits,
num_experts=num_experts,
top_k=1,
w13_weight=w13_weight,
w2_weight=w2_weight,
quant_config=make_w8a8_mxfp8_quant_config(),
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
dispatcher_type="static",
tokens_full=True,
reduce_results=False,
)
self.assertEqual(tuple(output.shape), (num_tokens, hidden_size))
self.assertEqual(output.dtype, dtype)
if __name__ == "__main__":
unittest.main()