import torch
from mindspeed.fsdp.distributed.tensor_parallel.tensor_parallel import tensor_parallel_modules
from .expert_parallel.expert_parallel import expert_parallelize_modules
from .expert_parallel.expert_fully_shard_parallel import expert_fully_shard_modules
from .fully_shard_parallel import fully_shard_parallel_modules, set_modules_to_prefetch
from .parallel_state import get_parallel_state
from ..params.parallel_args import ParallelArguments
from ..params.training_args import TrainingArguments
class ParallelApplier:
def __init__(self, parallel_config: ParallelArguments, training_config: TrainingArguments):
self.config = parallel_config
self.training_config = training_config
self.parallel_state = get_parallel_state()
def apply_fsdp_modules(self, model, training_config):
model = fully_shard_parallel_modules(model, self.parallel_state.get_fsdp_device_mesh(), self.config.fsdp_plan,
training_config)
return model
def apply_tp_modules(self, model):
if self.config.tensor_parallel_size == 1:
return
model = tensor_parallel_modules(model, self.parallel_state.get_tp_device_mesh(), self.config.tp_plan)
def apply_ep_modules(self, model):
if not self.config.ep_plan.apply_efsdp_modules:
self.config.ep_plan.apply_efsdp_modules = self.config.ep_plan.apply_modules
if self.config.ep_plan._gradient_divide_factor is None:
self.config.ep_plan._gradient_divide_factor = torch.distributed.get_world_size()
if self.config.expert_parallel_size > 1 and self.config.ep_plan.apply_modules:
model = expert_parallelize_modules(model, self.parallel_state.get_ep_device_mesh(), self.config.ep_plan)
model = expert_fully_shard_modules(model, self.parallel_state.get_efsdp_device_mesh(), self.config.ep_plan, self.config.fsdp_plan)
self.config.fsdp_plan.apply_modules = [x for x in self.config.fsdp_plan.apply_modules if x not in self.config.ep_plan.apply_efsdp_modules]
def set_modules_to_prefetch(self, model):
if self.config.fsdp_plan.num_to_forward_prefetch > 0 or self.config.fsdp_plan.num_to_backward_prefetch > 0:
ep_plan = self.config.ep_plan if self.config.expert_parallel_size > 1 and self.config.ep_plan.apply_modules else None
set_modules_to_prefetch(model, fsdp_plan=self.config.fsdp_plan, ep_plan=ep_plan)
return model
def __call__(self, model):
self.apply_tp_modules(model=model)
self.apply_ep_modules(model=model)
model = self.apply_fsdp_modules(model=model, training_config=self.training_config)
model = self.set_modules_to_prefetch(model=model)
return model