from functools import wraps
from functools import partial
import torch
import torch.nn.functional as F
from einops import rearrange
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.tensor_parallel.mappings import _split_along_first_dim
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.moe.moe_utils import save_to_aux_losses_tracker
from megatron.core import parallel_state
from megatron.training import get_args
from megatron.core.transformer.moe.moe_utils import topk_softmax_with_capacity
from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
from mindspeed_llm.tasks.models.common.pai_megatron import pai_megatron_aux_loss
from mindspeed_llm.core.transformer.moe.moe_utils import topk_softmax_with_capacity_and_hash
def group_limited_greedy_topKgating(self, logits: torch.Tensor):
args = get_args()
seq_length = logits.shape[0]
scores = F.softmax(logits, dim=1)
group_scores = (
scores.view(args.micro_batch_size * seq_length, self.n_group, -1).max(dim=-1).values
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(
args.micro_batch_size * seq_length, self.n_group, args.num_experts // self.n_group
)
.reshape(args.micro_batch_size * seq_length, -1)
)
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
topk_weight, topk_idx = torch.topk(
tmp_scores, k=args.moe_router_topk, dim=-1, sorted=False
)
if args.moe_router_topk > 1 and args.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
else:
topk_weight = topk_weight * args.moe_router_topk_scaling_factor
topk_masked_gates = torch.zeros_like(logits).scatter(1, topk_idx, topk_weight)
topk_map = torch.zeros_like(logits).int().scatter(1, topk_idx, 1).bool()
if not self.training:
l_aux = None
self.l_aux = l_aux
return topk_masked_gates, topk_map
scores_for_aux = scores
topk_idx_for_aux_loss = topk_idx.view(args.micro_batch_size, -1)
topk_group_idx_for_aux_loss = group_idx.view(args.micro_batch_size, -1)
fi, Pi, l_aux = None, None, 0
if self.config.moe_aux_loss_coeff > 0:
l_expert_aux = 0
if args.seq_aux:
scores_for_seq_aux = scores_for_aux.view(args.micro_batch_size, seq_length, -1)
ce = torch.zeros(
args.micro_batch_size, args.num_experts, device=logits.device
)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(args.micro_batch_size, seq_length * args.moe_router_topk, device=logits.device),
)
num_sub_sequence = 1
sequence_partition_group = parallel_state.get_context_parallel_group()
if sequence_partition_group is not None:
num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group)
torch.distributed.all_reduce(ce, group=sequence_partition_group)
num_tokens = seq_length * num_sub_sequence
fi = ce.div(num_sub_sequence * num_tokens * args.moe_router_topk / args.num_experts)
Pi = scores_for_seq_aux.mean(dim=1)
l_expert_aux = (Pi * fi).sum(dim=1).mean() * self.config.moe_aux_loss_coeff
else:
mask_ce = F.one_hot(
topk_idx_for_aux_loss.view(-1), num_classes=args.num_experts
)
ce = mask_ce.to(logits.dtype).mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * args.num_experts
l_expert_aux = (Pi * fi).sum() * self.config.moe_aux_loss_coeff
self.l_expert_aux = l_expert_aux
l_aux += l_expert_aux
P_devi = None
if args.moe_device_level_aux_loss_coeff > 0:
l_device_aux = 0
if args.seq_aux:
if fi is None:
scores_for_seq_aux = scores_for_aux.view(args.micro_batch_size, seq_length, -1)
ce = torch.zeros(
args.micro_batch_size, args.num_experts, device=logits.device
)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(args.micro_batch_size, seq_length * args.moe_router_topk, device=logits.device),
)
fi = ce.div(seq_length * args.moe_router_topk / args.num_experts)
Pi = scores_for_seq_aux.mean(dim=1)
P_devi = Pi.view(args.micro_batch_size, self.n_group, -1).sum(-1)
f_devi = fi.view(args.micro_batch_size, self.n_group, -1).mean(-1)
l_device_aux = (f_devi * P_devi).sum(dim=1).mean() * args.moe_device_level_aux_loss_coeff
else:
if fi is None:
mask_ce = F.one_hot(
topk_idx_for_aux_loss.view(-1), num_classes=args.num_experts
)
ce = mask_ce.to(logits.dtype).mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * args.num_experts
P_devi = Pi.view(self.n_group, -1).sum(-1)
f_devi = fi.view(self.n_group, -1).mean(-1)
l_device_aux = (f_devi * P_devi).sum() * args.moe_device_level_aux_loss_coeff
self.l_device_aux = l_device_aux
l_aux += l_device_aux
if args.moe_comm_aux_loss_coeff > 0:
l_comm_aux = 0
if args.seq_aux:
if P_devi is None:
if Pi is None:
scores_for_seq_aux = scores_for_aux.view(args.micro_batch_size, seq_length, -1)
Pi = scores_for_seq_aux.mean(dim=1)
P_devi = Pi.view(args.micro_batch_size, self.n_group, -1).sum(-1)
ge = torch.zeros(
args.micro_batch_size, seq_length, args.num_experts, device=logits.device
)
ge.scatter_add_(
2,
topk_idx_for_aux_loss.view(args.micro_batch_size, seq_length, -1),
torch.ones(args.micro_batch_size, seq_length, args.moe_router_topk, device=logits.device),
)
ge = (ge.view(args.micro_batch_size, seq_length, self.n_group, -1).sum(-1) > 0).to(logits.dtype).sum(dim=1)
ge.div_(seq_length * self.topk_group / self.n_group)
l_comm_aux = (ge * P_devi).sum(dim=1).mean() * args.moe_comm_aux_loss_coeff
else:
if P_devi is None:
if Pi is None:
Pi = scores_for_aux.mean(0)
P_devi = Pi.view(self.n_group, -1).sum(-1)
ge = torch.zeros(
args.micro_batch_size, seq_length, args.num_experts, device=logits.device
)
ge.scatter_add_(
2,
topk_idx_for_aux_loss.view(args.micro_batch_size, seq_length, -1),
torch.ones(args.micro_batch_size, seq_length, args.moe_router_topk, device=logits.device),
)
ge = rearrange(ge, 'b s (ng gs) -> (b s) ng gs', ng=self.n_group, gs=args.num_experts // self.n_group)
ge = (ge.sum(dim=-1) > 0).to(logits.dtype).mean(0).div(self.topk_group / self.n_group)
l_comm_aux = (ge * P_devi).sum() * args.moe_comm_aux_loss_coeff
self.l_comm_aux = l_comm_aux
l_aux += l_comm_aux
self.l_aux = l_aux
return topk_masked_gates, topk_map
class custom_multiplier(torch.autograd.Function):
@staticmethod
def forward(
ctx,
scores: torch.Tensor,
multiplier: torch.Tensor,
selected_experts: torch.Tensor,
masked_gates: torch.Tensor,
mask_for_one: torch.Tensor,
):
ctx.save_for_backward(multiplier, selected_experts, masked_gates)
return multiplier * mask_for_one
@staticmethod
def backward(
ctx,
grad_at_output: torch.Tensor,
):
multiplier, selected_experts, masked_gates = ctx.saved_tensors
grad_at_output = grad_at_output * multiplier
grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
grad_at_scores_expaned.scatter_add_(
dim=-1,
index=selected_experts,
src=grad_at_output,
)
return (
grad_at_scores_expaned,
None,
None,
None,
None,
)
def sparsemixer_top2(self, scores, jitter_eps=0.01):
if self.topk != 2:
raise ValueError(f"Expected topk to be 2, but got {self.topk}.")
with torch.no_grad():
mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf'))
if self.training:
selected_experts = (masked_gates - torch.empty_like(
masked_gates, memory_format=torch.legacy_contiguous_format
).exponential_().log()).max(dim=-1)[1].unsqueeze(-1)
else:
selected_experts = max_ind
masked_gates = torch.softmax(masked_gates, dim=-1)
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
if self.training:
max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
mask_for_one = torch.logical_or(
selected_experts == max_ind,
torch.rand_like(max_scores) > 0.75
).int()
mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
multiplier = custom_multiplier.apply(
scores,
multiplier_o,
selected_experts,
masked_gates,
mask_for_one,
)
else:
multiplier = multiplier_o
masked_scores = torch.scatter(
scores,
-1,
selected_experts,
float('-inf'),
)
with torch.no_grad():
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
if self.training:
selected_experts_top2 = (masked_gates_top2 - torch.empty_like(
masked_gates_top2, memory_format=torch.legacy_contiguous_format
).exponential_().log()
).max(dim=-1)[1].unsqueeze(-1)
else:
selected_experts_top2 = max_ind
masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
if self.training:
max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
mask_for_one_top2 = torch.logical_or(
selected_experts_top2 == max_ind,
torch.rand_like(max_scores).uniform_() > 0.75
).int()
mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
multiplier_top2 = custom_multiplier.apply(
scores,
multiplier_top2_o,
selected_experts_top2,
masked_gates_top2,
mask_for_one_top2,
)
else:
multiplier_top2 = multiplier_top2_o
multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
multiplier = torch.zeros_like(scores).scatter(1, selected_experts, multiplier)
selected_experts = torch.zeros_like(scores).int().scatter(1, selected_experts, 1).bool()
return (
multiplier,
selected_experts,
)
def topk_router_build_hash_module(self):
mg_args = get_args()
self.hash = self.layer_number <= mg_args.n_hash_layers
if self.hash:
self.tid2eid = torch.nn.Parameter(
torch.stack([torch.randperm(mg_args.moe_router_topk) for _ in range(mg_args.padded_vocab_size)]),
requires_grad=False
)
self.expert_bias = None
def topk_router_init_wrapper(function):
@wraps(function)
def topk_router_init(self, *args, **kwargs):
function(self, *args, **kwargs)
mg_args = get_args()
if mg_args.num_zero_experts is not None:
self.num_experts = mg_args.num_experts + mg_args.num_zero_experts
self.weight = torch.nn.Parameter(
torch.empty((self.num_experts, self.config.hidden_size), dtype=torch.float32)
)
if self.config.perform_initialization:
self.config.init_method(self.weight)
self.weight.data = self.weight.data.to(dtype=self.config.params_dtype)
if self.enable_expert_bias:
self.register_buffer(
'local_tokens_per_expert',
torch.zeros(self.num_experts, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
'expert_bias', torch.zeros(self.num_experts, dtype=torch.float32)
)
else:
self.local_tokens_per_expert = None
self.expert_bias = None
self.n_group = mg_args.moe_router_num_groups if mg_args.moe_router_num_groups is not None else (
mg_args.expert_model_parallel_size)
self.topk_group = mg_args.moe_router_group_topk
self.norm_topk_prob = mg_args.norm_topk_prob
setattr(self.__class__, 'build_hash_module', topk_router_build_hash_module)
return topk_router_init
def topk_router_forward_patch(self, input: torch.Tensor, input_ids: torch.Tensor = None):
"""
patch for TopKRouter forward
Args:
input (torch.Tensor): Input tensor.
input_ids (torch.Tensor): Input ids.
"""
self._maintain_float32_expert_bias()
input = self.apply_input_jitter(input)
logits = self.gating(input)
scores, routing_map = self.routing(logits, input_ids)
return scores, routing_map
def apply_seq_aux_loss(self, activation, logits, topk_idx):
"""
Apply complementary sequence-wise auxiliary loss
"""
args = get_args()
moe_aux_loss_coeff = self.config.moe_aux_loss_coeff / parallel_state.get_tensor_model_parallel_world_size()
if moe_aux_loss_coeff == 0:
return activation
num_tokens, num_experts = logits.shape
seq_length = num_tokens // args.micro_batch_size
if self.score_function == "softmax":
scores = torch.softmax(logits, dim=-1)
elif self.score_function == "sigmoid":
scores = torch.sigmoid(logits)
if self.expert_bias is not None:
scores = scores + self.expert_bias
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
elif self.score_function == "sqrtsoftplus":
scores = F.softplus(logits).sqrt()
if self.expert_bias is not None:
scores = scores + self.expert_bias
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {self.score_function}")
scores_for_aux = scores
topk_idx_for_aux_loss = topk_idx.view(args.micro_batch_size, -1)
scores_for_seq_aux = scores_for_aux.view(args.micro_batch_size, seq_length, -1)
ce = torch.stack([torch.histc(x.to(torch.int32), bins=args.num_experts, min=0, max=args.num_experts) for x in
topk_idx_for_aux_loss])
num_sub_sequence = 1
sequence_partition_group = parallel_state.get_context_parallel_group()
if sequence_partition_group is not None:
num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group)
moe_aux_loss_coeff /= num_sub_sequence
torch.distributed.all_reduce(ce, group=sequence_partition_group)
num_tokens = seq_length * num_sub_sequence
fi = ce.div(num_sub_sequence * num_tokens * args.moe_router_topk / args.num_experts)
Pi = scores_for_seq_aux.mean(dim=1)
aux_loss = (Pi * fi).sum(dim=1).mean() * moe_aux_loss_coeff
save_to_aux_losses_tracker(
"load_balancing_loss",
aux_loss / moe_aux_loss_coeff,
self.layer_number,
self.config.num_layers,
reduce_group=sequence_partition_group,
)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def topk_router_gating_func(self, input: torch.Tensor):
_args = get_args()
if _args.router_gating_in_fp32:
if not self.weight.requires_grad:
logits = F.linear(input.type(torch.float32), self.weight.type(torch.float32))
else:
def to_fp32(_input, weight):
return _input.type(torch.float32), weight.type(torch.float32)
self.fp32_checkpoint_manager = CheckpointWithoutOutput()
input, weight = self.fp32_checkpoint_manager.checkpoint(to_fp32, False, input, self.weight)
logits = torch.nn.functional.linear(input, weight)
self.fp32_checkpoint_manager.discard_output()
if logits.requires_grad:
logits.register_hook(self.fp32_checkpoint_manager.recompute)
else:
if self.config.moe_router_dtype == 'fp8':
from mindspeed.te.pytorch.fp8.recipes import matmul_fp8
logits = matmul_fp8(input, self.weight)
else:
if self.config.moe_router_dtype == 'fp32':
router_dtype = torch.float32
elif self.config.moe_router_dtype == 'fp64':
router_dtype = torch.float64
else:
router_dtype = input.dtype
logits = torch.nn.functional.linear(input.to(router_dtype), self.weight.to(router_dtype))
return logits
def topk_router_routing(self, logits: torch.Tensor, input_ids: torch.Tensor = None):
"""Top-k routing function
Args:
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mask of token to experts assignment.
"""
args = get_args()
seq_length, bsz = logits.shape[:2]
logits = logits.view(-1, self.num_experts)
logits = self.apply_z_loss(logits)
args = get_args()
if (
self.config.tensor_model_parallel_size > 1
and self.config.moe_token_dispatcher_type == "alltoall_seq"
):
logits = gather_from_sequence_parallel_region(logits)
if self.routing_type == "sinkhorn":
scores, routing_map = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, routing_map = self.aux_loss_load_balancing(logits)
if args.norm_topk_prob:
scores = scores / scores.sum(dim=-1, keepdim=True)
if args.topk_softmax_in_fp32:
scores = scores.type_as(logits)
elif self.routing_type == "seq_aux_loss":
scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length)
elif self.routing_type == "softmax_topk":
if args.moe_revert_type_after_topk:
logits_ = torch.softmax(logits, dim=-1, dtype=torch.float32)
else:
logits_ = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
if self.expert_bias is not None:
logits_for_routing = logits_ + self.expert_bias
_, indices = torch.topk(logits_for_routing, k=self.topk, dim=1)
scores = torch.gather(logits_, dim=1, index=indices).type_as(logits_)
else:
scores, indices = torch.topk(logits_, k=self.topk, dim=1)
if args.norm_topk_prob:
scores = scores / scores.sum(dim=-1, keepdim=True)
if self.config.moe_router_topk_scaling_factor is not None:
scores *= self.config.moe_router_topk_scaling_factor
scores = torch.zeros_like(logits_).scatter(1, indices, scores)
routing_map = torch.zeros_like(logits_).int().scatter(1, indices, 1).bool()
elif self.routing_type == "group_limited_greedy":
scores, routing_map = group_limited_greedy_topKgating(self, logits)
elif self.routing_type == "pai_megatron_aux_loss":
scores, routing_map = pai_megatron_aux_loss(self, logits)
elif self.routing_type == "sparsemixer_topk":
scores, routing_map = sparsemixer_top2(self, logits)
elif self.routing_type == "none":
if args.n_hash_layers >= 1:
scores, routing_map, _ = topk_softmax_with_capacity_and_hash(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
token_hash=self.hash if hasattr(self, "hash") else None,
tid2eid=self.tid2eid if hasattr(self, "tid2eid") else None,
input_ids=input_ids,
)
else:
scores, routing_map, _ = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
args = get_args()
if self.training and args.seq_aux:
scores = apply_seq_aux_loss(self,
activation=scores,
logits=logits,
topk_idx=routing_map,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
if args.moe_tp_extend_ep:
scores = _split_along_first_dim(scores)
routing_map = _split_along_first_dim(routing_map)
if self.enable_expert_bias and torch.is_grad_enabled():
with torch.no_grad():
self.local_tokens_per_expert += routing_map.sum(dim=0)
if args.fix_router:
def fix_indices(index_tensor, logits_shape, router_topk):
expert_select = torch.arange(index_tensor.shape[0] * router_topk, device=index_tensor.device,
dtype=torch.int64).view(index_tensor.shape[0], router_topk) % logits_shape[-1]
routing_map = torch.zeros(index_tensor.shape, device=index_tensor.device, dtype=torch.bool)
routing_map.scatter_(1, expert_select, True)
return routing_map
if isinstance(routing_map, tuple):
routing_map = list(routing_map)
routing_map[0] = fix_indices(routing_map[0], logits.shape, args.moe_router_topk)
routing_map = tuple(routing_map)
else:
routing_map = fix_indices(routing_map, logits.shape, args.moe_router_topk)
return scores, routing_map
def _maintain_float32_expert_bias(self):
"""
Maintain the expert bias in float32.
When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module.
We keep it in float32 to avoid routing errors when updating the expert_bias.
"""
if hasattr(self, 'expert_bias') and self.expert_bias is not None:
if self.expert_bias.dtype != torch.float32:
self.expert_bias.data = self.expert_bias.data.to(torch.float32)
def global_aux_loss_load_balancing(self, logits: torch.Tensor):
"""Apply auxiliary loss-based load balancing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mask of token to experts assignment.
"""
probs, routing_map, _ = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
)
return probs, routing_map
def global_aux_loss_topk_router_forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
self._maintain_float32_expert_bias()
input = self.apply_input_jitter(input)
logits = self.gating(input)
scores, routing_map = self.routing(logits)
return scores, routing_map, logits.detach()
def global_load_balancing_loss_func(router_logits, attention_mask, config):
"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
router_logits:
Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
shape [sequence_length, batch_size, num_experts].
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
config: config arguments
Returns:
The auxiliary loss.
"""
if router_logits is None or not isinstance(router_logits, tuple):
return 0
if isinstance(router_logits, tuple):
compute_device = router_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device).transpose(0, 1).reshape(-1, layer_gate.shape[2])
for layer_gate in router_logits], dim=0)
top_k = config.moe_router_topk
num_experts = concatenated_gate_logits.shape[1]
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
expert_attention_mask, dim=0
)
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
router_per_expert_attention_mask, dim=0
)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
def transformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
allowed_score_function = {"softmax", "sigmoid", "sqrtsoftplus"}
bypass_flag = (
getattr(self, "moe_router_enable_expert_bias", False)
and getattr(self, "moe_router_score_function", None) in allowed_score_function
and getattr(self, "moe_router_score_function", None) != "sigmoid"
)
if not bypass_flag:
return fn(self)
old_score_fn = self.moe_router_score_function
try:
self.moe_router_score_function = "sigmoid"
return fn(self)
finally:
self.moe_router_score_function = old_score_fn
return wrapper