#!/usr/bin/env python
# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
# MindIE is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# pylint: disable=too-many-lines,too-many-function-args
import unittest
import numpy as np
from mindiesd.eplb.eplb_scheduler import eplb_greedy
RESPONSE = {
0: np.array(
[
450,
1,
3892,
4017,
226,
67,
1321,
3,
0,
214,
1376,
722,
318,
428,
211,
13,
3,
86,
10,
8,
44,
28,
39,
159,
0,
2787,
4083,
0,
219,
8,
4018,
0,
2,
1,
53,
6,
3,
5,
19,
0,
0,
12,
208,
3987,
2171,
0,
27,
0,
30,
0,
1,
11,
0,
4096,
0,
0,
3,
1,
6,
1,
281,
2,
46,
3982,
0,
0,
4093,
2,
1,
0,
27,
3,
1,
4,
11,
3,
0,
0,
3,
1,
0,
774,
0,
0,
89,
211,
1567,
0,
43,
3982,
135,
4096,
15,
0,
325,
320,
1,
0,
4,
0,
94,
333,
0,
3,
9,
26,
444,
0,
0,
0,
13,
107,
3472,
106,
1,
4079,
0,
2,
0,
0,
0,
1,
0,
0,
0,
0,
329,
1,
0,
55,
0,
0,
0,
32,
4,
0,
1,
11,
2,
0,
0,
0,
0,
4096,
0,
1,
0,
0,
0,
0,
0,
2,
4096,
8,
0,
0,
2,
0,
0,
0,
13,
69,
5,
0,
0,
0,
203,
112,
20,
7,
0,
1,
0,
0,
0,
0,
9,
0,
21,
0,
0,
0,
21,
11,
0,
40,
1,
39,
11,
0,
2,
72,
1,
0,
5,
0,
0,
4001,
0,
2,
1,
0,
0,
1,
0,
86,
0,
4096,
3,
12,
2,
4096,
0,
0,
4096,
0,
0,
0,
3664,
1,
0,
1,
5,
0,
14,
0,
0,
92,
0,
6,
0,
0,
6,
3,
3982,
714,
0,
4096,
2,
0,
0,
1993,
0,
6,
0,
0,
10,
0,
1,
0,
1,
3,
96,
659,
4096,
55,
0,
0,
0,
1,
0,
8,
129,
0,
2,
21,
0,
89,
19,
46,
0,
0,
7,
3,
19,
0,
0,
108,
0,
0,
0,
40,
4096,
1,
40,
36,
0,
2,
0,
0,
71,
2840,
6,
1732,
0,
6,
0,
5,
2,
0,
3,
0,
0,
1,
0,
0,
4,
4096,
4094,
66,
1,
3,
0,
1,
37,
0,
30,
4096,
0,
0,
]
),
1: np.array(
[
450,
1,
3892,
4017,
226,
67,
1321,
3,
0,
214,
1376,
722,
318,
428,
211,
13,
3,
86,
10,
8,
44,
28,
39,
159,
0,
2787,
4083,
0,
219,
8,
4018,
0,
2,
1,
53,
6,
3,
5,
19,
0,
0,
12,
208,
3987,
2171,
0,
27,
0,
30,
0,
1,
11,
0,
4096,
0,
0,
3,
1,
6,
1,
281,
2,
46,
3982,
0,
0,
4093,
2,
1,
0,
27,
3,
1,
4,
11,
3,
0,
0,
3,
1,
0,
774,
0,
0,
89,
211,
1567,
0,
43,
3982,
135,
4096,
15,
0,
325,
320,
1,
0,
4,
0,
94,
333,
0,
3,
9,
26,
444,
0,
0,
0,
13,
107,
3472,
106,
1,
4079,
0,
2,
0,
0,
0,
1,
0,
0,
0,
0,
329,
1,
0,
55,
0,
0,
0,
32,
4,
0,
1,
11,
2,
0,
0,
0,
0,
4096,
0,
1,
0,
0,
0,
0,
0,
2,
4096,
8,
0,
0,
2,
0,
0,
0,
13,
69,
5,
0,
0,
0,
203,
112,
20,
7,
0,
1,
0,
0,
0,
0,
9,
0,
21,
0,
0,
0,
21,
11,
0,
40,
1,
39,
11,
0,
2,
72,
1,
0,
5,
0,
0,
4001,
0,
2,
1,
0,
0,
1,
0,
86,
0,
4096,
3,
12,
2,
4096,
0,
0,
4096,
0,
0,
0,
3664,
1,
0,
1,
5,
0,
14,
0,
0,
92,
0,
6,
0,
0,
6,
3,
3982,
714,
0,
4096,
2,
0,
0,
1993,
0,
6,
0,
0,
10,
0,
1,
0,
1,
3,
96,
659,
4096,
55,
0,
0,
0,
1,
0,
8,
129,
0,
2,
21,
0,
89,
19,
46,
0,
0,
7,
3,
19,
0,
0,
108,
0,
0,
0,
40,
4096,
1,
40,
36,
0,
2,
0,
0,
71,
2840,
6,
1732,
0,
6,
0,
5,
2,
0,
3,
0,
0,
1,
0,
0,
4,
4096,
4094,
66,
1,
3,
0,
1,
37,
0,
30,
4096,
0,
0,
]
),
2: np.array(
[
450,
1,
3892,
4017,
226,
67,
1321,
3,
0,
214,
1376,
722,
318,
428,
211,
13,
3,
86,
10,
8,
44,
28,
39,
159,
0,
2787,
4083,
0,
219,
8,
4018,
0,
2,
1,
53,
6,
3,
5,
19,
0,
0,
12,
208,
3987,
2171,
0,
27,
0,
30,
0,
1,
11,
0,
4096,
0,
0,
3,
1,
6,
1,
281,
2,
46,
3982,
0,
0,
4093,
2,
1,
0,
27,
3,
1,
4,
11,
3,
0,
0,
3,
1,
0,
774,
0,
0,
89,
211,
1567,
0,
43,
3982,
135,
4096,
15,
0,
325,
320,
1,
0,
4,
0,
94,
333,
0,
3,
9,
26,
444,
0,
0,
0,
13,
107,
3472,
106,
1,
4079,
0,
2,
0,
0,
0,
1,
0,
0,
0,
0,
329,
1,
0,
55,
0,
0,
0,
32,
4,
0,
1,
11,
2,
0,
0,
0,
0,
4096,
0,
1,
0,
0,
0,
0,
0,
2,
4096,
8,
0,
0,
2,
0,
0,
0,
13,
69,
5,
0,
0,
0,
203,
112,
20,
7,
0,
1,
0,
0,
0,
0,
9,
0,
21,
0,
0,
0,
21,
11,
0,
40,
1,
39,
11,
0,
2,
72,
1,
0,
5,
0,
0,
4001,
0,
2,
1,
0,
0,
1,
0,
86,
0,
4096,
3,
12,
2,
4096,
0,
0,
4096,
0,
0,
0,
3664,
1,
0,
1,
5,
0,
14,
0,
0,
92,
0,
6,
0,
0,
6,
3,
3982,
714,
0,
4096,
2,
0,
0,
1993,
0,
6,
0,
0,
10,
0,
1,
0,
1,
3,
96,
659,
4096,
55,
0,
0,
0,
1,
0,
8,
129,
0,
2,
21,
0,
89,
19,
46,
0,
0,
7,
3,
19,
0,
0,
108,
0,
0,
0,
40,
4096,
1,
40,
36,
0,
2,
0,
0,
71,
2840,
6,
1732,
0,
6,
0,
5,
2,
0,
3,
0,
0,
1,
0,
0,
4,
4096,
4094,
66,
1,
3,
0,
1,
37,
0,
30,
4096,
0,
0,
]
),
3: np.array(
[
450,
1,
3892,
4017,
226,
67,
1321,
3,
0,
214,
1376,
722,
318,
428,
211,
13,
3,
86,
10,
8,
44,
28,
39,
159,
0,
2787,
4083,
0,
219,
8,
4018,
0,
2,
1,
53,
6,
3,
5,
19,
0,
0,
12,
208,
3987,
2171,
0,
27,
0,
30,
0,
1,
11,
0,
4096,
0,
0,
3,
1,
6,
1,
281,
2,
46,
3982,
0,
0,
4093,
2,
1,
0,
27,
3,
1,
4,
11,
3,
0,
0,
3,
1,
0,
774,
0,
0,
89,
211,
1567,
0,
43,
3982,
135,
4096,
15,
0,
325,
320,
1,
0,
4,
0,
94,
333,
0,
3,
9,
26,
444,
0,
0,
0,
13,
107,
3472,
106,
1,
4079,
0,
2,
0,
0,
0,
1,
0,
0,
0,
0,
329,
1,
0,
55,
0,
0,
0,
32,
4,
0,
1,
11,
2,
0,
0,
0,
0,
4096,
0,
1,
0,
0,
0,
0,
0,
2,
4096,
8,
0,
0,
2,
0,
0,
0,
13,
69,
5,
0,
0,
0,
203,
112,
20,
7,
0,
1,
0,
0,
0,
0,
9,
0,
21,
0,
0,
0,
21,
11,
0,
40,
1,
39,
11,
0,
2,
72,
1,
0,
5,
0,
0,
4001,
0,
2,
1,
0,
0,
1,
0,
86,
0,
4096,
3,
12,
2,
4096,
0,
0,
4096,
0,
0,
0,
3664,
1,
0,
1,
5,
0,
14,
0,
0,
92,
0,
6,
0,
0,
6,
3,
3982,
714,
0,
4096,
2,
0,
0,
1993,
0,
6,
0,
0,
10,
0,
1,
0,
1,
3,
96,
659,
4096,
55,
0,
0,
0,
1,
0,
8,
129,
0,
2,
21,
0,
89,
19,
46,
0,
0,
7,
3,
19,
0,
0,
108,
0,
0,
0,
40,
4096,
1,
40,
36,
0,
2,
0,
0,
71,
2840,
6,
1732,
0,
6,
0,
5,
2,
0,
3,
0,
0,
1,
0,
0,
4,
4096,
4094,
66,
1,
3,
0,
1,
37,
0,
30,
4096,
0,
0,
]
),
}
EXPERT_DICT = {
0: list(range(0, 80)), # 第0组: 0-79
1: list(range(80, 160)), # 第1组: 80-159
2: list(range(160, 240)), # 第2组: 160-239
3: list(range(240, 320)), # 第3组: 240-319
}
EXPERT_DICT_REDUNDANT = {
0: list(range(0, 84)), # 第0组: 0-83
1: list(range(80, 164)), # 第1组: 80-163
2: list(range(160, 244)), # 第2组: 160-243
3: list(range(240, 324)), # 第3组: 240-323
}
@unittest.skipIf(True, "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU.")
class TestEplbScheduler(unittest.TestCase):
def test_A2A_algo(self):
result = eplb_greedy(RESPONSE, "A2A", EXPERT_DICT, world_size=4, expert_num=320)
self.assertIsNotNone(result)
update, device_indices_list, local_expert_indices_list, local_expert_list, expert_trans_tensor = result
self.assertEqual(update, True)
self.assertEqual(len(device_indices_list), 4)
self.assertEqual(len(local_expert_indices_list), 4)
self.assertEqual(len(local_expert_indices_list[0]), 320)
for rank, experts in enumerate(local_expert_list):
unique_experts = set(experts)
self.assertEqual(len(experts), len(unique_experts))
expected_shape = (320, 320)
self.assertEqual(expert_trans_tensor.shape, expected_shape)
def test_AG_algo(self):
result = eplb_greedy(RESPONSE, "AG", EXPERT_DICT, world_size=4, expert_num=320)
self.assertIsNotNone(result)
update, device_indices_list, local_expert_indices_list, local_expert_list, expert_trans_tensor = result
self.assertEqual(update, True)
self.assertEqual(len(device_indices_list), 4)
self.assertEqual(len(local_expert_indices_list), 4)
self.assertEqual(len(local_expert_indices_list[0]), 320)
for rank, experts in enumerate(local_expert_list):
unique_experts = set(experts)
self.assertEqual(len(experts), len(unique_experts))
expected_shape = (320, 320)
self.assertEqual(expert_trans_tensor.shape, expected_shape)
def test_EX_algo(self):
result = eplb_greedy(RESPONSE, "EX", EXPERT_DICT, world_size=4, expert_num=320)
self.assertIsNotNone(result)
update, device_indices_list, local_expert_indices_list, local_expert_list, expert_trans_tensor = result
self.assertEqual(update, True)
self.assertEqual(len(device_indices_list), 4)
self.assertEqual(len(local_expert_indices_list), 4)
self.assertEqual(len(local_expert_indices_list[0]), 320)
for rank, experts in enumerate(local_expert_list):
unique_experts = set(experts)
self.assertEqual(len(experts), len(unique_experts))
expected_shape = (320, 320)
self.assertEqual(expert_trans_tensor.shape, expected_shape)
def test_A2A_redundant_algo(self):
result = eplb_greedy(RESPONSE, "A2A", EXPERT_DICT_REDUNDANT, 4, 320, 84, 4)
self.assertIsNotNone(result)
_, _, _, local_expert_list, _ = result
for rank, experts in enumerate(local_expert_list):
unique_experts = set(experts)
self.assertEqual(len(experts), len(unique_experts))
self.assertEqual(len(unique_experts), 84)
def test_AG_redundant_algo(self):
result = eplb_greedy(RESPONSE, "AG", EXPERT_DICT_REDUNDANT, 4, 320, 84, 4)
self.assertIsNotNone(result)
_, _, _, local_expert_list, _ = result
for rank, experts in enumerate(local_expert_list):
unique_experts = set(experts)
self.assertEqual(len(experts), len(unique_experts))
self.assertEqual(len(unique_experts), 84)
if __name__ == '__main__':
unittest.main()