import os
import types
import logging
from dataclasses import dataclass, asdict, fields
from functools import reduce
import torch
from torch.distributed.device_mesh import init_device_mesh, DeviceMesh
from mindspeed_llm.fsdp2.distributed.parallel_engine_config import ParallelEngineConfig
from mindspeed.lite.utils.log import print_rank
logger = logging.getLogger(__name__)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
def get_last_mesh_dim(mesh_shape):
last_mesh = torch.distributed.get_world_size()
for shape in mesh_shape:
if last_mesh % shape != 0:
raise AssertionError("World size is not divisible by mesh group {}".format(mesh_shape))
last_mesh //= shape
return last_mesh
def init_parallel_state(config: ParallelEngineConfig):
field_names = {field.name for field in fields(ParallelState)}
parallel_state_config = {k: v for k, v in asdict(config).items() if k in field_names}
return ParallelState(**parallel_state_config)
@dataclass
class ParallelState(metaclass=Singleton):
data_parallel_size: int = -1
fully_shard_parallel_size: int = -1
tensor_parallel_size: int = -1
context_parallel_size: int = -1
expert_parallel_size: int = -1
expert_fully_shard_parallel_size: int = -1
expert_data_parallel_size: int = -1
device_mesh_map: dict[str, DeviceMesh] = None
def __post_init__(self):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='hccl')
if self.fully_shard_parallel_size < 0:
raise Exception("ParallelState has not been initialized .")
if self.device_mesh_map is None:
self.device_mesh_map = dict()
mesh_dim_names = ('dp', 'fsdp', 'tp')
mesh_shape = (
self.fully_shard_parallel_size,
self.tensor_parallel_size,
)
self.data_parallel_size = get_last_mesh_dim(mesh_shape)
mesh_shape = (self.data_parallel_size,) + mesh_shape
self.add_device_mesh_groups(mesh_dim_names, mesh_shape)
mesh_dim_names = ('edp', 'efsdp', 'ep')
mesh_shape = (self.expert_fully_shard_parallel_size, self.expert_parallel_size,)
self.expert_data_parallel_size = get_last_mesh_dim(mesh_shape)
mesh_shape = (self.expert_data_parallel_size,) + mesh_shape
self.add_device_mesh_groups(mesh_dim_names, mesh_shape)
mesh_dim_names = ('dp_fsdp', 'cp')
mesh_shape = (self.fully_shard_parallel_size * self.data_parallel_size // self.context_parallel_size,
self.context_parallel_size)
self.add_device_mesh_groups(mesh_dim_names, mesh_shape)
if self.data_parallel_size != 1:
raise Exception(
f"Currently, only `data_parallel_size` = 1 is supported, but `self.data_parallel_size` = {self.data_parallel_size} .")
if self.fully_shard_parallel_size != torch.distributed.get_world_size():
raise Exception("Currently, `fully_shard_parallel_size` must be equal to the distributed world size.")
print_rank(logger.info, f'Parallel state initialized:\n {self.__str__()}')
def __str__(self):
info = ''
for name, _ in self.device_mesh_map.items():
enable = self.is_group_enable(name)
size = self.get_group_size(name)
mesh = self.get_device_mesh(name)
info += f'[{name}] = {enable} | Group size: {size} | device mesh:{mesh} \n'
return info
@property
def is_initialized(self) -> bool:
return torch.distributed.is_initialized()
@property
def world_size(self) -> int:
return 1 if not self.is_initialized else torch.distributed.get_world_size()
@property
def local_rank(self) -> int:
return int(os.getenv("LOCAL_RANK", "-1"))
@property
def global_rank(self) -> int:
return -1 if not self.is_initialized else torch.distributed.get_rank()
def is_group_enable(self, mesh_name: str) -> bool:
if mesh_name in self.device_mesh_map:
return self.get_group_size(mesh_name) > 1
else:
return False
def get_group(self, mesh_name: str):
if mesh_name in self.device_mesh_map:
return self.device_mesh_map[mesh_name].get_group(mesh_name)
else:
raise RuntimeError(f"Mesh group {mesh_name} not found.")
def get_group_size(self, mesh_name: str):
if mesh_name in self.device_mesh_map:
return torch.distributed.get_world_size(self.device_mesh_map[mesh_name].get_group(mesh_name))
else:
raise RuntimeError(f"Mesh group {mesh_name} not found.")
def get_rank(self, mesh_name: str):
if mesh_name in self.device_mesh_map:
return self.device_mesh_map[mesh_name].get_local_rank(mesh_name)
else:
raise RuntimeError(f"Mesh group {mesh_name} not found.")
def get_device_mesh(self, mesh_name: str):
if mesh_name in self.device_mesh_map:
return self.device_mesh_map[mesh_name][mesh_name]
else:
raise RuntimeError(f"Mesh group {mesh_name} not found.")
def add_device_mesh_groups(self, mesh_dim_names, mesh_shape):
def get_methods(name):
def is_enable_method(self):
return self.is_group_enable(name)
def get_group_method(self):
return self.get_group(name)
def get_size_method(self):
return self.get_group_size(name)
def get_rank_method(self):
return self.get_rank(name)
def get_mesh_method(self):
return self.get_device_mesh(name)
return is_enable_method, get_group_method, get_size_method, get_rank_method, get_mesh_method
if reduce(lambda a, b: a * b, mesh_shape) != torch.distributed.get_world_size():
raise AssertionError(f"Mesh groups {mesh_shape}({reduce(lambda a, b: a * b, mesh_shape)}) "
f"!= world size({torch.distributed.get_world_size()})")
device_mesh = init_device_mesh(device_type=torch.accelerator.current_accelerator().type, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names)
for mesh_name in mesh_dim_names:
self.device_mesh_map[mesh_name] = device_mesh
is_enable, get_group, get_size, get_rank, get_mesh = get_methods(mesh_name)
setattr(self, 'is_{}_enable'.format(mesh_name), types.MethodType(is_enable, self))
setattr(self, 'get_{}_group'.format(mesh_name), types.MethodType(get_group, self))
setattr(self, 'get_{}_group_size'.format(mesh_name), types.MethodType(get_size, self))
setattr(self, 'get_{}_rank'.format(mesh_name), types.MethodType(get_rank, self))
setattr(self, 'get_{}_device_mesh'.format(mesh_name), types.MethodType(get_mesh, self))