import os
from logging import getLogger
from typing import (
Optional,
Callable,
List
)
from functools import partial
from datetime import timedelta
import torch
import torch_npu
from megatron.core.utils import GlobalMemoryBuffer, is_torch_min_version
from megatron.core.parallel_state import (
default_embedding_ranks,
default_position_embedding_ranks,
RankGenerator,
create_group,
get_nccl_options,
get_data_parallel_group,
create_hierarchical_parallel_groups,
_set_global_memory_buffer
)
import megatron.core.parallel_state as mcps
from megatron.training import get_args
from mindspeed.core.qos.qos import Qos
from mindspeed.log_config import log_rank_0
from mindspeed.core.qos.domain_info import is_a3_version
LOG = getLogger(__name__)
def create_group_qos(
ranks=None,
timeout=None,
backend=None,
pg_options=None,
use_local_synchronization=False,
group_desc=None,
parallel_type=None
):
"""Creates a Qos ProcessGroup."""
kwargs = {
'ranks': ranks,
'timeout': timeout,
'backend': backend,
'pg_options': pg_options,
'use_local_synchronization': use_local_synchronization,
'group_desc': group_desc,
}
if not is_torch_min_version('2.4.0'):
kwargs.pop('group_desc')
if timeout is None:
kwargs.pop('timeout')
if pg_options is None:
kwargs['pg_options'] = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
ai_qos = Qos()
roce_qos = ai_qos.set_parallel_roce_qos(parallel_type)
sdma_qos = ai_qos.set_parallel_sdma_qos(parallel_type)
if not (0 <= roce_qos <= 7) or not (0 <= sdma_qos <= 7):
error_msg_parts = []
if not (0 <= roce_qos <= 7):
error_msg_parts.append(f"roce_qos={roce_qos} (valid range: 0-7)")
if not (0 <= sdma_qos <= 7):
error_msg_parts.append(f"sdma_qos={sdma_qos} (valid range: 0-7)")
raise ValueError(
f"Invalid QoS value for parallel type '{parallel_type}'! "
+ " | ".join(error_msg_parts)
)
args = get_args()
if is_a3_version:
if args.aiqos_enable_roce:
kwargs['pg_options'].hccl_config = {'hccl_sdma_qos': sdma_qos, 'qos_service_level': roce_qos,
'qos_traffic_class': roce_qos * 32}
log_rank_0(LOG.info, f"{parallel_type} roce_qos: {roce_qos}, sdma_qos: {sdma_qos}")
else:
kwargs['pg_options'].hccl_config = {'hccl_sdma_qos': sdma_qos}
log_rank_0(LOG.info, f"{parallel_type} sdma_qos: {sdma_qos}")
else:
kwargs['pg_options'].hccl_config = {'qos_service_level': roce_qos, 'qos_traffic_class': roce_qos * 32}
log_rank_0(LOG.info, f"{parallel_type} roce_qos: {roce_qos}")
return torch.distributed.new_group(**kwargs)
def create_hierarchical_parallel_groups_qos(
rank, ranks, group_size, hierarchical_group_sizes, pg_options, parallel_type
):
"""Create hierarchical groups for one parallelism.
Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 ... g15.
If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level
of sub-groups, and 4 GPUs in the last level of sub groups. The present function will
create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as:
8 level-1 sub-groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 level-2 sub-groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
4 level-3 sub-groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
"""
hierarchical_groups = []
accumulated_group_sizes = 1
processed_group_sizes = 1
for level, hierarchical_group_size in enumerate(hierarchical_group_sizes):
accumulated_group_sizes *= hierarchical_group_size
for k in range(group_size // accumulated_group_sizes):
for j in range(processed_group_sizes):
global_sub_ranks = [
ranks[j + i * processed_group_sizes + k * accumulated_group_sizes]
for i in range(hierarchical_group_size)
]
sub_group = create_group_qos(
global_sub_ranks,
pg_options=pg_options,
group_desc=f'HIERARCHICAL_CONTEXT_PARALLEL_GROUP_L{level}',
parallel_type=parallel_type
)
if rank in global_sub_ranks:
hierarchical_groups.append(sub_group)
processed_group_sizes *= hierarchical_group_size
return hierarchical_groups
def initialize_model_parallel_qos(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
pipeline_model_parallel_comm_backend: Optional[str] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
hierarchical_context_parallel_sizes: Optional[List[int]] = None,
expert_model_parallel_size: int = 1,
num_distributed_optimizer_instances: int = 1,
expert_tensor_parallel_size: Optional[int] = None,
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
order: str = "tp-cp-ep-dp-pp",
encoder_tensor_model_parallel_size: int = 0,
encoder_pipeline_model_parallel_size: Optional[int] = 0,
get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
create_gloo_process_groups: bool = True,
) -> None:
if encoder_pipeline_model_parallel_size is None:
encoder_pipeline_model_parallel_size = 0
if encoder_tensor_model_parallel_size == 0 and encoder_pipeline_model_parallel_size > 0:
encoder_tensor_model_parallel_size = tensor_model_parallel_size
if get_embedding_ranks is None:
get_embedding_ranks = partial(
default_embedding_ranks, split_rank=pipeline_model_parallel_split_rank
)
if get_position_embedding_ranks is None:
get_position_embedding_ranks = partial(
default_position_embedding_ranks, split_rank=pipeline_model_parallel_split_rank
)
if encoder_pipeline_model_parallel_size > 0:
mcps._PIPELINE_MODEL_PARALLEL_DECODER_START = encoder_pipeline_model_parallel_size
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed has not been initialized yet.")
world_size: int = torch.distributed.get_world_size()
if encoder_tensor_model_parallel_size > 0:
if not (encoder_tensor_model_parallel_size <= tensor_model_parallel_size):
raise RuntimeError(
"encoder_tensor_model_parallel_size must be less than or equal to tensor_model_parallel_size.")
encoder_model_size = (
encoder_tensor_model_parallel_size
* encoder_pipeline_model_parallel_size
* context_parallel_size
)
decoder_model_size = (
tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
)
total_model_size = encoder_model_size + decoder_model_size
if world_size % total_model_size != 0:
raise RuntimeError(f"world_size ({world_size}) is not divisible by {total_model_size}")
data_parallel_size: int = world_size // total_model_size
encoder_world_size = encoder_model_size * data_parallel_size
decoder_world_size = decoder_model_size * data_parallel_size
if not (encoder_world_size + decoder_world_size == world_size):
raise RuntimeError(f"{encoder_world_size=} + {decoder_world_size=} != {world_size=}")
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 1:
raise RuntimeError(
"pipeline-model-parallel size should be greater than 1 with interleaved schedule"
)
mcps._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
mcps._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank is not None:
mcps._PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank()
nccl_comm_cfgs = {}
if nccl_communicator_config_path is not None:
try:
import yaml
except ImportError as e:
raise RuntimeError(
"Cannot import `yaml`. Setting custom nccl communicator configs "
"requires the yaml package."
) from e
with open(nccl_communicator_config_path, "r") as stream:
nccl_comm_cfgs = yaml.safe_load(stream)
if encoder_world_size > 0:
encoder_rank_generator = RankGenerator(
tp=encoder_tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=encoder_pipeline_model_parallel_size,
cp=context_parallel_size,
order=order,
rank_offset=0,
)
else:
encoder_rank_generator = None
decoder_rank_generator = RankGenerator(
tp=tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order=order,
rank_offset=encoder_world_size,
)
if expert_tensor_parallel_size is None:
expert_tensor_parallel_size = tensor_model_parallel_size
expert_tensor_model_pipeline_parallel_size = (
expert_tensor_parallel_size * expert_model_parallel_size * pipeline_model_parallel_size
)
expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size
if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0:
raise RuntimeError(
f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})"
)
expert_decoder_rank_generator = RankGenerator(
tp=expert_tensor_parallel_size,
ep=expert_model_parallel_size,
dp=expert_data_parallel_size,
pp=pipeline_model_parallel_size,
cp=1,
order=order,
rank_offset=encoder_world_size,
)
condition = (
order.endswith("pp")
or pipeline_model_parallel_size == 1
or expert_data_parallel_size == data_parallel_size
)
if not condition:
raise RuntimeError(
"When not using pp-last rank ordering, the data parallel size of the attention and moe layers must be the same"
)
decoder_pp_ranks = decoder_rank_generator.get_ranks("pp")
expert_decoder_pp_ranks = expert_decoder_rank_generator.get_ranks("pp")
if not (decoder_pp_ranks == expert_decoder_pp_ranks):
raise RuntimeError(
f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, "
f"but got {decoder_pp_ranks} and {expert_decoder_pp_ranks}"
)
def generator_wrapper(group_type, is_expert=False, **kwargs):
if is_expert:
d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs)
else:
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if encoder_rank_generator is None:
for x in d_ranks:
yield x
return
e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
if group_type == 'pp':
rep = len(d_ranks) // len(e_ranks)
remain = len(d_ranks) % len(e_ranks)
e_ind = 0
e_rep = rep + int(e_ind < remain)
for _, y in enumerate(d_ranks):
x = e_ranks[e_ind]
e_rep -= 1
if e_rep == 0:
e_ind += 1
e_rep = rep + int(e_ind < remain)
yield x + y
elif group_type == 'tp-pp':
if not (len(e_ranks) == len(d_ranks)):
raise RuntimeError(
f"The length of e_ranks ({len(e_ranks)}) does not match the length of d_ranks ({len(d_ranks)}).")
for x, y in zip(e_ranks, d_ranks):
yield x + y
else:
for x in e_ranks:
yield x
for x in d_ranks:
yield x
timeout = timedelta(minutes=distributed_timeout_minutes)
if not (mcps._DATA_PARALLEL_GROUP is None):
raise RuntimeError('data parallel group is already initialized')
for ranks in generator_wrapper('dp'):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('dp', nccl_comm_cfgs),
group_desc='DATA_PARALLEL_GROUP',
parallel_type='dp'
)
if create_gloo_process_groups:
group_gloo = create_group(
ranks, timeout=timeout, backend="gloo", group_desc='DATA_PARALLEL_GROUP_GLOO'
)
else:
group_gloo = None
if rank in ranks:
mcps._DATA_PARALLEL_GROUP = group
mcps._DATA_PARALLEL_GROUP_GLOO = group_gloo
mcps._DATA_PARALLEL_GLOBAL_RANKS = ranks
if not ((data_parallel_size * context_parallel_size) % num_distributed_optimizer_instances == 0):
raise RuntimeError(
'Data parallel size should be divisible by partial DistOpt shard factor'
)
intra_partial_data_parallel_size = (
data_parallel_size * context_parallel_size
) // num_distributed_optimizer_instances
for ranks_with_cp in generator_wrapper('dp-cp'):
group_with_cp = create_group_qos(
ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs),
group_desc='DATA_PARALLEL_GROUP_WITH_CP',
parallel_type='dp-cp'
)
if create_gloo_process_groups:
group_with_cp_gloo = create_group(
ranks_with_cp,
timeout=timeout,
backend="gloo",
group_desc='DATA_PARALLEL_GROUP_WITH_CP_GLOO',
)
else:
group_with_cp_gloo = None
if rank in ranks_with_cp:
mcps._DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
mcps._DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
mcps._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
if num_distributed_optimizer_instances > 1:
for i in range(num_distributed_optimizer_instances):
intra_partial_data_parallel_ranks_with_cp = ranks_with_cp[
(i * intra_partial_data_parallel_size): (
(i + 1) * intra_partial_data_parallel_size
)
]
intra_partial_data_parallel_group_with_cp = create_group_qos(
intra_partial_data_parallel_ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('intra_dp_cp', nccl_comm_cfgs),
group_desc='INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
parallel_type='intra-dp-cp'
)
if create_gloo_process_groups:
intra_partial_data_parallel_group_with_cp_gloo = create_group(
intra_partial_data_parallel_ranks_with_cp,
timeout=timeout,
backend="gloo",
group_desc='INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO',
)
else:
intra_partial_data_parallel_group_with_cp_gloo = None
if rank in intra_partial_data_parallel_ranks_with_cp:
mcps._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = (
intra_partial_data_parallel_group_with_cp
)
mcps._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = (
intra_partial_data_parallel_group_with_cp_gloo
)
for i in range(intra_partial_data_parallel_size):
inter_partial_data_parallel_ranks_with_cp = ranks_with_cp[
i::intra_partial_data_parallel_size
]
inter_partial_data_parallel_group_with_cp = create_group_qos(
inter_partial_data_parallel_ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('inter_dp_cp', nccl_comm_cfgs),
group_desc='INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
parallel_type='inter-dp-cp'
)
if rank in inter_partial_data_parallel_ranks_with_cp:
mcps._INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = (
inter_partial_data_parallel_group_with_cp
)
else:
mcps._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = mcps._DATA_PARALLEL_GROUP_WITH_CP
mcps._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = mcps._DATA_PARALLEL_GROUP_WITH_CP_GLOO
if use_sharp:
if rank == 0:
print(
"The number of process groups to use SHARP with depends on the type "
"of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
"process groups and QM2 supports up to 256 process groups. We apply "
"SHARP to the communications of the data-parallel domain. If the "
"number of data-parallel process groups is larger than the max "
"process groups that the network switch supports, the communication "
"will fall back to non-SHARP operators. To enable SHARP, "
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
torch.distributed.barrier(
group=get_data_parallel_group(with_context_parallel=True),
device_ids=[torch.cuda.current_device()],
)
os.environ["NCCL_COLLNET_ENABLE"] = "0"
if not (mcps._CONTEXT_PARALLEL_GROUP is None):
raise RuntimeError('context parallel group is already initialized')
for ranks in generator_wrapper('cp'):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('cp', nccl_comm_cfgs),
group_desc='CONTEXT_PARALLEL_GROUP',
parallel_type='cp'
)
if rank in ranks:
mcps._CONTEXT_PARALLEL_GROUP = group
mcps._CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
if hierarchical_context_parallel_sizes:
mcps._HIERARCHICAL_CONTEXT_PARALLEL_GROUPS += create_hierarchical_parallel_groups_qos(
rank,
ranks,
context_parallel_size,
hierarchical_context_parallel_sizes,
get_nccl_options('hcp', nccl_comm_cfgs),
parallel_type='hcp'
)
if not (mcps._MODEL_PARALLEL_GROUP is None):
raise RuntimeError('model parallel group is already initialized')
for ranks in generator_wrapper('tp-pp'):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('mp', nccl_comm_cfgs),
group_desc='MODEL_PARALLEL_GROUP',
parallel_type='mp'
)
if rank in ranks:
mcps._MODEL_PARALLEL_GROUP = group
mcps._MODEL_PARALLEL_GLOBAL_RANKS = ranks
if not (mcps._TENSOR_MODEL_PARALLEL_GROUP is None):
raise RuntimeError('tensor model parallel group is already initialized')
for ranks in generator_wrapper('tp'):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp', nccl_comm_cfgs),
group_desc='TENSOR_MODEL_PARALLEL_GROUP',
parallel_type='tp'
)
if rank in ranks:
mcps._TENSOR_MODEL_PARALLEL_GROUP = group
mcps._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
if not (mcps._PIPELINE_MODEL_PARALLEL_GROUP is None):
raise RuntimeError('pipeline model parallel group is already initialized')
if not (mcps._EMBEDDING_GROUP is None):
raise RuntimeError('embedding group is already initialized')
if not (mcps._POSITION_EMBEDDING_GROUP is None):
raise RuntimeError('position embedding group is already initialized')
if pipeline_model_parallel_comm_backend == 'ucc':
if 'CUDA_DEVICE_MAX_CONNECTIONS' in os.environ:
if not (os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] != '1'):
raise RuntimeError("UCC-backend requires CUDA_DEVICE_MAX_CONNECTIONS > 1")
os.environ['TORCH_UCC_BLOCKING_WAIT'] = (
os.environ['TORCH_UCC_BLOCKING_WAIT']
if "TORCH_UCC_BLOCKING_WAIT" in os.environ
else 'none'
)
os.environ['UCC_EC_CUDA_STREAM_TASK_MODE'] = (
os.environ['UCC_EC_CUDA_STREAM_TASK_MODE']
if "UCC_EC_CUDA_STREAM_TASK_MODE" in os.environ
else 'driver'
)
os.environ['UCX_TLS'] = (
os.environ['UCX_TLS'] if "UCX_TLS" in os.environ else 'ib,cuda_copy'
)
os.environ['NSYS_UCP_COMM_PARAMS'] = '1'
os.environ['UCX_RNDV_THRESH'] = '0'
os.environ['UCX_NET_DEVICES'] = 'all'
os.environ['UCC_CL_BASIC_TLS'] = '^sharp,nccl'
for ranks in generator_wrapper('pp'):
group = create_group_qos(
ranks,
timeout=timeout,
backend=pipeline_model_parallel_comm_backend,
pg_options=(
None
if pipeline_model_parallel_comm_backend == 'ucc'
else get_nccl_options('pp', nccl_comm_cfgs)
),
group_desc='PIPELINE_MODEL_PARALLEL_GROUP',
parallel_type='pp'
)
if not (
pipeline_model_parallel_comm_backend is None
or pipeline_model_parallel_comm_backend == 'nccl'
or pipeline_model_parallel_comm_backend == 'ucc'
):
raise RuntimeError(
f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported')
if rank in ranks:
if mcps._PIPELINE_MODEL_PARALLEL_GROUP is None:
mcps._PIPELINE_MODEL_PARALLEL_GROUP = group
mcps._PIPELINE_GLOBAL_RANKS = ranks
elif isinstance(mcps._PIPELINE_GLOBAL_RANKS[0], list):
mcps._PIPELINE_MODEL_PARALLEL_GROUP.append(group)
mcps._PIPELINE_GLOBAL_RANKS.append(ranks)
else:
mcps._PIPELINE_MODEL_PARALLEL_GROUP = [mcps._PIPELINE_MODEL_PARALLEL_GROUP, group]
mcps._PIPELINE_GLOBAL_RANKS = [mcps._PIPELINE_GLOBAL_RANKS, ranks]
embedding_ranks = get_embedding_ranks(ranks)
group = create_group_qos(
embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options('embd', nccl_comm_cfgs),
group_desc='EMBEDDING_GROUP',
parallel_type='embd'
)
if rank in embedding_ranks:
mcps._EMBEDDING_GROUP = group
mcps._EMBEDDING_GLOBAL_RANKS = embedding_ranks
position_embedding_ranks = get_position_embedding_ranks(ranks)
group = create_group_qos(
position_embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options('pos_embd', nccl_comm_cfgs),
group_desc='POSITION_EMBEDDING_GROUP',
parallel_type='pos-embd'
)
if rank in position_embedding_ranks:
mcps._POSITION_EMBEDDING_GROUP = group
mcps._POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
if not (mcps._TENSOR_AND_DATA_PARALLEL_GROUP is None):
raise RuntimeError('Tensor + data parallel group is already initialized')
for ranks in generator_wrapper('tp-dp-cp'):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs),
group_desc='TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP',
)
if rank in ranks:
mcps._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
for ranks in generator_wrapper('tp-dp'):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs),
group_desc='TENSOR_AND_DATA_PARALLEL_GROUP',
parallel_type='tp-dp'
)
if rank in ranks:
mcps._TENSOR_AND_DATA_PARALLEL_GROUP = group
if not (mcps._TENSOR_AND_CONTEXT_PARALLEL_GROUP is None):
raise RuntimeError('Tensor + context parallel group is already initialized')
for ranks in generator_wrapper('tp-cp'):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_cp', nccl_comm_cfgs),
group_desc='TENSOR_AND_CONTEXT_PARALLEL_GROUP',
parallel_type='tp-cp'
)
if rank in ranks:
mcps._TENSOR_AND_CONTEXT_PARALLEL_GROUP = group
if not (mcps._EXPERT_MODEL_PARALLEL_GROUP is None):
raise RuntimeError('Expert parallel group is already initialized')
for ranks in generator_wrapper('ep', is_expert=True):
group = create_group_qos(
ranks,
pg_options=get_nccl_options('ep', nccl_comm_cfgs),
group_desc='EXPERT_MODEL_PARALLEL_GROUP',
parallel_type='ep'
)
if rank in ranks:
mcps._EXPERT_MODEL_PARALLEL_GROUP = group
if not (mcps._EXPERT_TENSOR_PARALLEL_GROUP is None):
raise RuntimeError('Expert tensor model parallel group is already initialized')
for ranks in generator_wrapper('tp', is_expert=True):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('ep_tp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_PARALLEL_GROUP',
parallel_type='ep-tp'
)
if rank in ranks:
mcps._EXPERT_TENSOR_PARALLEL_GROUP = group
if not (mcps._EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is None):
raise RuntimeError('Expert tensor + model parallel group is already initialized')
for ranks in generator_wrapper('tp-ep', is_expert=True):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_ep_mp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP',
parallel_type='tp-ep-mp'
)
if rank in ranks:
mcps._EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = group
if not (mcps._EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is None):
raise RuntimeError('The expert_tensor_model_pipeline parallel group is already initialized')
for ranks in generator_wrapper('tp-ep-pp', is_expert=True):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_ep_pp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP',
parallel_type='tp-ep-pp'
)
if rank in ranks:
mcps._EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = group
if not (mcps._EXPERT_DATA_PARALLEL_GROUP is None):
raise RuntimeError('Expert data group is already initialized')
if not (mcps._EXPERT_DATA_PARALLEL_GROUP_GLOO is None):
raise RuntimeError('Expert data group-gloo is already initialized')
for ranks in generator_wrapper('dp', is_expert=True):
group = create_group_qos(
ranks,
timeout=timeout,
pg_options=get_nccl_options('ep_dp', nccl_comm_cfgs),
group_desc='EXPERT_DATA_PARALLEL_GROUP',
parallel_type='ep-dp'
)
if create_gloo_process_groups:
group_gloo = create_group(
ranks, backend="gloo", group_desc='EXPERT_DATA_PARALLEL_GROUP_GLOO'
)
else:
group_gloo = None
if rank in ranks:
mcps._EXPERT_DATA_PARALLEL_GROUP = group
mcps._EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo
_set_global_memory_buffer()