import torch
def load_balancing_loss_func_optimized(
gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
num_experts: int | None = None,
top_k: int = 2,
attention_mask: torch.Tensor | None = None,
context_parallel_group=None,
) -> torch.Tensor | int:
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
num_layers = len(gate_logits)
if num_layers == 0:
return 0
compute_device = gate_logits[0].device
tokens_selected = torch.zeros(top_k, num_experts, device=compute_device)
if attention_mask is None:
tokens_total = torch.zeros(top_k, num_experts, device=compute_device)
else:
tokens_total = torch.zeros(top_k, num_experts, device=compute_device, dtype=attention_mask.dtype)
expert_attention_mask = None
for layer_gate in gate_logits:
routing_weights = torch.nn.functional.softmax(layer_gate, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
num_tokens = layer_gate.shape[0]
if expert_attention_mask is None or expert_attention_mask.shape[0] != num_tokens:
expert_attention_mask = torch.ones(
num_tokens, top_k, num_experts,
device=compute_device, dtype=torch.float32
).reshape(-1, top_k, num_experts)
layer_tokens_total = torch.sum(expert_attention_mask, dim=0)
layer_tokens_selected = torch.sum(expert_mask.float(), dim=0)
else:
batch_size, sequence_length = attention_mask.shape
if expert_attention_mask is None:
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((1, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
layer_tokens_total = torch.sum(expert_attention_mask, dim=0)
layer_tokens_selected = torch.sum(expert_mask.float() * expert_attention_mask, dim=0)
tokens_selected += layer_tokens_selected
tokens_total += layer_tokens_total
if context_parallel_group is not None and torch.distributed.get_world_size(group=context_parallel_group) > 1:
torch.distributed.all_reduce(
tokens_total,
op=torch.distributed.ReduceOp.SUM,
group=context_parallel_group
)
torch.distributed.all_reduce(
tokens_selected,
op=torch.distributed.ReduceOp.SUM,
group=context_parallel_group
)
tokens_per_expert = tokens_selected / tokens_total
compute_device = gate_logits[0].device
routing_weights = torch.cat([torch.nn.functional.softmax(layer_gate, dim=-1).to(compute_device) for layer_gate in gate_logits], dim=0)
if attention_mask is not None:
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
router_selected = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0)
router_total = torch.sum(router_per_expert_attention_mask, dim=0)
else:
num_tokens = gate_logits[0].shape[0]
router_per_expert_attention_mask = torch.ones(
num_layers, num_tokens, num_experts,
device=compute_device, dtype=torch.float32
).reshape(-1, num_experts)
router_selected = torch.sum(routing_weights, dim=0)
router_total = torch.sum(router_per_expert_attention_mask, dim=0)
if context_parallel_group is not None and torch.distributed.get_world_size(group=context_parallel_group) > 1:
torch.distributed.all_reduce(
router_total,
op=torch.distributed.ReduceOp.SUM,
group=context_parallel_group
)
router_prob_per_expert = router_selected / router_total
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts