import os
import sys
from argparse import Namespace
from datetime import timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
import torch
import torch_npu
from megatron.training import global_vars
from mindspeed.core.qos.adaptor import create_group_qos
from mindspeed.core.qos.domain_info import ParallelCommDomain, RankGenerator
from mindspeed.core.qos.qos import (
Qos,
_DEFAULT_QOS,
_DEFAULT_QOS_ROCE_HIGH,
_DEFAULT_QOS_ROCE_LOW,
_DEFAULT_QOS_ROCE_MIDDLE,
_DEFAULT_QOS_SDMA_HIGH,
_DEFAULT_QOS_SDMA_LOW,
_DEFAULT_QOS_SDMA_MIDDLE,
_PARALLEL_TYPES,
domains,
roce_qos_str_to_value,
sdma_qos_str_to_value,
)
GLOBAL_RANK_PARAMS = Namespace(
tp=2,
pp=4,
dp=8,
ep=8,
cp=1,
order='tp-cp-ep-dp-pp',
rank_offset=0,
world_size=2 * 4 * 8 * 1,
tensor_parallel_comm_domain=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17],
[18, 19], [20, 21], [22, 23], [24, 25],
[26, 27], [28, 29], [30, 31], [32, 33], [34, 35], [36, 37], [38, 39], [40, 41],
[42, 43], [44, 45], [46, 47], [48, 49],
[50, 51], [52, 53], [54, 55], [56, 57], [58, 59], [60, 61], [62, 63]],
pipeline_parallel_comm_domain=[[0, 16, 32, 48], [1, 17, 33, 49], [2, 18, 34, 50], [3, 19, 35, 51], [4, 20, 36, 52],
[5, 21, 37, 53], [6, 22, 38, 54],
[7, 23, 39, 55], [8, 24, 40, 56], [9, 25, 41, 57], [10, 26, 42, 58],
[11, 27, 43, 59], [12, 28, 44, 60], [13, 29, 45, 61],
[14, 30, 46, 62], [15, 31, 47, 63]],
data_parallel_comm_domain=[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15],
[16, 18, 20, 22, 24, 26, 28, 30], [17, 19, 21, 23, 25, 27, 29, 31],
[32, 34, 36, 38, 40, 42, 44, 46], [33, 35, 37, 39, 41, 43, 45, 47],
[48, 50, 52, 54, 56, 58, 60, 62], [49, 51, 53, 55, 57, 59, 61, 63]],
context_parallel_comm_domain=[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15],
[16], [17], [18], [19], [20], [21],
[22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35],
[36], [37], [38], [39], [40], [41], [42],
[43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56],
[57], [58], [59], [60], [61], [62], [63]],
expert_parallel_comm_domain=[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15],
[16, 18, 20, 22, 24, 26, 28, 30], [17, 19, 21, 23, 25, 27, 29, 31],
[32, 34, 36, 38, 40, 42, 44, 46], [33, 35, 37, 39, 41, 43, 45, 47],
[48, 50, 52, 54, 56, 58, 60, 62], [49, 51, 53, 55, 57, 59, 61, 63]]
)
def reset_quality_of_service_singleton():
Qos._initialize = False
Qos._instance = None
if hasattr(global_vars, '_GLOBAL_ARGS'):
global_vars._GLOBAL_ARGS = None
class TestQos:
def test_qos_manual(self):
with patch('mindspeed.core.qos.qos.get_args') as mock_get_args:
mock_get_args.return_value = SimpleNamespace(
aiqos_mode="manual",
aiqos_schedule="{tp:low,pp:high,dp-cp:high,ep:high,cp:middle,pos-embd:middle}"
)
qos = Qos()
assert qos.aiqos_mode == "manual"
assert qos.sdma_aiqos_schedule == {'tp': 2, 'pp': 6, 'dp-cp': 6, 'ep': 6, 'cp': 4, 'pos-embd': 4}
assert qos.roce_aiqos_schedule == {'tp': 3, 'pp': 5, 'dp-cp': 5, 'ep': 5, 'cp': 4, 'pos-embd': 4}
assert qos._initialize == True
def test_qos_auto_enable_roce(self):
g = RankGenerator(tp=GLOBAL_RANK_PARAMS.tp, ep=GLOBAL_RANK_PARAMS.ep, dp=GLOBAL_RANK_PARAMS.dp,
pp=GLOBAL_RANK_PARAMS.pp, cp=GLOBAL_RANK_PARAMS.cp, order='tp-cp-ep-dp-pp')
ep_group_ranks = g.get_ranks('ep', independent_ep=True)
tp_group_ranks = g.get_ranks('tp')
pp_group_ranks = g.get_ranks('pp')
dp_group_ranks = g.get_ranks('dp')
cp_group_ranks = g.get_ranks('cp')
assert tp_group_ranks == GLOBAL_RANK_PARAMS.tensor_parallel_comm_domain
assert pp_group_ranks == GLOBAL_RANK_PARAMS.pipeline_parallel_comm_domain
assert dp_group_ranks == GLOBAL_RANK_PARAMS.data_parallel_comm_domain
assert cp_group_ranks == GLOBAL_RANK_PARAMS.context_parallel_comm_domain
assert ep_group_ranks == GLOBAL_RANK_PARAMS.expert_parallel_comm_domain
tp_info = ParallelCommDomain(ip_list=None, rank_list=tp_group_ranks, world_size=GLOBAL_RANK_PARAMS.tp,
parallel_type='tp', comm_amount=4096, comm_amount_no_overlap=2048)
pp_info = ParallelCommDomain(ip_list=None, rank_list=pp_group_ranks, world_size=GLOBAL_RANK_PARAMS.pp,
parallel_type='pp', comm_amount=40960, comm_amount_no_overlap=20480)
dp_info = ParallelCommDomain(ip_list=None, rank_list=dp_group_ranks, world_size=GLOBAL_RANK_PARAMS.dp,
parallel_type='dp', comm_amount=1314, comm_amount_no_overlap=520)
cp_info = ParallelCommDomain(ip_list=None, rank_list=cp_group_ranks, world_size=GLOBAL_RANK_PARAMS.cp,
parallel_type='cp', comm_amount=512, comm_amount_no_overlap=256)
ep_info = ParallelCommDomain(ip_list=None, rank_list=ep_group_ranks, world_size=GLOBAL_RANK_PARAMS.ep,
parallel_type='ep', comm_amount=131072, comm_amount_no_overlap=81920)
with patch('mindspeed.core.qos.domain_info.get_args') as mock_domain_get_args, \
patch('mindspeed.core.qos.qos.get_args') as mock_qos_get_args, \
patch('mindspeed.core.qos.domain_info.is_a3_version', new=True), \
patch('mindspeed.core.qos.qos.is_a3_version', new=True), \
patch('mindspeed.core.qos.domain_info.get_overlap_space_dict') as mock_space_dict, \
patch('mindspeed.core.qos.qos.get_tensor_parallel_comm_domain', return_value=tp_info), \
patch('mindspeed.core.qos.qos.get_data_parallel_comm_domain', return_value=dp_info), \
patch('mindspeed.core.qos.qos.get_pipeline_parallel_comm_domain', return_value=pp_info), \
patch('mindspeed.core.qos.qos.get_expert_parallel_comm_domain', return_value=ep_info), \
patch('mindspeed.core.qos.qos.get_context_parallel_comm_domain', return_value=cp_info), \
patch('mindspeed.core.qos.qos.log_rank_0'):
all_keys = [(x, y) for x in ('tp', 'dp', 'pp', 'ep', 'cp') for y in ('tp', 'dp', 'pp', 'ep', 'cp')]
space_overlap_res = {key: 0 for key in all_keys}
mock_space_dict.return_value = space_overlap_res
mock_args = Namespace(
aiqos_mode="auto",
aiqos_enable_roce=True,
num_experts=32,
overlap_grad_reduce=True,
overlap_param_gather=True,
)
global_vars._GLOBAL_ARGS = mock_args
mock_domain_get_args.return_value = mock_args
mock_qos_get_args.return_value = mock_args
reset_quality_of_service_singleton()
assert Qos._initialize is False
assert Qos._instance is None
qos = Qos()
assert qos is not None
assert qos.aiqos_mode == "auto"
assert qos.set_parallel_roce_qos('tp') == 4
assert qos.set_parallel_roce_qos('pp') == 4
assert qos.set_parallel_roce_qos('dp') == 4
assert qos.set_parallel_roce_qos('cp') == 4
assert qos.set_parallel_roce_qos('ep') == 4
assert qos.set_parallel_sdma_qos('tp') == 6
assert qos.set_parallel_sdma_qos('pp') == 4
assert qos.set_parallel_sdma_qos('dp') == 6
assert qos.set_parallel_sdma_qos('cp') == 2
assert qos.set_parallel_sdma_qos('ep') == 2
def test_create_qos_group(self):
tp_ranks = GLOBAL_RANK_PARAMS.tensor_parallel_comm_domain
timeout = timedelta(seconds=120)
pg_initialized = False
with patch.dict(os.environ, {'MASTER_ADDR': "localhost", 'MASTER_PORT': "6666"}, clear=False), \
patch('mindspeed.core.qos.adaptor.get_args') as mock_adaptor_get_args, \
patch('mindspeed.core.qos.qos.get_args') as mock_qos_get_args:
mock_args = SimpleNamespace(
aiqos_mode="manual",
aiqos_enable_roce=True,
aiqos_schedule="{tp:low,pp:high,dp-cp:high,ep:high,cp:middle,pos-embd:middle}"
)
global_vars._GLOBAL_ARGS = mock_args
mock_adaptor_get_args.return_value = mock_args
mock_qos_get_args.return_value = mock_args
try:
torch.npu.set_device(0)
torch.distributed.init_process_group(backend='hccl', rank=0, world_size=1, timeout=timeout)
pg_initialized = True
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
sdma_qos = 6
roce_qos = 5
group = create_group_qos(
[0],
timeout=timeout,
pg_options=pg_options,
group_desc='DATA_PARALLEL_GROUP',
parallel_type='dp'
)
pg_options.hccl_config = {'hccl_sdma_qos': sdma_qos, 'qos_service_level': roce_qos,
'qos_traffic_class': roce_qos * 32}
assert group is not None, "Group creation failed"
hccl_cfg = pg_options.hccl_config
assert 'hccl_sdma_qos' in hccl_cfg, "Missing key: hccl_sdma_qos"
assert hccl_cfg['hccl_sdma_qos'] == sdma_qos
assert 'qos_service_level' in hccl_cfg, "Missing key: qos_service_level"
assert hccl_cfg['qos_service_level'] == roce_qos
assert 'qos_traffic_class' in hccl_cfg, "Missing key: qos_traffic_class"
assert hccl_cfg['qos_traffic_class'] == roce_qos * 32
finally:
if pg_initialized:
torch.distributed.destroy_process_group()
global_vars._GLOBAL_ARGS = None
reset_quality_of_service_singleton()