import os
import unittest
from unittest.mock import MagicMock, patch
import torch
import torch.distributed as dist
from mindiesd.layers.moe.moe_dataclass import (
MoEMlpComputeInput,
MoETokenDispatchOutput,
MoEWeights,
MoEPrepareInput,
MoERoutingInput,
MoETokenDispatchInput,
)
from mindiesd.layers.moe.moe_context import (
MoECommType,
build_mlp_compute_input,
build_moe_weights,
build_prepare_input,
build_routing_input,
build_token_dispatch_input,
get_moe_comm_type,
get_moe_group,
get_moe_quant_algo,
is_moe_int_quant,
is_moe_mxfp_quant,
is_moe_quant,
set_moe_comm_context,
set_moe_context,
validate_moe_inputs,
)
from mindiesd.quantization.config import QuantConfig
from mindiesd.quantization.mode import QuantAlgorithm
from mindiesd.utils import ParametersInvalid
from .common import make_moe_kwargs, make_w8a8_dynamic_quant_config, make_w8a8_mxfp8_quant_config
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "NPU",
"Skip CPU-compatible tests when MINDIE_TEST_MODE is NPU.",
)
class TestMoEContext(unittest.TestCase):
def setUp(self):
set_moe_context()
def test_validate_moe_inputs_accepts_valid_inputs(self):
self.assertEqual(validate_moe_inputs(**make_moe_kwargs()), QuantAlgorithm.NO_QUANT)
def test_validate_moe_inputs_resolves_empty_quant_config(self):
self.assertEqual(validate_moe_inputs(**make_moe_kwargs(quant_config=QuantConfig())), QuantAlgorithm.NO_QUANT)
def test_validate_moe_inputs_accepts_supported_quant_scales(self):
cases = (
dict(
quant_algo=QuantAlgorithm.W8A8_DYNAMIC,
quant_config=make_w8a8_dynamic_quant_config(),
w13_weight=torch.randint(-8, 8, (2, 4, 16), dtype=torch.int8),
w2_weight=torch.randint(-8, 8, (2, 8, 4), dtype=torch.int8),
w13_weight_scale=torch.randn(2, 16),
w2_weight_scale=torch.randn(2, 4),
),
dict(
quant_algo=QuantAlgorithm.W8A8_MXFP8,
quant_config=make_w8a8_mxfp8_quant_config(),
w13_weight=torch.empty(2, 4, 16, dtype=torch.float8_e4m3fn),
w2_weight=torch.empty(2, 8, 4, dtype=torch.float8_e4m3fn),
w13_weight_scale=torch.empty(2, 16, dtype=torch.uint8),
w2_weight_scale=torch.empty(2, 4, dtype=torch.uint8),
),
)
for case in cases:
with self.subTest(quant_algo=case["quant_algo"]):
kwargs = dict(case)
quant_algo = kwargs.pop("quant_algo")
self.assertEqual(validate_moe_inputs(**make_moe_kwargs(**kwargs)), quant_algo)
def test_validate_moe_inputs_rejects_invalid_parameters(self):
invalid_cases = (
dict(name="reduce_results", kwargs=dict(reduce_results="false")),
dict(name="tokens_full", kwargs=dict(tokens_full="true")),
dict(name="renormalize", kwargs=dict(renormalize="true")),
dict(name="num_experts", kwargs=dict(num_experts="2")),
dict(name="top_k", kwargs=dict(top_k=3)),
dict(name="k_group", kwargs=dict(k_group=0)),
dict(name="group_count", kwargs=dict(group_count=0)),
dict(name="group_select_mode", kwargs=dict(group_select_mode=2)),
dict(name="routing_method", kwargs=dict(routing_method="relu")),
dict(name="routing_method_type", kwargs=dict(routing_method=0)),
dict(name="routed_scaling_factor_int", kwargs=dict(routed_scaling_factor=1)),
dict(name="routed_scaling_factor", kwargs=dict(routed_scaling_factor="1.0")),
dict(name="quant_config_type", kwargs=dict(quant_config="int8")),
dict(
name="unsupported_quant_config", kwargs=dict(quant_config=QuantConfig(quant_algo=QuantAlgorithm.W4A16))
),
dict(name="none_quant_with_scale", kwargs=dict(w13_weight_scale=torch.randn(2, 16))),
dict(name="dispatcher_type", kwargs=dict(dispatcher_type="auto")),
dict(name="custom_routing_function", kwargs=dict(custom_routing_function=object())),
)
for case in invalid_cases:
with self.subTest(case=case["name"]):
kwargs = make_moe_kwargs()
kwargs.update(case["kwargs"])
with self.assertRaises(ParametersInvalid):
validate_moe_inputs(**kwargs)
def test_validate_moe_inputs_rejects_missing_quant_scales(self):
quant_cases = (
dict(
name="w8a8_dynamic",
kwargs=make_moe_kwargs(
quant_config=make_w8a8_dynamic_quant_config(),
w13_weight=torch.randint(-8, 8, (2, 4, 16), dtype=torch.int8),
w2_weight=torch.randint(-8, 8, (2, 8, 4), dtype=torch.int8),
w13_weight_scale=torch.randn(2, 16),
w2_weight_scale=torch.randn(2, 4),
),
),
dict(
name="w8a8_mxfp8",
kwargs=make_moe_kwargs(
quant_config=make_w8a8_mxfp8_quant_config(),
w13_weight=torch.empty(2, 4, 16, dtype=torch.float8_e4m3fn),
w2_weight=torch.empty(2, 8, 4, dtype=torch.float8_e4m3fn),
w13_weight_scale=torch.empty(2, 16, dtype=torch.uint8),
w2_weight_scale=torch.empty(2, 4, dtype=torch.uint8),
),
),
)
cases = (
dict(name="missing_w13_scale", kwargs=dict(w13_weight_scale=None)),
dict(name="missing_w2_scale", kwargs=dict(w2_weight_scale=None)),
)
for quant_case in quant_cases:
for case in cases:
with self.subTest(quant=quant_case["name"], case=case["name"]):
kwargs = dict(quant_case["kwargs"])
kwargs.update(case["kwargs"])
with self.assertRaises(ParametersInvalid):
validate_moe_inputs(**kwargs)
def test_validate_moe_inputs_rejects_invalid_shapes(self):
invalid_cases = (
dict(router_logits=torch.randn(4, 2)),
dict(w13_bias=torch.randn(2, 8)),
dict(w2_bias=torch.randn(2, 8)),
dict(w2_weight=torch.randn(2, 7, 5)),
)
for overrides in invalid_cases:
with self.subTest(overrides=tuple(overrides)):
with self.assertRaises(ParametersInvalid):
validate_moe_inputs(**make_moe_kwargs(**overrides))
def test_validate_moe_inputs_rejects_invalid_ep_partition(self):
ep_group = MagicMock(spec=dist.ProcessGroup)
with self.assertRaises(ParametersInvalid):
with patch("torch.distributed.get_world_size", return_value=3):
validate_moe_inputs(**make_moe_kwargs(num_experts=4, ep_group=ep_group))
def test_validate_moe_inputs_rejects_invalid_expert_grouping(self):
cases = (
dict(name="uneven_groups", kwargs=dict(num_experts=5, group_count=2)),
dict(name="too_many_selected_groups", kwargs=dict(num_experts=4, k_group=3, group_count=2)),
dict(
name="topk_exceeds_selected_experts",
kwargs=dict(num_experts=4, top_k=3, k_group=1, group_count=2),
),
dict(
name="top2_group_score_without_two_experts",
kwargs=dict(num_experts=4, group_count=4, group_select_mode=1),
),
)
for case in cases:
with self.subTest(case=case["name"]):
with self.assertRaises(ParametersInvalid):
validate_moe_inputs(**make_moe_kwargs(**case["kwargs"]))
def test_set_moe_comm_context_resolves_active_group(self):
tp_group = MagicMock(spec=dist.ProcessGroup)
ep_group = MagicMock(spec=dist.ProcessGroup)
cases = (
dict(
name="prefers_ep",
tp_group=tp_group,
ep_group=ep_group,
expected_type=MoECommType.EP,
expected_group=ep_group,
),
dict(
name="uses_tp_without_ep",
tp_group=tp_group,
ep_group=None,
expected_type=MoECommType.TP,
expected_group=tp_group,
),
)
for case in cases:
with self.subTest(case=case["name"]):
with patch("torch.distributed.get_world_size", return_value=2):
set_moe_comm_context(tp_group=case["tp_group"], ep_group=case["ep_group"])
self.assertEqual(get_moe_comm_type(), case["expected_type"])
self.assertIs(get_moe_group(), case["expected_group"])
def test_is_moe_quant_resolves_quantization(self):
set_moe_context()
self.assertFalse(is_moe_quant())
self.assertFalse(is_moe_int_quant())
self.assertFalse(is_moe_mxfp_quant())
self.assertEqual(get_moe_quant_algo(), QuantAlgorithm.NO_QUANT)
set_moe_context(quant_algo=QuantAlgorithm.W8A8_DYNAMIC)
self.assertTrue(is_moe_quant())
self.assertTrue(is_moe_int_quant())
self.assertFalse(is_moe_mxfp_quant())
self.assertEqual(get_moe_quant_algo(), QuantAlgorithm.W8A8_DYNAMIC)
set_moe_context(quant_algo=QuantAlgorithm.W8A8_MXFP8)
self.assertTrue(is_moe_quant())
self.assertFalse(is_moe_int_quant())
self.assertTrue(is_moe_mxfp_quant())
self.assertEqual(get_moe_quant_algo(), QuantAlgorithm.W8A8_MXFP8)
def test_build_input_wrappers(self):
hidden_states = torch.randn(3, 4)
router_logits = torch.randn(3, 2)
topk_weights = torch.randn(3, 1)
topk_ids = torch.zeros(3, 1, dtype=torch.int32)
w13_weight = torch.randn(2, 4, 16)
w2_weight = torch.randn(2, 8, 4)
group_list = torch.tensor([2, 3])
w13_weight_scale = torch.randn(2, 16)
w2_weight_scale = torch.randn(2, 4)
self.assertIsInstance(build_prepare_input(hidden_states, router_logits), MoEPrepareInput)
routing_input = build_routing_input(
hidden_states,
router_logits,
top_k=1,
k_group=2,
group_count=2,
group_select_mode=1,
routing_method="sigmoid",
routed_scaling_factor=0.5,
)
self.assertIsInstance(routing_input, MoERoutingInput)
self.assertEqual(routing_input.k_group, 2)
self.assertEqual(routing_input.group_count, 2)
self.assertEqual(routing_input.group_select_mode, 1)
self.assertEqual(routing_input.norm_type, 1)
self.assertEqual(routing_input.routed_scaling_factor, 0.5)
weights = build_moe_weights(
w13_weight,
w2_weight,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
)
self.assertIsInstance(weights, MoEWeights)
self.assertIs(weights.w13_weight_scale, w13_weight_scale)
self.assertIs(weights.w2_weight_scale, w2_weight_scale)
self.assertIsInstance(
build_token_dispatch_input(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
num_experts=2,
top_k=1,
weights=weights,
),
MoETokenDispatchInput,
)
dispatch_output = MoETokenDispatchOutput(
hidden_states=hidden_states,
group_list=group_list,
group_list_type=1,
combine_metadata=object(),
)
mlp_input = build_mlp_compute_input(
dispatch_output=dispatch_output,
weights=weights,
mlp_output_dtype=torch.float32,
)
self.assertIsInstance(mlp_input, MoEMlpComputeInput)
self.assertIs(mlp_input.hidden_states, hidden_states)
self.assertIs(mlp_input.group_list, group_list)
self.assertIsInstance(mlp_input.weights, MoEWeights)
self.assertIs(mlp_input.weights, weights)
if __name__ == "__main__":
unittest.main()