from typing import Optional
import torch
import torch_npu
from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.transformer.moe.moe_utils import (reduce_aux_losses_tracker_across_ranks,
clear_aux_losses_tracker, group_limited_topk,
get_capacity)
from megatron.core.tensor_parallel.utils import divide
AG_TP_HIDDEN_STATUS = None
AG_SHARED_EXPERTS_INPUTS = []
GEMM_BACKWARD_NEED_TENSORS = None
RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = None
SWAP_STREAM = None
SWAP_STREAM2 = None
SWAP_TENSOR = None
MATMUL_OUTPUT_GRAD = None
UNPERMUTED_TOKENS = None
def get_swap_stream():
global SWAP_STREAM2
if SWAP_STREAM2 is None:
_ = torch_npu.npu.Stream(device=torch.npu.current_device())
SWAP_STREAM2 = torch_npu.npu.Stream(device=torch.npu.current_device())
stream = SWAP_STREAM2
return stream
def set_swap_status(tensor):
global SWAP_TENSOR
SWAP_TENSOR = tensor
def get_swap_status():
global SWAP_STREAM
if SWAP_STREAM is None:
SWAP_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
global SWAP_TENSOR
stream = SWAP_STREAM
tensor = SWAP_TENSOR
SWAP_TENSOR = None
return stream, tensor
def set_prob_backward_need_tensors(matmul_output_grad, unpermuted_tokens):
global MATMUL_OUTPUT_GRAD
MATMUL_OUTPUT_GRAD = matmul_output_grad
global UNPERMUTED_TOKENS
UNPERMUTED_TOKENS = unpermuted_tokens
def get_prob_backward_need_tensors():
global SWAP_STREAM2
if SWAP_STREAM2 is None:
_ = torch_npu.npu.Stream(device=torch.npu.current_device())
SWAP_STREAM2 = torch_npu.npu.Stream(device=torch.npu.current_device())
global MATMUL_OUTPUT_GRAD
global UNPERMUTED_TOKENS
stream = SWAP_STREAM2
matmul_output_grad = MATMUL_OUTPUT_GRAD
unpermuted_tokens = UNPERMUTED_TOKENS
MATMUL_OUTPUT_GRAD = None
UNPERMUTED_TOKENS = None
return stream, matmul_output_grad, unpermuted_tokens
def set_ag_tp_hidden_status(_inputs):
global AG_TP_HIDDEN_STATUS
AG_TP_HIDDEN_STATUS = _inputs
def get_ag_tp_hidden_status():
global AG_TP_HIDDEN_STATUS
result = AG_TP_HIDDEN_STATUS
AG_TP_HIDDEN_STATUS = None
return result
def set_gemm_backward_need_tensors(_inputs):
global GEMM_BACKWARD_NEED_TENSORS
GEMM_BACKWARD_NEED_TENSORS = _inputs
def get_gemm_backward_need_tensors():
global GEMM_BACKWARD_NEED_TENSORS
result = GEMM_BACKWARD_NEED_TENSORS
GEMM_BACKWARD_NEED_TENSORS = None
return result
def set_rs_global_hidden_states_grad_with_handle(_inputs):
global RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = _inputs
def get_rs_global_hidden_states_grad_with_handle():
global RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
result = RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE
RS_GLOBAL_HIDDEN_STATES_GRAD_WITH_HANDLE = None
return result
ALL2ALL_EXPERTS_OUTPUT = None
def set_all2all_experts_output(_input):
global ALL2ALL_EXPERTS_OUTPUT
ALL2ALL_EXPERTS_OUTPUT = _input
def get_all2all_experts_output():
global ALL2ALL_EXPERTS_OUTPUT
result = ALL2ALL_EXPERTS_OUTPUT
ALL2ALL_EXPERTS_OUTPUT = None
return result
def only_recompute_activation(layer_number):
args = get_args()
vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vpp_size = args.virtual_pipeline_model_parallel_size
pp_size = args.pipeline_model_parallel_size
if vpp_size is not None:
layer_per_chunk = args.num_layers_per_virtual_pipeline_stage
elif pp_size is not None:
layer_per_chunk = args.num_layers // pp_size
else:
layer_per_chunk = args.num_layers
if vpp_rank is None:
vpp_rank = 0
if vpp_size is None:
vpp_size = 1
recompute_priority = ((layer_number - 1) % layer_per_chunk) * vpp_size + vpp_rank
moe_zero_memory_num_layers = args.moe_zero_memory_num_layers
if moe_zero_memory_num_layers:
if recompute_priority < moe_zero_memory_num_layers:
return False
else:
return True
else:
return False
def forward_func(func, inputs):
def detach_tensor(input_):
if input_.requires_grad and input_.grad_fn is None:
return input_
else:
new_input = input_.detach()
new_input.requires_grad = True
return new_input
detach_inputs = []
if isinstance(inputs, tuple):
for input_ in inputs:
if isinstance(input_, tuple):
detach_input = []
for i in input_:
if isinstance(i, torch.Tensor) and torch.is_floating_point(i):
detach_input.append(detach_tensor(i))
else:
detach_input.append(i)
detach_inputs.append(tuple(detach_input))
else:
if isinstance(input_, torch.Tensor) and torch.is_floating_point(input_):
detach_input = detach_tensor(input_)
else:
detach_input = input_
detach_inputs.append(detach_input)
elif isinstance(inputs, torch.Tensor):
detach_inputs.append(detach_tensor(inputs))
with torch.enable_grad():
output = func(*detach_inputs)
return output, *detach_inputs
def backward_func(func_tensor, gradinputs):
if gradinputs is None or func_tensor.grad_fn is None:
return
if isinstance(gradinputs, torch.Tensor):
func_tensor.backward(gradinputs)
elif isinstance(gradinputs, tuple):
func_tensor.backward(*gradinputs)
def permute(tokens, routing_map, num_out_tokens: int = None):
if routing_map.dim() == 1:
topk = 1
else:
topk = routing_map.size(1)
flatten_indices = routing_map.view(-1)
sorted_indices = torch.sort(flatten_indices.float(), stable=True)[1]
if num_out_tokens is not None:
sorted_indices = sorted_indices[:num_out_tokens]
permuted_tokens = tokens.index_select(0, sorted_indices // topk)
return permuted_tokens, sorted_indices
def unpermute(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
):
assert sorted_indices.numel() == permuted_tokens.size(0)
if probs is not None:
num_unpermuted_tokens = probs.numel()
topk = probs.size(1)
else:
num_unpermuted_tokens = permuted_tokens.size(0)
topk = 1
unpermuted_tokens = torch.zeros(
[num_unpermuted_tokens, permuted_tokens.shape[-1]],
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens)
unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1))
if probs is not None:
unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1)
return unpermuted_tokens
def get_mean(tensor):
"""
Calculate the mean of a tensor, excluding specified 'noop_layers'.
Parameters:
tensor (torch.Tensor): A one-dimensional tensor.
Returns:
float: The mean of the tensor, excluding the 'noop_layers' if specified.
Notes:
- If `args.noop_layers` is a set and is not empty, the mean is calculated by excluding these layers.
- If `args.noop_layers` is empty or None, the mean is calculated directly from the tensor.
- `args.num_layers` represents the total number of layers, used to adjust the mean calculation when
excluding 'noop_layers'.
"""
args = get_args()
if hasattr(args, 'noop_layers') and isinstance(args.noop_layers, set) and len(args.noop_layers) > 0:
return tensor.sum() / (args.num_layers - len(args.noop_layers))
return tensor.mean()
def track_moe_metrics(
loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False
):
reduce_aux_losses_tracker_across_ranks()
tracker = parallel_state.get_moe_layer_wise_logging_tracker()
if writer is not None:
aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()}
for name, loss_list in aux_losses.items():
loss_list_mean = get_mean(loss_list)
if total_loss_dict is not None:
if name not in total_loss_dict:
total_loss_dict[name] = loss_list_mean
else:
total_loss_dict[name] += loss_list_mean
writer.add_scalar(name, loss_list_mean, iteration)
if per_layer_logging:
for i, loss in enumerate(loss_list.tolist()):
writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration)
if wandb_writer:
wandb_writer.log({f"{name}": loss_list_mean}, iteration)
if per_layer_logging:
wandb_writer.log(
{
f"moe/{name}_layer_{i}": loss
for i, loss in enumerate(loss_list.tolist())
},
iteration,
)
clear_aux_losses_tracker()
def get_grouped_expert_params(model, num_local_experts, tp_size, config):
hidden_size = config.hidden_size
ffn_hidden_size = config.moe_ffn_hidden_size
fc1_output_size = ffn_hidden_size
if config.gated_linear_unit:
fc1_output_size *= 2
fc1_ffn_hidden_size_per_expert = divide(fc1_output_size, tp_size)
fc2_input_size = ffn_hidden_size
fc2_ffn_hidden_size_per_expert = divide(fc2_input_size, tp_size)
weight1_reshaped = model.weight1.view(num_local_experts, hidden_size, fc1_ffn_hidden_size_per_expert)
weight2_reshaped = model.weight2.view(num_local_experts, fc2_ffn_hidden_size_per_expert, hidden_size)
group_mlp_expert_params = {}
for idx in range(num_local_experts):
expert_weight1 = weight1_reshaped[idx]
expert_weight2 = weight2_reshaped[idx]
total_params = expert_weight1.numel() + expert_weight2.numel()
group_mlp_expert_params[idx] = {
'weight1': expert_weight1,
'weight2': expert_weight2,
'total_params': total_params
}
return group_mlp_expert_params
def get_expert_param_data(group_mlp_expert_params, params, idx):
expert_param = group_mlp_expert_params[idx]
offset = 0
for weight in ['weight1', 'weight2']:
seg1 = params[offset: offset + expert_param[weight].numel()]
seg1.copy_(expert_param[weight].data.flatten())
offset += expert_param[weight].numel()
def set_expert_param_data(group_mlp_expert_params, params, idx):
expert_param = group_mlp_expert_params[idx]
offset = 0
with torch.no_grad():
for weight in ['weight1', 'weight2']:
seg = params[offset: offset + expert_param[weight].numel()]
expert_param[weight].copy_(seg.reshape(expert_param[weight].shape))
expert_param[weight].grad = None
offset += expert_param[weight].numel()
def get_expert_param_dtype(experts, group_mlp_expert_params, idx):
e = group_mlp_expert_params[idx]["weight1"].dtype
return e
def get_expert_param_size(experts, group_mlp_expert_params, idx):
e = group_mlp_expert_params[idx]["total_params"]
return e
def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
capacity_factor: Optional[float] = None,
pad_to_capacity: bool = False,
drop_policy: str = "probs",
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
deterministic_mode: bool = False,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded
tokens will be 0.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
If "prob", the tokens with the lowest probabilities will be dropped.
If "position", tokens at the end of each batch will be dropped.
use_pre_softmax (bool): Whether to apply softmax or sigmoid before top-k selection.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of selected groups for each token.
scaling_factor (float): Scaling factor of routing score in top-k selection.
deterministic_mode (bool): Deprecated.
score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
expert_bias (torch.Tensor): The bias added to logits for expert routing.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
the routing probabilities for each token to each expert.
- routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
indicating which experts were selected for each token. True values represent
the selected experts.
- tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
the number of local tokens assigned to each expert before dropping and padding.
"""
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens, num_experts = logits.shape
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
args = get_args()
if args.fix_router:
top_indices = torch.arange(top_indices.numel(), device=top_indices.device,
dtype=torch.int64).view(top_indices.shape) % logits.shape[-1]
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
tokens_per_expert = topk_map.sum(dim=0)
if capacity_factor is None:
return topk_masked_gates, topk_map, tokens_per_expert
else:
expert_capacity = get_capacity(
num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
)
if drop_policy == "probs":
_, capacity_indices = torch.topk(
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
elif drop_policy == "position":
_, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")
if pad_to_capacity:
final_map = capacity_mask
final_probs = topk_masked_gates * final_map
else:
final_map = torch.logical_and(topk_map, capacity_mask)
final_probs = topk_masked_gates * final_map
return final_probs, final_map, tokens_per_expert