"""Define common classes for multi-parameter pipeline parallelism.
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
"""
from dataclasses import dataclass
from typing import Callable, Optional, Protocol
import torch
@dataclass
class InterleavingSchedulerArgs:
"""Define arguments for interleaving scheduler."""
get_model_config: Callable
"""Get the model config."""
get_pipeline_model_parallel_world_size: Callable
""""Get the pipeline model parallel world size."""
get_pipeline_model_parallel_rank: Callable
"""Get the pipeline model parallel rank."""
set_virtual_pipeline_model_parallel_rank: Callable
"""Set the virtual pipeline model parallel rank."""
get_args: Callable
"""Get the args."""
is_encoder_and_decoder: bool
"""If the model is encoder and decoder."""
is_pipeline_first_stage: Callable
"""If the model is pipeline first stage."""
is_pipeline_last_stage: Callable
"""If the model is pipeline last stage."""
deallocate_output_tensor: Callable
"""Deallocate the output tensor."""
send_forward_backward_recv_forward_backward: Callable
"""Send forward and backward and receive forward and backward."""
forward_step: Callable
""""Forward step."""
check_first_val_step: Callable
"""Check first validation step."""
backward_step: Callable
"""Backward step."""
recv_forward: Callable
"""Receive forward."""
recv_backward: Callable
"""Receive backward."""
send_forward_recv_forward: Callable
"""Send forward and receive forward."""
send_backward_recv_backward: Callable
"""Send backward and receive backward."""
class Config(Protocol):
"""Define a protocol for configuration."""
timers: Callable
"""Timers object to call for various timing functions.
See megatron.core.timers.Timers
"""
grad_scale_func: Callable
"""If using loss scaling, this function should take the loss
and return the scaled loss. If None, no function is called on the loss.
"""
pipeline_dtype: torch.dtype
"""dtype used in p2p communication, usually params_dtype"""
overlap_p2p_comm: bool
"""When True some of the peer to peer communication
for pipeline parallelism will overlap with computation.
Must be False if batch_p2p_comm is true.
"""
batch_p2p_comm: bool
"""Use batch_isend_irecv instead of individual isend/irecv calls.
Must be False if overlap_p2p_comm is True.
"""
barrier_with_L1_time: bool
"""If true, use barrier with level 1 time measurements.
It is up to the user to make sure
calling barrier with their timers will not result in hangs.
This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""
no_sync_func: Callable
"""Function that creates a context that suppresses asynchronous
data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel,
the default is to use core.distributed.DistributedDataParallel.no_sync.
"""
grad_sync_func: Callable
"""Function that launches asynchronous gradient reductions
(e.g. distributed optimizer gradient reduce-scatters).
The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""
param_sync_func: Callable
"""Function that launches asynchronous parameter synchronizations
(e.g. distributed optimizer parameter all-gathers).
The function should take one argument: an iterable of parameters to
be synchronized.
"""
num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
"""If int, set the number of microbatches where not all of the layers
will be checkpointed and recomputed.
The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers
(either full recompute or selective recompute). If
None, the checkpoint and recompute
will be left up to the forward_step function.
"""
deallocate_pipeline_outputs: bool
"""If True, output data is deallocated after the tensor
is sent to the next pipeline stage.
Helps with saving memory,
does nothing when pipeline parallel is not used.
"""
finalize_model_grads_func: Optional[Callable] = None
"""Function that finalizes gradients on all workers.
Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism,
and sequence parallelism dimensions.
"""
calculate_per_token_loss: bool
"""Whether cross entropy loss is calculated over
the actual number of non-padded tokens in the global batch,
versus the default behavior of assuming all tokens are non-padded.
"""