import os
from datetime import timedelta
from functools import partial, wraps
from typing import Union, Callable, List, Optional
import torch
import numpy as np
from megatron.training import get_args, print_rank_0
from megatron.core import mpu
from megatron.core.parallel_state import (
RankGenerator,
default_embedding_ranks,
default_position_embedding_ranks,
get_nccl_options,
)
from megatron.core.utils import is_torch_min_version
_PIPELINE_MODEL_PARALLEL_GROUP = None
_EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
_PIPELINE_MODEL_PARALLEL_DECODER_START = None
_EMBEDDING_GLOBAL_RANKS = None
_POSITION_EMBEDDING_GLOBAL_RANKS = None
_PIPELINE_GLOBAL_RANKS = None
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = None
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = None
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = None
_VTP_ENABLED = False
_VTP_SIZE_LIST = None
_VTP_STAGE_RANKS = None
_VTP_INTRA_STAGE_GROUP = None
_VTP_MY_STAGE_IDX = None
_EDGE_TP_SIZE = 1
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_CLOUD_TP = None
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_EDGE_CLOUD = None
_LAYERWISE_DISAGGREGATED_TRAINING = False
_VDP_SIZE = 1
_VDP_ENABLED = False
def install_ldt_init_model_parallel_patch():
"""Replace megatron's initialize_model_parallel with the VTP-aware wrapper.
Called from validate_args (LDT branch only), which runs inside
initialize_megatron BEFORE _initialize_distributed calls the function.
"""
ori_fn = mpu.initialize_model_parallel
if getattr(ori_fn, '_ldt_wrapped', False):
return
wrapped = initialize_model_parallel_wrapper(ori_fn)
wrapped._ldt_wrapped = True
mpu.initialize_model_parallel = wrapped
def initialize_model_parallel_wrapper(initialize_model_parallel):
@wraps(initialize_model_parallel)
def initialize_model_parallel_impl(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
pipeline_model_parallel_comm_backend: Optional[str] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
hierarchical_context_parallel_sizes: Optional[List[int]] = None,
expert_model_parallel_size: int = 1,
num_distributed_optimizer_instances: int = 1,
expert_tensor_parallel_size: Optional[int] = None,
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
order: str = "tp-cp-ep-dp-pp",
encoder_tensor_model_parallel_size: int = 0,
encoder_pipeline_model_parallel_size: Optional[int] = 0,
get_embedding_ranks: Optional[
Callable[[List[int], Optional[int]], List[int]]
] = None,
get_position_embedding_ranks: Optional[
Callable[[List[int], Optional[int]], List[int]]
] = None,
create_gloo_process_groups: bool = True,
) -> None:
cli_args = get_args()
global _LAYERWISE_DISAGGREGATED_TRAINING, _VDP_SIZE, _VDP_ENABLED
_LAYERWISE_DISAGGREGATED_TRAINING = True
_VDP_SIZE = cli_args.data_parallel_size
vtp_sizes = _auto_detect_vtp_sizes(cli_args)
vdp_size = cli_args.data_parallel_size
_init_vdp_state(tensor_model_parallel_size, context_parallel_size, vdp_size, vtp_sizes)
if _VDP_ENABLED:
print_rank_0(f'Virtual DP enabled with dp size: {vdp_size}')
if vtp_sizes and len(set(vtp_sizes)) > 1:
org_args = (
tensor_model_parallel_size,
pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank,
)
org_kwargs = {
"pipeline_model_parallel_comm_backend": pipeline_model_parallel_comm_backend,
"use_sharp": use_sharp,
"context_parallel_size": context_parallel_size,
"hierarchical_context_parallel_sizes": hierarchical_context_parallel_sizes,
"expert_model_parallel_size": expert_model_parallel_size,
"num_distributed_optimizer_instances": num_distributed_optimizer_instances,
"expert_tensor_parallel_size": expert_tensor_parallel_size,
"nccl_communicator_config_path": nccl_communicator_config_path,
"distributed_timeout_minutes": distributed_timeout_minutes,
"order": order,
"encoder_tensor_model_parallel_size": encoder_tensor_model_parallel_size,
"encoder_pipeline_model_parallel_size": encoder_pipeline_model_parallel_size,
"get_embedding_ranks": get_embedding_ranks,
"get_position_embedding_ranks": get_position_embedding_ranks,
"create_gloo_process_groups": create_gloo_process_groups,
}
if is_vdp_enabled():
_initialize_vtp_static_vtp_vdp(initialize_model_parallel, vtp_sizes, org_args, org_kwargs)
else:
_initialize_vtp_static_only_vtp(initialize_model_parallel, vtp_sizes, org_args, org_kwargs)
return
ori_get_world_size = torch.distributed.get_world_size
def ldt_get_world_size(*args, **kwargs):
real_world_size = ori_get_world_size()
return real_world_size + (_VDP_SIZE - 1) * context_parallel_size * tensor_model_parallel_size
torch.distributed.get_world_size = ldt_get_world_size
ori_rank_generator = mpu.RankGenerator
mpu.RankGenerator = LDTRankGenerator
ori_create_group = mpu.create_group
mpu.create_group = create_group
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank,
pipeline_model_parallel_comm_backend,
use_sharp,
context_parallel_size,
hierarchical_context_parallel_sizes,
expert_model_parallel_size,
num_distributed_optimizer_instances,
expert_tensor_parallel_size,
nccl_communicator_config_path,
distributed_timeout_minutes,
order,
encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size,
get_embedding_ranks,
get_position_embedding_ranks,
create_gloo_process_groups,
)
torch.distributed.get_world_size = ori_get_world_size
mpu.RankGenerator = ori_rank_generator
mpu.create_group = ori_create_group
_sync_all_global_variables(mpu)
if encoder_pipeline_model_parallel_size is None:
encoder_pipeline_model_parallel_size = 0
if (
encoder_tensor_model_parallel_size == 0
and encoder_pipeline_model_parallel_size > 0
):
encoder_tensor_model_parallel_size = tensor_model_parallel_size
if get_embedding_ranks is None:
get_embedding_ranks = partial(
default_embedding_ranks, split_rank=pipeline_model_parallel_split_rank
)
if get_position_embedding_ranks is None:
get_position_embedding_ranks = partial(
default_position_embedding_ranks,
split_rank=pipeline_model_parallel_split_rank,
)
if encoder_pipeline_model_parallel_size > 0:
global _PIPELINE_MODEL_PARALLEL_DECODER_START
_PIPELINE_MODEL_PARALLEL_DECODER_START = encoder_pipeline_model_parallel_size
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed is not initialized")
world_size = torch.distributed.get_world_size() + (_VDP_SIZE - 1) * context_parallel_size * tensor_model_parallel_size
if encoder_tensor_model_parallel_size > 0:
if not (
encoder_tensor_model_parallel_size <= tensor_model_parallel_size
):
raise RuntimeError("We do not support encoders with more TP than the decoder.")
encoder_model_size = (
encoder_tensor_model_parallel_size
* encoder_pipeline_model_parallel_size
* context_parallel_size
)
decoder_model_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* context_parallel_size
)
total_model_size = encoder_model_size + decoder_model_size
if world_size % total_model_size != 0:
raise RuntimeError(
f"world_size ({world_size}) is not divisible by {total_model_size}"
)
data_parallel_size: int = world_size // total_model_size
encoder_world_size = encoder_model_size * data_parallel_size
decoder_world_size = decoder_model_size * data_parallel_size
if not (
encoder_world_size + decoder_world_size == world_size
):
raise RuntimeError(f"{encoder_world_size=} + {decoder_world_size=} != {world_size=}")
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 1:
raise RuntimeError(
"pipeline-model-parallel size should be greater than 1 with interleaved schedule"
)
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = (
virtual_pipeline_model_parallel_size
)
if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
rank = torch.distributed.get_rank()
nccl_comm_cfgs = {}
if nccl_communicator_config_path is not None:
try:
import yaml
except ImportError as e:
raise RuntimeError(
"Cannot import `yaml`. Setting custom nccl communicator configs "
"requires the yaml package."
) from e
with open(nccl_communicator_config_path, "r") as stream:
nccl_comm_cfgs = yaml.safe_load(stream)
if encoder_world_size > 0:
encoder_rank_generator = LDTRankGenerator(
tp=encoder_tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=encoder_pipeline_model_parallel_size,
cp=context_parallel_size,
order=order,
rank_offset=0,
)
else:
encoder_rank_generator = None
decoder_rank_generator = LDTRankGenerator(
tp=tensor_model_parallel_size,
ep=1,
dp=data_parallel_size,
pp=pipeline_model_parallel_size,
cp=context_parallel_size,
order=order,
rank_offset=encoder_world_size,
)
if expert_tensor_parallel_size is None:
expert_tensor_parallel_size = tensor_model_parallel_size
expert_tensor_model_pipeline_parallel_size = (
expert_tensor_parallel_size
* expert_model_parallel_size
* pipeline_model_parallel_size
)
expert_data_parallel_size = (
decoder_world_size // expert_tensor_model_pipeline_parallel_size
)
if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0:
raise RuntimeError(
f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})"
)
expert_decoder_rank_generator = LDTRankGenerator(
tp=expert_tensor_parallel_size,
ep=expert_model_parallel_size,
dp=expert_data_parallel_size,
pp=pipeline_model_parallel_size,
cp=1,
order=order,
rank_offset=encoder_world_size,
)
if not (
order.endswith("pp")
or pipeline_model_parallel_size == 1
or expert_data_parallel_size == data_parallel_size
):
raise RuntimeError("When not using pp-last rank ordering, the data parallel size of the attention and moe layers must be the same")
if not (decoder_rank_generator.get_ranks(
"pp"
) == expert_decoder_rank_generator.get_ranks(
"pp"
)):
raise RuntimeError(f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \
but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}")
def generator_wrapper(group_type, is_expert=False, **kwargs):
"""The `RankGenerator` class produces a hyper-rectangle for a given set of
tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp.
Let's say we have a total of 6 GPUs denoted by g0 ... g5.
For encoder_tp=1, encoder_pp=1, decoder_tp=2, decoder_pp=1, dp=2,
g0, g1 belong to encoder and g2, ..., g5 belong to decoder.
The present function will create with "tp-dp-pp":
3 data-parallel groups: [g0, g1], [g2, g4], [g3, g5]
4 tensor model-parallel groups: [g0], [g1], [g2, g3], [g4, g5]
4 pipeline model-parallel groups: [g0, g2], [g0, g3], [g1, g4], [g1, g5]
"""
if is_expert:
d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs)
else:
d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
if encoder_rank_generator is None:
for x in d_ranks:
yield x
return
e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
if group_type == "pp":
rep = len(d_ranks) // len(e_ranks)
remain = len(d_ranks) % len(e_ranks)
e_ind = 0
e_rep = rep + int(e_ind < remain)
for i, y in enumerate(d_ranks):
x = e_ranks[e_ind]
e_rep -= 1
if e_rep == 0:
e_ind += 1
e_rep = rep + int(e_ind < remain)
yield x + y
elif group_type == "tp-pp":
if len(e_ranks) != len(d_ranks):
raise RuntimeError("Length of encoder ranks and decoder ranks must be the same for tp-pp group")
for x, y in zip(e_ranks, d_ranks):
yield x + y
else:
for x in e_ranks:
yield x
for x in d_ranks:
yield x
timeout = timedelta(minutes=distributed_timeout_minutes)
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_CLOUD_TP
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_EDGE_CLOUD
if _PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_CLOUD_TP is not None:
raise ValueError("VDP cross cloud tp group is already initialized")
if _PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_EDGE_CLOUD is not None:
raise ValueError("VDP cross edge cloud group is already initialized")
vdp_mp_ar_ranks = [r for r in range(0, torch.distributed.get_world_size(), tensor_model_parallel_size)]
vdp_cross_cloud_tp_ranks = vdp_mp_ar_ranks[:2]
vdp_cross_edge_cloud_ranks = vdp_mp_ar_ranks[1:]
vdp_cross_cloud_tp_group = create_group(ranks=vdp_cross_cloud_tp_ranks, timeout=timeout,
pg_options=get_nccl_options('ctpg', nccl_comm_cfgs), group_desc='PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_CLOUD_TP')
if rank in vdp_cross_cloud_tp_ranks:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_CLOUD_TP = vdp_cross_cloud_tp_group
vdp_cross_edge_cloud_group = create_group(ranks=vdp_cross_edge_cloud_ranks, timeout=timeout,
pg_options=get_nccl_options('ecg', nccl_comm_cfgs), group_desc='PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_EDGE_CLOUD')
if rank in vdp_cross_edge_cloud_ranks:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_EDGE_CLOUD = vdp_cross_edge_cloud_group
global _PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST
if pipeline_model_parallel_comm_backend == "ucc":
if "CUDA_DEVICE_MAX_CONNECTIONS" in os.environ:
if os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] == "1":
raise RuntimeError("UCC-backend requires CUDA_DEVICE_MAX_CONNECTIONS > 1")
os.environ["TORCH_UCC_BLOCKING_WAIT"] = (
os.environ["TORCH_UCC_BLOCKING_WAIT"]
if "TORCH_UCC_BLOCKING_WAIT" in os.environ
else "none"
)
os.environ["UCC_EC_CUDA_STREAM_TASK_MODE"] = (
os.environ["UCC_EC_CUDA_STREAM_TASK_MODE"]
if "UCC_EC_CUDA_STREAM_TASK_MODE" in os.environ
else "driver"
)
os.environ["UCX_TLS"] = (
os.environ["UCX_TLS"] if "UCX_TLS" in os.environ else "ib,cuda_copy"
)
os.environ["NSYS_UCP_COMM_PARAMS"] = "1"
os.environ["UCX_RNDV_THRESH"] = "0"
os.environ["UCX_NET_DEVICES"] = "all"
os.environ["UCC_CL_BASIC_TLS"] = "^sharp,nccl"
for ranks in generator_wrapper("pp"):
group_new = create_group(
ranks,
timeout=timeout,
backend=pipeline_model_parallel_comm_backend,
pg_options=(
None
if pipeline_model_parallel_comm_backend == "ucc"
else get_nccl_options("pp", nccl_comm_cfgs)
),
group_desc="PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE",
)
if not (
pipeline_model_parallel_comm_backend is None
or pipeline_model_parallel_comm_backend == "nccl"
or pipeline_model_parallel_comm_backend == "ucc"
):
raise RuntimeError(f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported')
if rank in ranks:
if _PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE is None:
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = group_new
_PIPELINE_GLOBAL_RANKS_NEW_STREAM = ranks
elif isinstance(_PIPELINE_GLOBAL_RANKS_NEW_STREAM[0], list):
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE.append(group_new)
_PIPELINE_GLOBAL_RANKS_NEW_STREAM.append(ranks)
else:
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = [
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE,
group_new,
]
_PIPELINE_GLOBAL_RANKS_NEW_STREAM = [
_PIPELINE_GLOBAL_RANKS_NEW_STREAM,
ranks,
]
group_last_to_first = create_group(
ranks,
timeout=timeout,
backend=pipeline_model_parallel_comm_backend,
pg_options=(
None
if pipeline_model_parallel_comm_backend == "ucc"
else get_nccl_options("pp", nccl_comm_cfgs)
),
group_desc="PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST",
)
if not (
pipeline_model_parallel_comm_backend is None
or pipeline_model_parallel_comm_backend == "nccl"
or pipeline_model_parallel_comm_backend == "ucc"
):
raise RuntimeError(f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported')
if rank in ranks:
if _PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST is None:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = group_last_to_first
_PIPELINE_GLOBAL_RANKS_LAST_TO_FIRST = ranks
elif isinstance(_PIPELINE_GLOBAL_RANKS_LAST_TO_FIRST[0], list):
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST.append(
group_last_to_first
)
_PIPELINE_GLOBAL_RANKS_LAST_TO_FIRST.append(ranks)
else:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = [
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST,
group_last_to_first,
]
_PIPELINE_GLOBAL_RANKS_LAST_TO_FIRST = [
_PIPELINE_GLOBAL_RANKS_LAST_TO_FIRST,
ranks,
]
group_first_to_last = create_group(
ranks,
timeout=timeout,
backend=pipeline_model_parallel_comm_backend,
pg_options=(
None
if pipeline_model_parallel_comm_backend == "ucc"
else get_nccl_options("pp", nccl_comm_cfgs)
),
group_desc="PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST",
)
if not (
pipeline_model_parallel_comm_backend is None
or pipeline_model_parallel_comm_backend == "nccl"
or pipeline_model_parallel_comm_backend == "ucc"
):
raise RuntimeError(f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported')
if rank in ranks:
if _PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST is None:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = group_first_to_last
_PIPELINE_GLOBAL_RANKS_FIRST_TO_LAST = ranks
elif isinstance(_PIPELINE_GLOBAL_RANKS_FIRST_TO_LAST[0], list):
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST.append(
group_first_to_last
)
_PIPELINE_GLOBAL_RANKS_FIRST_TO_LAST.append(ranks)
else:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = [
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST,
group_first_to_last,
]
_PIPELINE_GLOBAL_RANKS_FIRST_TO_LAST = [
_PIPELINE_GLOBAL_RANKS_FIRST_TO_LAST,
ranks,
]
_init_vtp_state(False, [], [])
return initialize_model_parallel_impl
def pre_validate_args_for_vtp(args):
"""Inflate world_size so Megatron's divisibility validation passes.
This only affects the _validate_args pass. The wrapper (which runs after)
uses the real world_size and calls Megatron init directly without going
through the validation path again.
Caller is responsible for guarding with
``args.layerwise_disaggregated_training`` before invoking.
"""
world_size = getattr(args, 'world_size', None)
if world_size is None:
return
tp = args.tensor_model_parallel_size
pp = args.pipeline_model_parallel_size
cp = getattr(args, 'context_parallel_size', 1) or 1
expected = tp * pp * cp
if world_size % expected == 0:
return
args._vtp_orig_world_size = world_size
args.world_size = expected
def post_validate_args_for_vtp(args):
"""Restore real world_size after Megatron validation."""
orig = getattr(args, '_vtp_orig_world_size', None)
if orig is not None:
args.world_size = orig
del args._vtp_orig_world_size
if int(os.environ['GROUP_RANK']) == 0 or int(os.environ['RANK']) == 0:
dp_size = (int(os.environ['WORLD_SIZE']) - int(os.environ['LOCAL_WORLD_SIZE'])) // (args.context_parallel_size * args.tensor_model_parallel_size * (args.pipeline_model_parallel_size - 1))
else:
if args.pipeline_model_parallel_size == int(os.environ['GROUP_WORLD_SIZE']):
dp_size = int(os.environ['LOCAL_WORLD_SIZE']) // (args.context_parallel_size * args.tensor_model_parallel_size)
else:
dp_size = int(os.environ['LOCAL_WORLD_SIZE']) // (args.context_parallel_size * args.tensor_model_parallel_size * ((args.pipeline_model_parallel_size - 1) / (int(os.environ['GROUP_WORLD_SIZE']) - 1)))
args.data_parallel_size = int(dp_size)
def _init_vtp_state(vtp_enabled, vtp_size_list, stage_ranks):
"""Initialize VTP global state variables."""
global _VTP_ENABLED, _VTP_SIZE_LIST, _VTP_STAGE_RANKS
global _VTP_MY_STAGE_IDX
_VTP_ENABLED = vtp_enabled
_VTP_SIZE_LIST = vtp_size_list
_VTP_STAGE_RANKS = stage_ranks
if _VTP_ENABLED:
print_rank_0(f'Virtual TP enabled with tp size list: {vtp_size_list}')
rank = torch.distributed.get_rank()
for stage_idx, stage in enumerate(stage_ranks):
if rank in stage:
_VTP_MY_STAGE_IDX = stage_idx
break
def _create_vtp_groups(stage_ranks, timeout, backend):
"""Create VTP intra-stage communication group.
PP rank0-only groups are already created during _initialize_vtp_static
as standard PP groups (main, alternate, last-to-first, first-to-last),
so only the intra-stage broadcast group is created here.
"""
global _VTP_INTRA_STAGE_GROUP
rank = torch.distributed.get_rank()
for stage in stage_ranks:
if len(stage) > 1:
group = torch.distributed.new_group(
ranks=stage, timeout=timeout, pg_options=get_nccl_options('tp', {}), group_desc='TENSOR_MODEL_PARALLEL_GROUP'
)
if rank in stage:
_VTP_INTRA_STAGE_GROUP = group
def is_vtp_enabled():
return _VTP_ENABLED
def get_vtp_size_list():
return _VTP_SIZE_LIST
def get_vtp_stage_ranks():
return _VTP_STAGE_RANKS
def get_vtp_intra_stage_group():
return _VTP_INTRA_STAGE_GROUP
def get_edge_tp_size():
return _EDGE_TP_SIZE
def vtp_allreduce(tensor, op=torch.distributed.ReduceOp.SUM):
"""VTP-aware hierarchical allreduce.
Replaces a flat 17-rank cross-network allreduce on model_parallel_group
with a 3-step hierarchical reduction:
1. Intra-stage TP allreduce (intra-node, fast)
2. Cross-stage PP allreduce (rank0-only, 3 ranks)
3. Intra-stage broadcast (from rank0, fast)
Mathematically correct for SUM, MAX, MIN — all are decomposable.
"""
if mpu.get_tensor_model_parallel_world_size() > 1:
torch.distributed.all_reduce(
tensor, op=op, group=mpu.get_tensor_model_parallel_group()
)
if is_vtp_stage_rank0():
torch.distributed.all_reduce(
tensor, op=op, group=mpu.get_pipeline_model_parallel_group()
)
intra_group = mpu.get_tensor_model_parallel_group()
if intra_group is not None:
stage_ranks = get_vtp_stage_ranks()
my_stage = get_vtp_my_stage_idx()
torch.distributed.broadcast(
tensor, src=stage_ranks[my_stage][0], group=intra_group
)
def vtp_hierarchical_barrier():
"""VTP-aware hierarchical barrier (3-step sync)."""
if mpu.get_tensor_model_parallel_world_size() > 1:
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
if is_vtp_stage_rank0():
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
intra_group = get_vtp_intra_stage_group()
if intra_group is not None:
torch.distributed.barrier(group=intra_group)
def get_vtp_my_stage_idx():
return _VTP_MY_STAGE_IDX
def is_vtp_stage_rank0():
if not _VTP_STAGE_RANKS or _VTP_MY_STAGE_IDX is None:
return True
return torch.distributed.get_rank() == _VTP_STAGE_RANKS[_VTP_MY_STAGE_IDX][0]
def _auto_detect_vtp_sizes(args):
global _EDGE_TP_SIZE
world_size = torch.distributed.get_world_size()
tp = args.tensor_model_parallel_size
cp = args.context_parallel_size
ep = args.expert_model_parallel_size
dp = args.data_parallel_size
pp = args.pipeline_model_parallel_size
cloud_size = tp * cp * ep * dp * (pp - 1)
edge_size = world_size - cloud_size
edge_tp = edge_size // cp // ep if edge_size != tp * cp * ep * dp else tp
_EDGE_TP_SIZE = edge_tp
vtp_sizes = [edge_tp] + [tp] * (pp - 1)
return vtp_sizes
def _transform_3d_list(data):
if not data or not data[0]:
return []
fixed_col = list(data[0][0])
R = len(data)
C = len(data[0]) - 1
lengths = [[len(data[r][c + 1]) for c in range(C)] for r in range(R)]
start_val = data[0][1][0] if lengths[0][0] > 0 else 0
total_len = sum(sum(row_lens) for row_lens in lengths)
full_seq = list(range(start_val, start_val + total_len))
filled_grid = [[None] * C for _ in range(R)]
seq_idx = 0
for c in range(C):
for r in range(R):
length_val = lengths[r][c]
filled_grid[r][c] = full_seq[seq_idx: seq_idx + length_val]
seq_idx += length_val
result = []
for r in range(R):
new_row = [list(fixed_col)]
new_row.extend(filled_grid[r])
result.append(new_row)
return result
def find_3d_indices(data, target):
"""
Find a target element in a 3D list and return its 3D position index.
:param data: 3D list (allows irregular nesting)
:param target: Target element to find
:return: Tuple (first-level index, second-level index, innermost index); returns None if not found
"""
if not isinstance(data, list):
return None, None, None
for i, layer in enumerate(data):
if not isinstance(layer, list):
continue
for j, sublist in enumerate(layer):
if isinstance(sublist, (list, tuple)):
try:
k = sublist.index(target)
return i, j, k
except ValueError:
continue
return None, None, None
def _initialize_vtp_static_vtp_vdp(fn, vtp_sizes, orig_args, orig_kwargs):
"""Initialize parallel state for static VTP with non-uniform TP sizes.
When per-node GPU counts differ (e.g., [1, 2] for edge+cloud),
world_size = sum(tp_sizes) * DP, which != max_tp * PP * DP.
Megatron's standard init fails the world_size % (TP*PP) == 0 check.
Strategy:
1. Call Megatron's init with TP=sum(sizes), PP=1 to pass validation
2. Override TP/PP/DP/model-parallel groups to match actual VTP layout
3. Create LDT alternate PP groups (ping/pang, last-to-first, first-to-last)
4. Initialize VTP state and communication groups
"""
rank = torch.distributed.get_rank()
vtp_model_size = sum(vtp_sizes)
pp_size = len(vtp_sizes)
data_parallel_size = get_vdp_size()
modified_args = (max(vtp_sizes), pp_size, None) + orig_args[3:]
modified_kwargs = dict(orig_kwargs)
if 'expert_tensor_parallel_size' in modified_kwargs:
modified_kwargs['expert_tensor_parallel_size'] = None
ori_get_world_size = torch.distributed.get_world_size
def ldt_get_world_size(*args, **kwargs):
stats_world_size = get_vdp_size() * max(vtp_sizes) * kwargs.get('context_parallel_size', 1) * pp_size
return stats_world_size
torch.distributed.get_world_size = ldt_get_world_size
ori_rank_generator = mpu.RankGenerator
mpu.RankGenerator = LDTRankGenerator
ori_create_group = mpu.create_group
mpu.create_group = create_group
fn(*modified_args, **modified_kwargs)
torch.distributed.get_world_size = ori_get_world_size
mpu.RankGenerator = ori_rank_generator
mpu.create_group = ori_create_group
_sync_all_global_variables(mpu)
all_domain_stages = []
for dp in range(data_parallel_size):
offset = dp * vtp_model_size
stages = []
for tp_size in vtp_sizes:
stages.append(list(range(offset, offset + tp_size)))
offset += tp_size
all_domain_stages.append(stages)
all_domain_stages = _transform_3d_list(all_domain_stages)
my_dp, my_stage_idx, my_intra_rank = find_3d_indices(all_domain_stages, rank)
if my_stage_idx is None:
raise RuntimeError(
f"VTP static init: rank {rank} not found in any stage of domain {my_dp}. "
f"stages={all_domain_stages}"
)
my_stages = all_domain_stages[my_dp]
actual_tp = vtp_sizes[my_stage_idx]
nccl_comm_cfgs = {}
nccl_config_path = orig_kwargs.get('nccl_communicator_config_path', None)
if nccl_config_path:
import yaml
with open(nccl_config_path, 'r') as f:
nccl_comm_cfgs = yaml.safe_load(f)
timeout = timedelta(
minutes=orig_kwargs.get('distributed_timeout_minutes', 30)
)
backend = orig_kwargs.get('pipeline_model_parallel_comm_backend', None)
for domain_stages in all_domain_stages:
for stage in domain_stages:
group = torch.distributed.new_group(stage, timeout=timeout)
if rank in stage:
mpu._TENSOR_MODEL_PARALLEL_GROUP = group
mpu._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = stage
mpu._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = actual_tp
mpu._MPU_TENSOR_MODEL_PARALLEL_RANK = my_intra_rank
pg_options = (
get_nccl_options('pp', nccl_comm_cfgs)
if backend != 'ucc' else None
)
global _PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST
mpu._PIPELINE_MODEL_PARALLEL_GROUP = None
mpu._PIPELINE_GLOBAL_RANKS = None
for domain_stages in all_domain_stages:
rank0_list = [s[0] for s in domain_stages]
all_domain_ranks = [r for stage in domain_stages for r in stage]
max_intra = max(len(stage) for stage in domain_stages)
for intra in range(max_intra):
pp_chain = []
for stage in domain_stages:
if intra < len(stage):
pp_chain.append(stage[intra])
else:
pp_chain.append(stage[0])
group = torch.distributed.new_group(
pp_chain, timeout=timeout,
backend=backend, pg_options=pg_options,
)
group_alt = torch.distributed.new_group(
pp_chain, timeout=timeout,
backend=backend, pg_options=pg_options,
)
if rank in pp_chain:
is_rank0 = rank in rank0_list
if intra == 0 or not is_rank0:
if int(os.environ['GROUP_RANK']) == 0 or int(os.environ['RANK']) == 0:
if mpu._PIPELINE_MODEL_PARALLEL_GROUP is None:
mpu._PIPELINE_MODEL_PARALLEL_GROUP = group
mpu._PIPELINE_GLOBAL_RANKS = pp_chain
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = group_alt
elif isinstance(mpu._PIPELINE_MODEL_PARALLEL_GROUP, list):
mpu._PIPELINE_MODEL_PARALLEL_GROUP.append(group)
mpu._PIPELINE_GLOBAL_RANKS.append(pp_chain)
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE.append(group_alt)
else:
mpu._PIPELINE_MODEL_PARALLEL_GROUP = [mpu._PIPELINE_MODEL_PARALLEL_GROUP, group]
mpu._PIPELINE_GLOBAL_RANKS = [mpu._PIPELINE_GLOBAL_RANKS, pp_chain]
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = [_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE, group_alt]
else:
mpu._PIPELINE_MODEL_PARALLEL_GROUP = group
mpu._PIPELINE_GLOBAL_RANKS = pp_chain
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = group_alt
group_l2f = torch.distributed.new_group(
rank0_list, timeout=timeout, backend=backend, pg_options=pg_options,
)
group_f2l = torch.distributed.new_group(
rank0_list, timeout=timeout, backend=backend, pg_options=pg_options,
)
if rank in all_domain_ranks:
if int(os.environ['GROUP_RANK']) == 0 or int(os.environ['RANK']) == 0:
if _PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST is None:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = group_l2f
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = group_f2l
elif isinstance(_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST, list):
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST.append(group_l2f)
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST.append(group_f2l)
else:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = [_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST, group_l2f]
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = [_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST, group_f2l]
else:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = group_l2f
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = group_f2l
mpu._MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = pp_size
mpu._MPU_PIPELINE_MODEL_PARALLEL_RANK = my_stage_idx
for domain_stages in all_domain_stages:
all_ranks = [r for stage in domain_stages for r in stage]
group = torch.distributed.new_group(all_ranks, timeout=timeout)
if rank in all_ranks:
mpu._MODEL_PARALLEL_GROUP = group
mpu._MODEL_PARALLEL_GLOBAL_RANKS = all_ranks
args = get_args()
args.tensor_model_parallel_size = actual_tp
args.data_parallel_size = data_parallel_size
orig_vpp = orig_args[2] if len(orig_args) > 2 else None
if orig_vpp is not None:
mpu._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = orig_vpp
mpu._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_init_vtp_state(True, vtp_sizes, my_stages)
_create_vtp_groups(my_stages, timeout, backend)
def _initialize_vtp_static_only_vtp(fn, vtp_sizes, orig_args, orig_kwargs):
"""Initialize parallel state for static VTP with non-uniform TP sizes.
When per-node GPU counts differ (e.g., [1, 2] for edge+cloud),
world_size = sum(tp_sizes) * DP, which != max_tp * PP * DP.
Megatron's standard init fails the world_size % (TP*PP) == 0 check.
Strategy:
1. Call Megatron's init with TP=sum(sizes), PP=1 to pass validation
2. Override TP/PP/DP/model-parallel groups to match actual VTP layout
3. Create LDT alternate PP groups (ping/pang, last-to-first, first-to-last)
4. Initialize VTP state and communication groups
"""
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
vtp_model_size = sum(vtp_sizes)
pp_size = len(vtp_sizes)
if world_size % vtp_model_size != 0:
raise RuntimeError(
f"VTP static: world_size ({world_size}) is not divisible by "
f"sum(vtp_sizes) ({vtp_model_size})"
)
data_parallel_size = world_size // vtp_model_size
modified_args = (vtp_model_size, 1, None) + orig_args[3:]
modified_kwargs = dict(orig_kwargs)
if 'expert_tensor_parallel_size' in modified_kwargs:
modified_kwargs['expert_tensor_parallel_size'] = None
fn(*modified_args, **modified_kwargs)
all_domain_stages = []
for dp in range(data_parallel_size):
offset = dp * vtp_model_size
stages = []
for tp_size in vtp_sizes:
stages.append(list(range(offset, offset + tp_size)))
offset += tp_size
all_domain_stages.append(stages)
my_dp = rank // vtp_model_size
my_stages = all_domain_stages[my_dp]
my_stage_idx = None
my_intra_rank = None
for idx, stage in enumerate(my_stages):
if rank in stage:
my_stage_idx = idx
my_intra_rank = stage.index(rank)
break
if my_stage_idx is None:
raise RuntimeError(
f"VTP static init: rank {rank} not found in any stage of domain {my_dp}. "
f"stages={my_stages}"
)
actual_tp = vtp_sizes[my_stage_idx]
nccl_comm_cfgs = {}
nccl_config_path = orig_kwargs.get('nccl_communicator_config_path', None)
if nccl_config_path:
import yaml
with open(nccl_config_path, 'r') as f:
nccl_comm_cfgs = yaml.safe_load(f)
timeout = timedelta(
minutes=orig_kwargs.get('distributed_timeout_minutes', 30)
)
backend = orig_kwargs.get('pipeline_model_parallel_comm_backend', None)
for domain_stages in all_domain_stages:
for stage in domain_stages:
group = torch.distributed.new_group(stage, timeout=timeout)
if rank in stage:
mpu._TENSOR_MODEL_PARALLEL_GROUP = group
mpu._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = stage
mpu._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = actual_tp
mpu._MPU_TENSOR_MODEL_PARALLEL_RANK = my_intra_rank
pg_options = (
get_nccl_options('pp', nccl_comm_cfgs)
if backend != 'ucc' else None
)
global _PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST
global _PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST
for domain_stages in all_domain_stages:
rank0_list = [s[0] for s in domain_stages]
all_domain_ranks = [r for stage in domain_stages for r in stage]
max_intra = max(len(stage) for stage in domain_stages)
for intra in range(max_intra):
pp_chain = []
for stage in domain_stages:
if intra < len(stage):
pp_chain.append(stage[intra])
else:
pp_chain.append(stage[0])
group = torch.distributed.new_group(
pp_chain, timeout=timeout,
backend=backend, pg_options=pg_options,
)
group_alt = torch.distributed.new_group(
pp_chain, timeout=timeout,
backend=backend, pg_options=pg_options,
)
if rank in pp_chain:
is_rank0 = rank in rank0_list
if intra == 0 or not is_rank0:
mpu._PIPELINE_MODEL_PARALLEL_GROUP = group
mpu._PIPELINE_GLOBAL_RANKS = pp_chain
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = group_alt
group_l2f = torch.distributed.new_group(
rank0_list, timeout=timeout, backend=backend, pg_options=pg_options,
)
group_f2l = torch.distributed.new_group(
rank0_list, timeout=timeout, backend=backend, pg_options=pg_options,
)
if rank in all_domain_ranks:
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = group_l2f
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = group_f2l
mpu._MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = pp_size
mpu._MPU_PIPELINE_MODEL_PARALLEL_RANK = my_stage_idx
for domain_stages in all_domain_stages:
all_ranks = [r for stage in domain_stages for r in stage]
group = torch.distributed.new_group(all_ranks, timeout=timeout)
if rank in all_ranks:
mpu._MODEL_PARALLEL_GROUP = group
mpu._MODEL_PARALLEL_GLOBAL_RANKS = all_ranks
if data_parallel_size > 1:
create_gloo = orig_kwargs.get('create_gloo_process_groups', True)
for stage_idx in range(pp_size):
for intra in range(vtp_sizes[stage_idx]):
dp_ranks = [
all_domain_stages[dp][stage_idx][intra]
for dp in range(data_parallel_size)
]
g_nccl = torch.distributed.new_group(
dp_ranks, timeout=timeout
)
g_gloo = (
torch.distributed.new_group(
dp_ranks, timeout=timeout, backend='gloo'
)
if create_gloo else None
)
if rank in dp_ranks:
mpu._DATA_PARALLEL_GROUP = g_nccl
mpu._DATA_PARALLEL_GROUP_GLOO = g_gloo
mpu._DATA_PARALLEL_GLOBAL_RANKS = dp_ranks
mpu._MPU_DATA_PARALLEL_WORLD_SIZE = data_parallel_size
mpu._MPU_DATA_PARALLEL_RANK = my_dp
args = get_args()
args.tensor_model_parallel_size = actual_tp
args.data_parallel_size = data_parallel_size
orig_vpp = orig_args[2] if len(orig_args) > 2 else None
if orig_vpp is not None:
mpu._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = orig_vpp
mpu._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_init_vtp_state(True, vtp_sizes, my_stages)
_create_vtp_groups(my_stages, timeout, backend)
def create_group(
ranks=None,
timeout=None,
backend=None,
pg_options=None,
use_local_synchronization=False,
group_desc=None,
):
"""Creates a ProcessGroup."""
if _LAYERWISE_DISAGGREGATED_TRAINING:
kwargs = {
'ranks': list(set(ranks)),
'timeout': timeout,
'backend': backend,
'pg_options': pg_options,
'use_local_synchronization': use_local_synchronization,
'group_desc': group_desc,
}
else:
kwargs = {
'ranks': ranks,
'timeout': timeout,
'backend': backend,
'pg_options': pg_options,
'use_local_synchronization': use_local_synchronization,
'group_desc': group_desc,
}
if not is_torch_min_version('2.4.0'):
kwargs.pop('group_desc')
if timeout is None:
kwargs.pop('timeout')
return torch.distributed.new_group(**kwargs)
def transform_x_dimension(data: Union[List[List[int]], np.ndarray], x_dim: int) -> np.ndarray:
"""
Stream chunking and flattening transformation of 2D array along the specified X-axis dimension, strictly preserving original shape.
Business logic:
1. Flatten the original data into a 1D continuous sequence
2. Split into continuous chunks by x_dim
3. Take the first `rows` chunks and assign them to corresponding rows
4. Repeat each chunk horizontally to fill the original number of columns
"""
arr = np.asarray(data)
if arr.ndim != 2:
raise ValueError("Input must be a 2D array")
rows, cols = arr.shape
if not (0 < x_dim <= cols):
raise ValueError(f"x_dim must be greater than 0 and not exceed the original column count {cols}")
if cols % x_dim != 0:
raise ValueError(f"To maintain periodic tiling, the original column count {cols} must be divisible by x_dim {x_dim}")
blocks = arr.ravel().reshape(-1, x_dim)
selected = blocks[:rows]
result = np.tile(selected, reps=(1, cols // x_dim))
return result
class LDTRankGenerator(RankGenerator):
def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0):
super().__init__(tp, ep, dp, pp, cp, order, rank_offset)
def get_ranks(self, token):
"""Get rank group by input token.
Args:
token (str):
Specify the ranks type that want to get. If we want
to obtain multiple parallel types, we can use a hyphen
'-' to separate them. For example, if we want to obtain
the TP_DP group, the token should be 'tp-dp'.
"""
def add_value_to_array(arr, value):
"""Add value to each element in the numpy array."""
arr = np.array(arr)
return arr + value
def mod_array(arr, mod_value):
if mod_value == 0:
raise ValueError("mod_value must be greater than 0.")
arr = np.array(arr)
return np.mod(arr, mod_value)
def get_edge_card_size():
"""Get edge card size."""
edge_card_size = 0
if int(os.environ.get("GROUP_RANK", -1)) == 0 or int(os.environ.get("RANK", -1)) == 0:
edge_card_size = int(os.environ.get("LOCAL_WORLD_SIZE", 0))
if is_vtp_enabled():
edge_card_size = edge_card_size * self.tp
broadcast_list = [edge_card_size]
torch.distributed.broadcast_object_list(broadcast_list, src=0)
edge_card_size = broadcast_list[0]
else:
broadcast_list = [0]
torch.distributed.broadcast_object_list(broadcast_list, src=0)
edge_card_size = broadcast_list[0]
return edge_card_size
if _LAYERWISE_DISAGGREGATED_TRAINING:
edge_card_size = get_edge_card_size()
edge_ranks = RankGenerator(self.tp, self.ep, self.dp, 1, self.cp, self.order).get_ranks(token)
edge_ranks = mod_array(edge_ranks, edge_card_size)
if token.find("tp") > -1:
edge_ranks_vtp = transform_x_dimension(edge_ranks, get_edge_tp_size())
else:
edge_ranks_vtp = edge_ranks
cloud_ranks = RankGenerator(self.tp, self.ep, self.dp, self.pp - 1, self.cp, self.order).get_ranks(token)
cloud_ranks = add_value_to_array(cloud_ranks, edge_card_size)
if token.find("pp") > -1:
ranks = np.concatenate((edge_ranks, cloud_ranks), axis=1).tolist()
else:
ranks = np.concatenate((edge_ranks, cloud_ranks), axis=0).tolist()
else:
ranks = super().get_ranks(token)
return ranks
def _init_vdp_state(tensor_model_parallel_size, context_parallel_size, vdp_size, vtp_sizes):
global _VDP_ENABLED
if vtp_sizes:
edge_tp_size = vtp_sizes[0]
else:
edge_tp_size = tensor_model_parallel_size
if int(os.environ['GROUP_RANK']) == 0 or int(os.environ['RANK']) == 0:
if int(os.environ['LOCAL_WORLD_SIZE']) % (context_parallel_size * edge_tp_size) == 0:
edge_dp_size = int(os.environ['LOCAL_WORLD_SIZE']) // (context_parallel_size * edge_tp_size)
if edge_dp_size != vdp_size:
_VDP_ENABLED = True
else:
_VDP_ENABLED = False
else:
_VDP_ENABLED = True
else:
edge_world_size = int(os.environ['WORLD_SIZE']) - int(os.environ['LOCAL_WORLD_SIZE']) * (int(os.environ['GROUP_WORLD_SIZE']) - 1)
if edge_world_size % (context_parallel_size * edge_tp_size) == 0:
edge_dp_size = edge_world_size // (context_parallel_size * edge_tp_size)
if edge_dp_size != vdp_size:
_VDP_ENABLED = True
else:
_VDP_ENABLED = False
else:
_VDP_ENABLED = True
def get_pipeline_model_parallel_group_alternate():
"""Get the alternate pipeline model parallel communication group.
This function returns the alternate pipeline model parallel group used for
double-buffering communication in pipeline parallel training. It works in
conjunction with the default pipeline model parallel group to enable
efficient alternating communication streams.
Returns:
torch.distributed.ProcessGroup or list[torch.distributed.ProcessGroup]:
The alternate pipeline model parallel communication group(s).
Returns a list if the current rank belongs to multiple pipeline groups.
Raises:
RuntimeError: If the pipeline model parallel group is not initialized.
Note:
- This group is used in double-buffering communication to improve performance
- It is typically used alongside the default pipeline model parallel group
- The two groups are alternated based on the pipeline parallel rank parity
"""
if not (
_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE is not None
):
raise RuntimeError("pipeline_model parallel group is not initialized")
return _PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE
def get_pipeline_model_parallel_group_last_to_first():
"""Get the pipeline model parallel communication group for last-to-first direction.
This function returns the pipeline model parallel group used for communication
in the last-to-first direction. It is typically used when the pipeline parallel
world size is odd, requiring additional communication streams for the first
and last stages.
Returns:
torch.distributed.ProcessGroup or list[torch.distributed.ProcessGroup]:
The pipeline model parallel communication group(s) for last-to-first direction.
Returns a list if the current rank belongs to multiple pipeline groups.
Raises:
RuntimeError: If the pipeline model parallel group is not initialized.
Note:
- This group is used for communication from last stage to first stage
- It is primarily used when pipeline parallel world size is odd
- Used to handle edge cases in U-shaped pipeline parallelism
"""
if not (
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST is not None
):
raise RuntimeError("pipeline_model parallel group is not initialized")
return _PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST
def get_pipeline_model_parallel_group_first_to_last():
"""Get the pipeline model parallel communication group for first-to-last direction.
This function returns the pipeline model parallel group used for communication
in the first-to-last direction. It is typically used when the pipeline parallel
world size is odd, requiring additional communication streams for the first
and last stages.
Returns:
torch.distributed.ProcessGroup or list[torch.distributed.ProcessGroup]:
The pipeline model parallel communication group(s) for first-to-last direction.
Returns a list if the current rank belongs to multiple pipeline groups.
Raises:
RuntimeError: If the pipeline model parallel group is not initialized.
Note:
- This group is used for communication from first stage to last stage
- It is primarily used when pipeline parallel world size is odd
- Used to handle edge cases in U-shaped pipeline parallelism
"""
if not (
_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST is not None
):
raise RuntimeError("pipeline_model parallel group is not initialized")
return _PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST
def _sync_all_global_variables(megatron_mpu):
"""Sync all global variables in megatron_mpu."""
from mindspeed_mm.patchs.layerwise_disaggregated_training import parallel_state_patch as ldt_mpu
global_variable_names = [
'_TENSOR_MODEL_PARALLEL_GROUP',
'_PIPELINE_MODEL_PARALLEL_GROUP',
'_MODEL_PARALLEL_GROUP',
'_EMBEDDING_GROUP',
'_POSITION_EMBEDDING_GROUP',
'_DATA_PARALLEL_GROUP',
'_DATA_PARALLEL_GROUP_GLOO',
'_TENSOR_AND_DATA_PARALLEL_GROUP',
'_EXPERT_MODEL_PARALLEL_GROUP',
'_EXPERT_TENSOR_PARALLEL_GROUP',
'_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP',
'_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP',
'_EXPERT_DATA_PARALLEL_GROUP',
'_EXPERT_DATA_PARALLEL_GROUP_GLOO',
'_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE',
'_MPU_EXPERT_MODEL_PARALLEL_RANK',
'_MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE',
'_MPU_EXPERT_TENSOR_PARALLEL_RANK',
'_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK',
'_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE',
'_PIPELINE_MODEL_PARALLEL_SPLIT_RANK',
'_PIPELINE_MODEL_PARALLEL_DECODER_START',
'_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE',
'_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE',
'_MPU_DATA_PARALLEL_WORLD_SIZE',
'_MPU_DATA_PARALLEL_RANK',
'_MPU_TENSOR_MODEL_PARALLEL_RANK',
'_MPU_PIPELINE_MODEL_PARALLEL_RANK',
'_EMBEDDING_GLOBAL_RANKS',
'_POSITION_EMBEDDING_GLOBAL_RANKS',
'_PIPELINE_GLOBAL_RANKS',
'_DATA_PARALLEL_GLOBAL_RANKS',
'_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS',
'_MODEL_PARALLEL_GLOBAL_RANKS',
'_CONTEXT_PARALLEL_GROUP',
'_CONTEXT_PARALLEL_GLOBAL_RANKS',
'_HIERARCHICAL_CONTEXT_PARALLEL_GROUP',
'_DATA_PARALLEL_GROUP_WITH_CP',
'_DATA_PARALLEL_GROUP_WITH_CP_GLOO',
'_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP',
'_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
'_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO',
'_INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
'_TENSOR_AND_CONTEXT_PARALLEL_GROUP',
'_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP',
'_PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE',
'_PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST',
'_PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST',
'_GLOBAL_MEMORY_BUFFER',
'_MOE_LAYER_WISE_LOGGING_TRACKER',
]
for var_name in global_variable_names:
if hasattr(megatron_mpu, var_name):
megatron_value = getattr(megatron_mpu, var_name)
setattr(ldt_mpu, var_name, megatron_value)
def get_layerwise_disaggregated_training():
return _LAYERWISE_DISAGGREGATED_TRAINING
def is_vdp_enabled():
return _VDP_ENABLED
def get_vdp_size():
"""get the size of the virtual data parallel group"""
return _VDP_SIZE
def get_pipeline_model_parallel_group_for_vdp_cross_cloud_tp():
return _PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_CLOUD_TP
def get_pipeline_model_parallel_group_for_vdp_cross_edge_cloud():
return _PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_EDGE_CLOUD