import contextlib
import torch
from functools import wraps
from megatron.core.enums import ModelType
from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from mindspeed.core.pipeline_parallel import flexible_schedules
from mindspeed.core.pipeline_parallel.ripipe_schedules import forward_backward_ripipe_pipelining
from mindspeed.core.pipeline_parallel import multiparameter_schedules
LOSS_BACKWARD_SCALE = torch.tensor(1.0)
def get_forward_backward_func_wrapper(get_forward_backward_func):
@wraps(get_forward_backward_func)
def wrapper(*args, **kwargs):
arguments = get_args()
if arguments.optimize_send_recv_comm and arguments.num_layers_per_virtual_pipeline_stage is None:
return flexible_schedules.forward_backward_pipelining_without_interleaving
if arguments.automated_pipeline_perf and arguments.pp_schedule_list:
return flexible_schedules.forward_backward_pipelining_without_interleaving
if (arguments.recompute_in_bubble or arguments.recompute_in_advance) and torch.is_grad_enabled():
return forward_backward_ripipe_pipelining
if parallel_state.get_pipeline_model_parallel_world_size() > 1 \
and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None \
and arguments.use_nanopipe:
return flexible_schedules.forward_backward_pipelining_with_interleaving_nano_pipe
if arguments.use_multiparameter_pipeline_model_parallel:
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1 \
and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
return multiparameter_schedules.forward_backward_pipelining_with_interleaving
return get_forward_backward_func(*args, **kwargs)
return wrapper
def forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
encoder_decoder_xattn=False,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable):
The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage():
if not collect_non_loss_data:
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
else:
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
loss_scale = (
config.grad_scale_func(LOSS_BACKWARD_SCALE)
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
model_type = get_model_type(model)
if (
model_type == ModelType.encoder_and_decoder
and encoder_decoder_xattn
and parallel_state.is_inside_decoder()
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
def get_tensor_shapes_wrapper(get_tensor_shapes):
@wraps(get_tensor_shapes)
def wrapper(*args, **kwargs):
tensor_shapes = get_tensor_shapes(*args, **kwargs)
arguments = get_args()
if arguments.tp_2d:
tensor_shapes = [[tensor_shape[0] // arguments.tp_x, tensor_shape[1], tensor_shape[2] // arguments.tp_y]
for tensor_shape in tensor_shapes]
return tensor_shapes
return wrapper