import os
import torch
import torch.nn as nn
import pytest
from mindspeed_mm.fsdp.models.mimov2_5.modeling_mimo_v2 import PatchMiMoV2TopkRouter, PatchMiMoV2NaiveMoe, PatchMiMoV2MoE
from mindspeed_mm.fsdp.models.mimov2_5.configuration_mimo_v2 import MiMoV2Config
from tests.ut_fsdp.utils.utils import judge_expression
class TestPatchMiMoV2TopkRouter:
def test_init(self):
config = MiMoV2Config(
hidden_size=4096,
n_routed_experts=8
)
router = PatchMiMoV2TopkRouter(config)
judge_expression(router.n_routed_experts == 8)
judge_expression(router.weight.shape == (8, 4096))
judge_expression(router.e_score_correction_bias.shape == (8,))
def test_forward(self):
config = MiMoV2Config(
hidden_size=4096,
n_routed_experts=4
)
router = PatchMiMoV2TopkRouter(config)
hidden_states = torch.randn(2, 10, 4096)
router_logits = router(hidden_states)
judge_expression(router_logits.shape == (20, 4))
judge_expression(router_logits.dtype == torch.float32)
def test_forward_with_different_batch_size(self):
config = MiMoV2Config(
hidden_size=2048,
n_routed_experts=6
)
router = PatchMiMoV2TopkRouter(config)
hidden_states = torch.randn(5, 15, 2048)
router_logits = router(hidden_states)
judge_expression(router_logits.shape == (75, 6))
class TestPatchMiMoV2NaiveMoe:
def test_init(self):
config = MiMoV2Config(
hidden_size=4096,
n_routed_experts=4,
moe_intermediate_size=1024,
hidden_act="silu"
)
moe = PatchMiMoV2NaiveMoe(config)
judge_expression(moe.num_experts == 4)
judge_expression(moe.hidden_size == 4096)
judge_expression(moe.intermediate_dim == 1024)
judge_expression(moe.gate_up_proj.shape == (4, 4096, 2048))
judge_expression(moe.down_proj.shape == (4, 1024, 4096))
def test_forward(self):
config = MiMoV2Config(
hidden_size=512,
n_routed_experts=4,
moe_intermediate_size=256,
hidden_act="silu"
)
moe = PatchMiMoV2NaiveMoe(config)
hidden_states = torch.randn(10, 512)
top_k_weights = torch.rand(10, 2)
top_k_index = torch.randint(0, 4, (10, 2))
output = moe(hidden_states, top_k_weights, top_k_index)
judge_expression(output.shape == (10, 512))
judge_expression(output.dtype == hidden_states.dtype)
def test_forward_with_single_expert(self):
config = MiMoV2Config(
hidden_size=256,
n_routed_experts=2,
moe_intermediate_size=128,
hidden_act="silu"
)
moe = PatchMiMoV2NaiveMoe(config)
hidden_states = torch.randn(5, 256)
top_k_weights = torch.rand(5, 1)
top_k_index = torch.zeros(5, 1, dtype=torch.long)
output = moe(hidden_states, top_k_weights, top_k_index)
judge_expression(output.shape == (5, 256))
def test_forward_with_all_experts(self):
config = MiMoV2Config(
hidden_size=128,
n_routed_experts=3,
moe_intermediate_size=64,
hidden_act="silu"
)
moe = PatchMiMoV2NaiveMoe(config)
hidden_states = torch.randn(6, 128)
top_k_weights = torch.ones(6, 3) / 3
top_k_index = torch.tensor([[0, 1, 2], [1, 2, 0], [2, 0, 1], [0, 1, 2], [1, 2, 0], [2, 0, 1]])
output = moe(hidden_states, top_k_weights, top_k_index)
judge_expression(output.shape == (6, 128))
class TestPatchMiMoV2MoE:
def test_init(self):
config = MiMoV2Config(
hidden_size=4096,
n_routed_experts=8,
moe_intermediate_size=1024,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
hidden_act="silu"
)
moe = PatchMiMoV2MoE(config)
judge_expression(moe.n_routed_experts == 8)
judge_expression(moe.n_group == 2)
judge_expression(moe.topk_group == 1)
judge_expression(moe.norm_topk_prob == True)
judge_expression(moe.top_k == 2)
def test_route_tokens_to_experts(self):
config = MiMoV2Config(
hidden_size=512,
n_routed_experts=4,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0
)
moe = PatchMiMoV2MoE(config)
router_logits = torch.randn(10, 4)
topk_indices, topk_weights = moe.route_tokens_to_experts(router_logits)
judge_expression(topk_indices.shape == (10, 2))
judge_expression(topk_weights.shape == (10, 2))
judge_expression((topk_weights >= 0).all())
def test_route_tokens_to_experts_with_norm(self):
config = MiMoV2Config(
hidden_size=256,
n_routed_experts=6,
num_experts_per_tok=3,
n_group=3,
topk_group=2,
norm_topk_prob=True,
routed_scaling_factor=2.0
)
moe = PatchMiMoV2MoE(config)
router_logits = torch.randn(5, 6)
topk_indices, topk_weights = moe.route_tokens_to_experts(router_logits)
judge_expression(topk_indices.shape == (5, 3))
judge_expression(topk_weights.shape == (5, 3))
def test_route_tokens_to_experts_without_norm(self):
config = MiMoV2Config(
hidden_size=256,
n_routed_experts=4,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
norm_topk_prob=False,
routed_scaling_factor=1.0
)
moe = PatchMiMoV2MoE(config)
router_logits = torch.randn(8, 4)
topk_indices, topk_weights = moe.route_tokens_to_experts(router_logits)
judge_expression(topk_indices.shape == (8, 2))
judge_expression(topk_weights.shape == (8, 2))
judge_expression((topk_weights >= 0).all())
def test_forward(self):
config = MiMoV2Config(
hidden_size=512,
n_routed_experts=4,
moe_intermediate_size=256,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
hidden_act="silu"
)
moe = PatchMiMoV2MoE(config)
hidden_states = torch.randn(2, 10, 512)
output = moe(hidden_states)
judge_expression(output.shape == (2, 10, 512))
def test_forward_with_different_config(self):
config = MiMoV2Config(
hidden_size=1024,
n_routed_experts=8,
moe_intermediate_size=512,
num_experts_per_tok=3,
n_group=4,
topk_group=2,
norm_topk_prob=True,
routed_scaling_factor=0.5,
hidden_act="silu"
)
moe = PatchMiMoV2MoE(config)
hidden_states = torch.randn(3, 15, 1024)
output = moe(hidden_states)
judge_expression(output.shape == (3, 15, 1024))
def test_forward_with_single_batch(self):
config = MiMoV2Config(
hidden_size=256,
n_routed_experts=4,
moe_intermediate_size=128,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
hidden_act="silu"
)
moe = PatchMiMoV2MoE(config)
hidden_states = torch.randn(1, 5, 256)
output = moe(hidden_states)
judge_expression(output.shape == (1, 5, 256))
def test_forward_with_scaling_factor(self):
config = MiMoV2Config(
hidden_size=512,
n_routed_experts=4,
moe_intermediate_size=256,
num_experts_per_tok=2,
n_group=2,
topk_group=1,
norm_topk_prob=True,
routed_scaling_factor=2.5,
hidden_act="silu"
)
moe = PatchMiMoV2MoE(config)
hidden_states = torch.randn(2, 8, 512)
output = moe(hidden_states)
judge_expression(output.shape == (2, 8, 512))