import torch
from mindspeed.fsdp.distributed.expert_parallel.expert_fully_shard_parallel import expert_fully_shard_modules
from mindspeed.fsdp.distributed.fully_shard_parallel.fully_shard_parallel import \
fully_shard_parallel_modules
from mindspeed.fsdp.distributed.parallel_state import init_parallel_state
from mindspeed.fsdp.distributed.tensor_parallel.tensor_parallel import tensor_parallel_modules
from mindspeed.fsdp.memory.recompute.recompute import recompute_modules
from mindspeed.fsdp.parallel_engine_config import ParallelEngineConfig
from mindspeed.fsdp.distributed.expert_parallel.expert_parallel import expert_parallelize_modules
class MindSpeedParallelEngine(torch.nn.Module):
def __init__(self, config: ParallelEngineConfig, model: torch.nn.Module):
super(MindSpeedParallelEngine, self).__init__()
self.config = config
self.model = model
self.parallel_state = init_parallel_state(self.config)
self.apply_quantization_modules()
self.apply_tp_modules()
self.apply_ep_modules()
self.apply_recompute_modules()
self.apply_fsdp_modules()
def apply_fsdp_modules(self):
self.model = fully_shard_parallel_modules(self.model, self.parallel_state.get_fsdp_device_mesh(), self.config.fsdp_plan)
def apply_tp_modules(self):
if self.config.tensor_parallel_size == 1:
return
self.model = tensor_parallel_modules(self.model, self.parallel_state.get_tp_device_mesh(), self.config.tp_plan)
def apply_ep_modules(self):
if self.config.expert_parallel_size > 1:
self.model = expert_parallelize_modules(self.model, self.parallel_state.get_ep_device_mesh(), self.config.ep_plan)
self.model = expert_fully_shard_modules(self.model, self.parallel_state.get_efsdp_device_mesh(), self.config.ep_plan)
def apply_recompute_modules(self):
if not self.config.recompute:
return
self.model = recompute_modules(self.model, self.config.recompute_plan)
def apply_quantization_modules(self):
"""Apply quantization based on quantization_format + quantization_recipe."""
if not self.config.quantization_plan.recipe_name:
return
try:
if self.config.recompute:
self.config.quantization_plan.fsdp_low_precision_all_gather_mode = "all"
from mindspeed.fsdp.quantization.converter.model_converter import build_model_converter
model_converters = build_model_converter(self.config.quantization_plan)
model_converters.convert(self.model)
except Exception as e:
raise RuntimeError(f"Failed to convert quantization plan") from e
def apply_optimizer_hook(self, optimizer: torch.optim.Optimizer):
if not self.config.quantization_plan.recipe_name:
return
from mindspeed.fsdp.quantization.core.cache import hook_optimizer_step
hook_optimizer_step(self.model, optimizer)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)