from dataclasses import dataclass, field
from typing import Any, Dict, List, Callable, Literal, Union, Optional
import torch
@dataclass
class FSDPPlanConfig:
ignored_modules: List[str] = None
apply_modules: Dict[str, Any] = None
param_dtype: Optional[torch.dtype] = None
reduce_dtype: Optional[torch.dtype] = None
output_dtype: Optional[torch.dtype] = None
cast_forward_inputs: bool = True
num_to_forward_prefetch: Optional[int] = 0
num_to_backward_prefetch: Optional[int] = 0
@dataclass
class TPPlanConfig:
colwise_parallel: List[str] = None
rowwise_parallel: List[str] = None
sequence_parallel: List[str] = None
@dataclass
class CPPlanConfig:
context_parallel_type: str = None
is_pack: bool = False
@dataclass
class EPPlanConfig:
apply_modules: List[str] = None
dispatcher: Union[Literal["eager", "fused", "mc2"], Callable] = None
apply_efsdp_modules: List[str] = None
_gradient_divide_factor: float = None
@dataclass
class QuantizeConfig:
quant_format: Optional[str] = None
quant_recipe: Optional[str] = None
block_size: int = 32
quant_apply_modules: List[str] = None
quant_ignored_modules: List[str] = None
converters: List[str] = None
quant_gmm: bool = False
gemm_gradient_accumulation_fusion: bool = False
extra_args: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ParallelEngineConfig:
data_parallel_size: int = 1
fully_shard_parallel_size: int = 1
fsdp_plan: FSDPPlanConfig = None
tensor_parallel_size: int = 1
tp_plan: TPPlanConfig = None
context_parallel_size: int = 1
context_parallel_type: Literal["ulysses"] = "ulysses"
cp_plan: CPPlanConfig = None
expert_parallel_size: int = 1
expert_fully_shard_parallel_size: int = 1
expert_data_parallel_size: int = 1
ep_plan: EPPlanConfig = None
recompute: bool = False
recompute_plan: List[str] = None
quantization_plan: Optional[QuantizeConfig] = None
def __post_init__(self):
self.validate_tp_config()
self.validate_ep_config()
self.validate_cp_config()
self.validate_recompute_config()
self.validate_quantization_config()
self.validate_fsdp_config()
def validate_fsdp_config(self):
''' fully shard plan
config = ParallelEngineConfig(
fsdp_plan=FSDPPlanConfig(
'ignored_modules':['*mlp.experts*'],
'apply_modules': {
'model.layers.*': {reshard_after_forward=None, shard_placement_fn=None}
}
)
)
'''
self.fsdp_plan = FSDPPlanConfig() if self.fsdp_plan is None else self.fsdp_plan
if self.fully_shard_parallel_size > 1:
if self.expert_parallel_size > 1:
self.fsdp_plan.ignored_modules.extend(self.ep_plan.apply_modules)
if self.tensor_parallel_size > 1:
self.fsdp_plan.ignored_modules.extend(self.tp_plan.colwise_parallel)
self.fsdp_plan.ignored_modules.extend(self.tp_plan.rowwise_parallel)
self.fsdp_plan.ignored_modules = list(set(self.fsdp_plan.ignored_modules))
def validate_tp_config(self):
''' tensor parallelize plan
config = ParallelEngineConfig(
tp_plan=TPPlanConfig(
colwise_parallel=['*.q_proj', '*.k_proj', '*.v_proj'],
rowwise_parallel=['*.o_proj']
)
)
'''
self.tp_plan = TPPlanConfig() if self.tp_plan is None else self.tp_plan
self.tp_plan.colwise_parallel = [] if self.tp_plan.colwise_parallel is None else self.tp_plan.colwise_parallel
self.tp_plan.rowwise_parallel = [] if self.tp_plan.rowwise_parallel is None else self.tp_plan.rowwise_parallel
self.tp_plan.sequence_parallel = [] if self.tp_plan.sequence_parallel is None else self.tp_plan.sequence_parallel
def validate_ep_config(self):
''' expert parallelize plan
config = ParallelEngineConfig(
ep_plan=EPPlanConfig(
apply_modules: ['*mlp.experts*'],
dispatcher: 'eager', 'fused', 'mc2'
)
)
'''
self.ep_plan = EPPlanConfig(apply_modules=[], dispatcher='eager') if self.ep_plan is None else self.ep_plan
self.ep_plan._gradient_divide_factor = self.expert_parallel_size * self.expert_fully_shard_parallel_size * self.expert_data_parallel_size
if self.ep_plan.apply_efsdp_modules is None:
self.ep_plan.apply_efsdp_modules = []
for ep_module in self.ep_plan.apply_modules:
if ep_module.endswith('.experts'):
self.ep_plan.apply_efsdp_modules.append(ep_module.removesuffix('.experts'))
def validate_recompute_config(self):
self.recompute_plan = [] if self.recompute_plan is None else self.recompute_plan
def validate_cp_config(self):
if self.context_parallel_type not in ["ulysses"]:
raise Exception("context parallel type must be `ulysses`.")
def validate_quantization_config(self):
self.quantization_plan = QuantizeConfig() if self.quantization_plan is None else self.quantization_plan