import logging
import types
from functools import partial
from typing import Callable
import torch
from torch.distributed import DeviceMesh
from torch.distributed.tensor import Shard, DTensor, Replicate, distribute_tensor, distribute_module
from mindspeed.fsdp.utils.log import print_rank
from mindspeed.fsdp.utils.str_match import module_name_match
from mindspeed_mm.fsdp.params.parallel_args import EPPlanConfig
logger = logging.getLogger(__name__)
def expert_parallelize_modules(modules: torch.nn.Module, ep_mesh: DeviceMesh, plan: EPPlanConfig):
ep_modules = get_ep_modules(modules, plan)
ep_group = ep_mesh.get_group()
ep_rank = torch.distributed.get_rank(ep_group)
ep_size = torch.distributed.get_world_size(ep_group)
for module in ep_modules:
distribute_experts_module(module, ep_mesh)
if hasattr(module, 'ep_forward') and callable(module.ep_forward):
module.forward = partial(module.ep_forward, ep_group=ep_group, ep_plan=plan)
else:
experts_forward_fn = get_experts_forward_fn_for_qwen(ep_group, use_npu_fused_ops=plan.use_npu_fused_ops, dispatcher=plan.dispatcher)
module.forward = types.MethodType(experts_forward_fn, module)
return modules
def get_ep_modules(modules: torch.nn.Module, plan: EPPlanConfig):
ep_modules = []
for plan_name in plan.apply_modules:
for name, module in modules.named_modules():
if module_name_match(plan_name, name):
print_rank(logger.debug, f'[Expert Parallel]: Apply ep to module <{name}>')
ep_modules.append(module)
if len(ep_modules) == 0:
raise RuntimeError(f'[Expert Parallel] No module named {plan} or not be ModuleList')
return ep_modules
def prepare_distribute_input_fn(module, inputs, device_mesh):
inputs = list(inputs)
for idx, input_tensor in enumerate(inputs):
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, (Replicate(),), run_check=False)
inputs[idx] = input_tensor
return *inputs,
def prepare_distribute_output_fn(module, outputs, device_mesh):
return outputs.to_local()
def distribute_expert_weight(module_name, module, ep_mesh):
for name, param in module.named_parameters(recurse=False):
dist_param = torch.nn.Parameter(distribute_tensor(param, ep_mesh, [Shard(0)]))
module.register_parameter(name, dist_param)
for name, children_module in module.named_children():
distribute_expert_weight(name, children_module, ep_mesh)
def distribute_experts_module(module: torch.nn.Module, ep_mesh: DeviceMesh):
return distribute_module(module=module, device_mesh=ep_mesh, partition_fn=distribute_expert_weight,)
def get_grad_division_hook(param, ep_size):
def hook(*unused):
return param.grad.mul_(1 / ep_size)
return hook
def apply_grad_division_hook(module, ep_size):
for param in module.parameters():
if param.requires_grad:
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(get_grad_division_hook(param, ep_size))
def get_experts_forward_fn_for_qwen(ep_group, use_npu_fused_ops=True, dispatcher="alltoall"):
from .ep_dispatcher import ep_forward, ep_mc2_forward
def experts_forward(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor):
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
gate_up_proj = self.gate_up_proj.to_local() if isinstance(self.gate_up_proj, DTensor) else self.gate_up_proj
down_proj = self.down_proj.to_local() if isinstance(self.down_proj, DTensor) else self.down_proj
fused = use_npu_fused_ops
ep_dispatcher_dict = {
"alltoall": ep_forward,
"mc2": ep_mc2_forward
}
if dispatcher in ep_dispatcher_dict:
dipatcher_func = ep_dispatcher_dict[dispatcher]
hidden_states = dipatcher_func(
self.num_experts,
routing_weights,
router_indices,
hidden_states,
fc1_weight=gate_up_proj,
fc2_weight=down_proj,
ep_group=ep_group,
fused=fused,
)
else:
raise NotImplementedError(f"ep dispatcher {dispatcher} is not implenmented now.")
hidden_states = hidden_states.view(batch_size, -1, self.hidden_size)
return hidden_states
return experts_forward