from typing import Optional
import torch
def get_pg_size(group: Optional[torch.distributed.ProcessGroup]) -> int:
"""Return process group size, treating None as a single-rank group."""
if group is None:
return 1
return torch.distributed.get_world_size(group=group)
def get_pg_rank(group: Optional[torch.distributed.ProcessGroup]) -> int:
"""Return process group rank, treating None as a single-rank group."""
if group is None:
return 0
return torch.distributed.get_rank(group=group)
class LegacyProcessGroupCollection:
"""Process-group collection with the field names used by Megatron dev."""
def __init__(self) -> None:
from megatron.core import mpu
self.mp = mpu.get_model_parallel_group()
self.tp = mpu.get_tensor_model_parallel_group()
try:
self.dp_cp = mpu.get_data_parallel_group(with_context_parallel=True, partial_data_parallel=True)
except TypeError:
self.dp_cp = mpu.get_data_parallel_group()
if hasattr(mpu, "get_expert_tensor_parallel_group"):
try:
self.expt_tp = mpu.get_expert_tensor_parallel_group()
except (AssertionError, RuntimeError):
self.expt_tp = self.tp
else:
self.expt_tp = self.tp
if hasattr(mpu, "get_expert_data_parallel_group"):
try:
self.expt_dp = mpu.get_expert_data_parallel_group()
except (AssertionError, RuntimeError):
self.expt_dp = self.dp_cp
else:
self.expt_dp = self.dp_cp
for getter_name in (
"get_expert_tensor_model_pipeline_parallel_group",
"get_expert_tensor_and_model_parallel_group",
):
if hasattr(mpu, getter_name):
try:
self.tp_ep_pp = getattr(mpu, getter_name)()
break
except (AssertionError, RuntimeError):
continue
else:
self.tp_ep_pp = self.mp