import contextlib
from typing import Callable, Iterator, List, Optional, Union, Tuple
import torch
import torch.distributed as dist
from megatron import core
from megatron.core import ModelParallelConfig, parallel_state
from megatron.core.enums import ModelType
from megatron.core.utils import get_model_config, get_model_type
from megatron.core.pipeline_parallel.schedules import (
get_tensor_shapes,
forward_step,
backward_step,
deallocate_output_tensor,
check_first_val_step,
clear_embedding_activation_buffer,
finish_embedding_wgrad_compute
)
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.p2p_communication import (
Shape,
_communicate_shapes,
_communicate,
_batched_p2p_ops,
_p2p_ops
)
from megatron.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
)
from megatron.training import get_args
from mindspeed.core.parallel_state import get_pipeline_parallel_group_for_new_stream
forward_comm_stream = None
backward_comm_stream = None
default_stream = None
scheduler_plan = None
def recv_forward(tensor_shapes, config, group):
input_tensors = []
wait_handles = []
for tensor_shape in tensor_shapes:
if tensor_shape is None or core.parallel_state.is_pipeline_first_stage():
input_tensor = None
wait_handle = None
else:
input_tensor, _, wait_handle = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
group=group,
wait_on_reqs=False
)
input_tensors.append(input_tensor)
wait_handles.append(wait_handle)
return input_tensors, wait_handles
def recv_backward(tensor_shapes, config, group):
output_tensor_grads = []
wait_handlers = []
for tensor_shape in tensor_shapes:
if tensor_shape is None or core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
wait_handle = None
else:
_, output_tensor_grad, wait_handle = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
config=config,
group=group,
wait_on_reqs=False
)
output_tensor_grads.append(output_tensor_grad)
wait_handlers.append(wait_handle)
return output_tensor_grads, wait_handlers
def send_forward(output_tensors, tensor_shapes, config, group):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None or core.parallel_state.is_pipeline_last_stage():
continue
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
tensor_shape=None,
config=config,
group=group,
wait_on_reqs=False
)
def send_backward(input_tensor_grads, tensor_shapes, config, group):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None or core.parallel_state.is_pipeline_first_stage():
continue
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=None,
config=config,
group=group,
wait_on_reqs=False
)
def _communicate(
*,
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
wait_on_reqs: bool = True,
group: dist.ProcessGroup = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Args:
tensor_send_next (torch.Tensor, optional):
Tensor to send to next rank (no tensor sent if None)
tensor_send_prev (torch.Tensor, optional):
Tensor to send to prev rank (no tensor sent if None)
recv_prev (boolean, required):
whether tensor should be received from previous rank.
recv_next (boolean, required):
whether tensor should be received from next rank.
tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
wait_on_reqs (boolean, optional, default=False):
For non-batched p2p communication, wait on each request
before returning.
Returns:
tuple containing
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
"""
tensor_recv_prev = None
tensor_recv_next = None
if not config.variable_seq_lengths:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else:
recv_prev_shape, recv_next_shape = _communicate_shapes(
tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
)
if recv_prev:
if config.pipeline_dtype is None:
raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev = torch.empty(
recv_prev_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)
if recv_next:
if config.pipeline_dtype is None:
raise RuntimeError("dtype must be provided if recv_next is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next = torch.empty(
recv_next_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)
if config.use_ring_exchange_p2p:
def _ring_exchange_wrapper(**kwargs):
torch.distributed.ring_exchange(**kwargs)
return []
p2p_func = _ring_exchange_wrapper
elif config.batch_p2p_comm:
if not wait_on_reqs:
raise AssertionError('wait_on_reqs must be True when use batch_p2p_comm')
p2p_func = _batched_p2p_ops
else:
p2p_func = _p2p_ops
reqs = p2p_func(
tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=group,
prev_pipeline_rank=get_pipeline_model_parallel_prev_rank(),
next_pipeline_rank=get_pipeline_model_parallel_next_rank(),
)
if wait_on_reqs and len(reqs) > 0:
for req in reqs.values():
req.wait()
reqs = None
if config.batch_p2p_comm and config.batch_p2p_sync:
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next, reqs
def generate_1f1b_scheduler_plan(pp_size, num_micro_batch):
scheduler_plan_all_stages = {}
num_warmup_microbatch = [pp_size - r - 1 for r in range(pp_size)]
num_cooldown_microbatch = num_warmup_microbatch
num_stable_microbatch = [(num_micro_batch * 2 - num_warmup_microbatch[r] - num_cooldown_microbatch[r]) // 2
for r in range(pp_size)]
forward_count = [1 for _ in range(pp_size)]
backward_count = [1 for _ in range(pp_size)]
for pp_rank in range(pp_size):
key = 'stage{}'.format(pp_rank)
scheduler_plan_all_stages[key] = []
for i in range(num_warmup_microbatch[pp_rank]):
value = 'F{}'.format(forward_count[pp_rank])
scheduler_plan_all_stages[key].append(value)
forward_count[pp_rank] += 1
for pp_rank in range(pp_size):
key = 'stage{}'.format(pp_rank)
for i in range(num_stable_microbatch[pp_rank]):
value = 'F{}'.format(forward_count[pp_rank])
scheduler_plan_all_stages[key].append(value)
forward_count[pp_rank] += 1
value = 'B{}'.format(backward_count[pp_rank])
scheduler_plan_all_stages[key].append(value)
backward_count[pp_rank] += 1
for pp_rank in range(pp_size):
key = 'stage{}'.format(pp_rank)
for i in range(num_cooldown_microbatch[pp_rank]):
value = 'B{}'.format(backward_count[pp_rank])
scheduler_plan_all_stages[key].append(value)
backward_count[pp_rank] += 1
return scheduler_plan_all_stages
def forward_backward_pipelining_without_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise.
"""
if isinstance(model, list):
if not len(model) == 1:
raise AssertionError("non-interleaved pipeline parallelism does not support model chunking")
model = model[0]
if isinstance(data_iterator, list):
if not len(data_iterator) == 1:
raise AssertionError("non-pipeline-parallel schedule does not support model chunking")
data_iterator = data_iterator[0]
config = get_model_config(model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
model_type = get_model_type(model)
rank = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(
rank=rank - 1,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
encoder_decoder_xattn=False,
)
send_tensor_shapes = get_tensor_shapes(
rank=rank,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
encoder_decoder_xattn=False,
)
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
forward_data_store = []
def wait_helper(wait_handlers):
for reqs in wait_handlers:
if reqs is not None:
for req in reqs.values():
req.wait()
global forward_comm_stream
if forward_comm_stream is None:
forward_comm_stream = torch.cuda.Stream()
global backward_comm_stream
if backward_comm_stream is None:
backward_comm_stream = torch.cuda.Stream()
global default_stream
if default_stream is None:
default_stream = torch.cuda.default_stream()
global scheduler_plan
arguments = get_args()
key = 'stage{}'.format(parallel_state.get_pipeline_model_parallel_rank())
if scheduler_plan is None and getattr(arguments, "pp_schedule_list", None):
scheduler_plan = arguments.pp_schedule_list.get(key)
elif scheduler_plan is None and getattr(arguments, "pp_schedule_list", None) is None:
scheduler_plan = generate_1f1b_scheduler_plan(parallel_state.get_pipeline_model_parallel_world_size(),
num_microbatches)
scheduler_plan = scheduler_plan.get(key)
config.batch_p2p_comm = False
fwd_wait_handles, bwd_wait_handles = None, None
current_tag_id = -1
for tag in scheduler_plan:
current_tag_id += 1
if tag.startswith('F'):
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
current_tag_id % max_outstanding_backprops >= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
with torch.cuda.stream(forward_comm_stream):
input_tensor, fwd_wait_handles = recv_forward(
recv_tensor_shapes, config, get_pipeline_model_parallel_group()
)
wait_helper(fwd_wait_handles)
output_tensor, _ = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(first_val_step, forward_only, current_tag_id == 0)
)
with torch.cuda.stream(forward_comm_stream):
forward_comm_stream.wait_stream(default_stream)
send_forward(
output_tensor,
send_tensor_shapes,
config,
get_pipeline_model_parallel_group()
)
for tensor in output_tensor:
if tensor is not None:
tensor.record_stream(forward_comm_stream)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
else:
if forward_only:
continue
if current_tag_id == len(scheduler_plan) - 1:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()
with torch.cuda.stream(backward_comm_stream):
output_tensor_grads, bwd_wait_handles = recv_backward(
send_tensor_shapes, config, get_pipeline_parallel_group_for_new_stream()
)
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
wait_helper(bwd_wait_handles)
input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grads,
model_type,
config
)
with torch.cuda.stream(backward_comm_stream):
backward_comm_stream.wait_stream(default_stream)
send_backward(
input_tensor_grad,
recv_tensor_shapes,
config,
get_pipeline_parallel_group_for_new_stream()
)
for tensor in input_tensor_grad:
if tensor is not None:
tensor.record_stream(backward_comm_stream)
if not forward_only:
if no_sync_context is not None:
enable_grad_sync()
if config.grad_sync_func is not None:
config.grad_sync_func(model.parameters())
if config.timers is not None:
config.timers('forward-backward').stop()
if config.finalize_model_grads_func is not None and not forward_only:
config.finalize_model_grads_func([model])
return forward_data_store
def forward_backward_pipelining_with_interleaving_patch(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
if not isinstance(model, list):
raise AssertionError("interleaved pipeline parallelism expected model chunking")
if not all(isinstance(chunk, torch.nn.Module) for chunk in model):
raise AssertionError("invalid model chunking")
if not isinstance(data_iterator, list):
raise AssertionError("interleaved pipeline parallelism expected each model chunk to have a data iterator")
config = get_model_config(model[0])
if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
if config.finalize_model_grads_func is not None and not forward_only:
embedding_module = clear_embedding_activation_buffer(config, model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
no_sync_func = config.no_sync_func
if isinstance(no_sync_func, list):
def multi_no_sync():
stack = contextlib.ExitStack()
for model_chunk_no_sync_func in config.no_sync_func:
stack.enter_context(model_chunk_no_sync_func())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
config.grad_sync_func = [config.grad_sync_func for _ in model]
if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
config.param_sync_func = [config.param_sync_func for _ in model]
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
synchronized_model_chunks = set()
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
forward_data_store = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
if num_microbatches % pipeline_parallel_size != 0:
msg = f'number of microbatches ({num_microbatches}) is not divisible by '
msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
msg += 'when using interleaved schedule'
raise RuntimeError(msg)
model_type = get_model_type(model[0])
if model_type == ModelType.encoder_and_decoder:
raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
if decoder_seq_length is not None and decoder_seq_length != seq_length:
raise RuntimeError(
"Interleaving is not supported with a different decoder sequence length."
)
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
tensor_shape[0] = tensor_shape[0] // get_args().tp_x
tensor_shape[-1] = tensor_shape[-1] // get_args().tp_y
num_model_chunks = len(model)
total_num_microbatches = num_microbatches * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = total_num_microbatches
else:
if num_microbatches == pipeline_parallel_size:
num_warmup_microbatches = total_num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
if config.param_sync_func is not None:
config.param_sync_func[0](model[0].parameters())
config.param_sync_func[1](model[1].parameters())
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def get_microbatch_id_in_model_chunk(iteration_id, forward):
"""Helper method to get the microbatch_id within model chunk given the iteration number."""
assert forward
iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks)
microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + (
iteration_id % pipeline_parallel_size
)
return microbatch_id_in_model_chunk
def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == 0:
return microbatch_id_in_group % pipeline_parallel_size == 0
else:
return False
def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == num_microbatch_groups - 1:
return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
else:
return False
def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if config.param_sync_func is not None:
param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
if (
param_sync_microbatch_id < total_num_microbatches
and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
):
param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
if 1 < param_sync_chunk_id < num_model_chunks:
config.param_sync_func[param_sync_chunk_id](
model[param_sync_chunk_id].parameters()
)
if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id),
),
current_microbatch=current_microbatch,
)
output_tensors[model_chunk_id].append(output_tensor)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
if config.grad_sync_func is not None:
grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
grad_sync_microbatch_id
):
grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
enable_grad_sync()
config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()
return input_tensor_grad
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
fwd_wait_handles = None
bwd_wait_handles = None
for k in range(num_warmup_microbatches):
if fwd_wait_handles is not None:
for req in fwd_wait_handles.values():
req.wait()
cur_model_chunk_id = get_model_chunk_id(k, forward=True)
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True)
output_tensor = forward_step_helper(
k, current_microbatch, checkpoint_activations_microbatch
)
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (total_num_microbatches - 1):
recv_prev = False
if parallel_state.is_pipeline_last_stage():
output_tensor = None
if not config.overlap_p2p_comm:
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
else:
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
output_tensor_grad,
bwd_wait_handles,
) = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
for k in range(num_microbatches_remaining):
forward_k = k + num_warmup_microbatches
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
forward_k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True)
if config.overlap_p2p_comm:
if fwd_wait_handles is not None:
for req in fwd_wait_handles.values():
req.wait()
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
output_tensor = forward_step_helper(
forward_k, current_microbatch, checkpoint_activations_microbatch
)
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
if k == (num_microbatches_remaining - 1):
recv_prev = False
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
if bwd_wait_handles is not None:
for req in bwd_wait_handles.values():
req.wait()
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
else:
output_tensor = forward_step_helper(
forward_k, current_microbatch, checkpoint_activations_microbatch
)
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
if k == (num_microbatches_remaining - 1):
recv_prev = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if not forward_only:
if config.overlap_p2p_comm and bwd_wait_handles is not None:
for wait_handle in bwd_wait_handles.values():
wait_handle.wait()
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config)
)
for k in range(num_microbatches_remaining, total_num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (total_num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
)
)
enable_grad_sync()
if config.grad_sync_func is not None:
for model_chunk_id in range(num_model_chunks):
if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
if config.finalize_model_grads_func is not None and not forward_only:
finish_embedding_wgrad_compute(config, embedding_module)
config.finalize_model_grads_func(
model, total_num_tokens if config.calculate_per_token_loss else None
)
if config.timers is not None:
config.timers('forward-backward').stop()
return forward_data_store