import typing
import torch
from megatron.training import get_args
from .experts import Experts
from .gate import TopKGate
from .moe_layer import MOELayer
from .config import Config
from .utils import get_reshape_index_select
class MoE(torch.nn.Module):
"""Initialize an MoE layer.
Arguments:
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
num_experts (int, optional): default=1, the total number of experts per layer.
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
aux_loss_coef (int, optional): default=0.0, scaling coefficient for the aux loss.
z_loss_coef (int, optional): default=0.0, scaling coefficient for the z loss.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
"""
def __init__(self,
hidden_size,
expert,
num_experts=1,
ep_size=1,
k=1,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
aux_loss_coef=0.0,
z_loss_coef=0.0,
ep_group=None,
noisy_gate_policy: typing.Optional[str] = None,
no_drop=False,
dynamic_padding=False,
use_sinkhorn=False,
sequence_parallel=False):
super(MoE, self).__init__()
args = get_args()
pipe_experts = args.use_pipe_experts
sequence_parallel = sequence_parallel
pipe_experts_multi_data = args.pipe_experts_multi_data
pipe_experts_multi_stream = args.pipe_experts_multi_stream
if num_experts % ep_size != 0:
raise AssertionError(f"Number of experts should be divisible by expert parallel size")
num_local_experts = num_experts // ep_size
if pipe_experts:
reshape_index_select = get_reshape_index_select(num_local_experts, ep_size)
else:
reshape_index_select = None
self.num_experts = num_experts
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.ep_group = ep_group
self.hidden_size = hidden_size
config = Config(hidden_size=hidden_size,
num_experts=num_experts,
ep_size=ep_size,
topk=k,
capacity_factor=capacity_factor,
eval_capacity_factor=eval_capacity_factor,
min_capacity=min_capacity,
aux_loss_coef=aux_loss_coef,
z_loss_coef=z_loss_coef,
noisy_gate_policy=noisy_gate_policy,
no_drop=no_drop,
dynamic_padding=dynamic_padding,
use_sinkhorn=use_sinkhorn,
sequence_parallel=sequence_parallel,
reshape_index_select=reshape_index_select
)
self.moe_layer = MOELayer(TopKGate(config),
Experts(expert, num_local_experts),
ep_size,
num_local_experts,
pipe_experts=pipe_experts,
sequence_parallel=sequence_parallel,
pipe_experts_multi_data=pipe_experts_multi_data,
pipe_experts_multi_stream=pipe_experts_multi_stream,
)
self.moe_layer.set_ep_group(ep_group)
def forward(self, hidden_states, used_token=None):
output = self.moe_layer(hidden_states, used_token)
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts