import collections
from functools import wraps
from typing import Iterator, List, Union
import contextlib
import torch
from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import deallocate_output_tensor, forward_step, backward_step, \
check_first_val_step
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.utils import get_model_config, get_model_type
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func as get_forward_backward_func_mg
from mindspeed.core.tensor_parallel.checkpoint_manager import get_pipeline_checkpoint_manager
from mindspeed.core.weight_grad_store import WeightGradStore
def get_forward_backward_func_ripipe_patch(*args, **kwargs):
if torch.is_grad_enabled():
return forward_backward_ripipe_pipelining
return get_forward_backward_func_mg()
def forward_backward_ripipe_pipelining(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
"""Almost directly copied from megatron's forward_backward_pipelining_with_interleaving
function, all modifications are annotated with 'ripipe related' or 'nanopipe related' """
pipeline_checkpoint_manager = get_pipeline_checkpoint_manager(
num_of_chunks=parallel_state.get_virtual_pipeline_model_parallel_world_size())
args = get_args()
if args.recompute_in_bubble or args.recompute_in_advance:
pipeline_checkpoint_manager.open_ri_pipe = True
pipeline_checkpoint_manager.do_pre_recompute = True
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
assert isinstance(
data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
config = get_model_config(model[0])
if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
no_sync_func = config.no_sync_func
if isinstance(no_sync_func, list):
def multi_no_sync():
stack = contextlib.ExitStack()
for model_chunk_no_sync_func in config.no_sync_func:
stack.enter_context(model_chunk_no_sync_func())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
config.grad_sync_func = [config.grad_sync_func for _ in model]
if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
config.param_sync_func = [config.param_sync_func for _ in model]
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
synchronized_model_chunks = set()
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
forward_data_store = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
if num_microbatches % pipeline_parallel_size != 0:
msg = f'number of microbatches ({num_microbatches}) is not divisible by '
msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
msg += 'when using interleaved schedule'
raise RuntimeError(msg)
model_type = get_model_type(model[0])
if model_type == ModelType.encoder_and_decoder:
raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
if decoder_seq_length is not None and decoder_seq_length != seq_length:
raise RuntimeError(
"Interleaving is not supported with a different decoder sequence length."
)
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
tensor_shape[0] = tensor_shape[0] // getattr(args, "tp_x", 1)
tensor_shape[-1] = tensor_shape[-1] // getattr(args, "tp_y", 1)
num_model_chunks = len(model)
total_num_microbatches = num_microbatches * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = total_num_microbatches
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
num_fwd = min((pipeline_parallel_size - 1) * 2 + (num_model_chunks - 1) * pipeline_parallel_size, total_num_microbatches)
num_dx = num_fwd - num_warmup_microbatches
overlap_chunks_num = (num_dx + pipeline_parallel_size - 1) // pipeline_parallel_size
nano_flag = [True] * len(model)
for i in range(overlap_chunks_num):
nano_flag[-i - 1] = False
num_microbatches_recompute, num_microbatches_recompute_forward, num_microbatches_recompute_steady_groups, \
num_microbatches_recompute_tail = get_ripipe_recompute_count_params(num_microbatches,
num_model_chunks,
num_warmup_microbatches)
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
if config.param_sync_func is not None:
config.param_sync_func[0](model[0].parameters())
config.param_sync_func[1](model[1].parameters())
def get_chunk_batch_id(microbatch_id, forward):
"""ripipe related, needed by recompute_in_bubble function."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
group_id = microbatch_id // (pipeline_parallel_size * num_model_chunks)
intra_chunk_batch_id = (microbatch_id_in_group % pipeline_parallel_size)
return group_id, intra_chunk_batch_id, model_chunk_id
def should_recompute(fk):
"""ripipe related, needed by recompute_in_bubble function, used to determine
whether a mircobatch needs to be recomputed in the 1f1b stage."""
gid, intro_group_bid, chunk_id = get_chunk_batch_id(fk, forward=True)
if chunk_id == 0:
if gid < 2:
return False
elif gid < 2 + num_microbatches_recompute_steady_groups:
if intro_group_bid >= (1 + 2 * pipeline_parallel_rank):
return True
else:
if intro_group_bid >= pipeline_parallel_size - num_microbatches_recompute_tail:
return True
return False
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == 0:
return microbatch_id_in_group % pipeline_parallel_size == 0
else:
return False
def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == num_microbatch_groups - 1:
return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
else:
return False
def forward_step_helper(microbatch_id, checkpoint_activations_microbatch):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if config.param_sync_func is not None:
param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
if (
param_sync_microbatch_id < total_num_microbatches
and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
):
param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
if 1 < param_sync_chunk_id < num_model_chunks:
config.param_sync_func[param_sync_chunk_id](
model[param_sync_chunk_id].parameters()
)
if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor, _ = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id),
),
)
output_tensors[model_chunk_id].append(output_tensor)
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
if args.recompute_in_bubble or args.recompute_in_advance:
pipeline_checkpoint_manager.batch_fin(model_chunk_id)
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id) and nano_flag[model_chunk_id]:
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
if config.grad_sync_func is not None:
grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
grad_sync_microbatch_id
):
grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
if nano_flag[grad_sync_chunk_id]:
enable_grad_sync()
config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()
return input_tensor_grad
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
fwd_wait_handles = None
bwd_wait_handles = None
for k in range(num_warmup_microbatches):
if fwd_wait_handles is not None:
for req in fwd_wait_handles if isinstance(fwd_wait_handles, list) else fwd_wait_handles.values():
req.wait()
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
if args.recompute_in_bubble:
if k < pipeline_parallel_size * num_model_chunks:
pipeline_checkpoint_manager.disable_recompute()
else:
num_microbatches_recompute_forward -= 1
output_tensor = forward_step_helper(k, checkpoint_activations_microbatch)
if args.recompute_in_bubble or args.recompute_in_advance:
pipeline_checkpoint_manager.enable_recompute()
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (total_num_microbatches - 1):
recv_prev = False
if parallel_state.is_pipeline_last_stage():
output_tensor = None
if not config.overlap_p2p_comm:
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
else:
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
output_tensor_grad,
bwd_wait_handles,
) = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
for k in range(num_microbatches_remaining):
forward_k = k + num_warmup_microbatches
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
forward_k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
if config.overlap_p2p_comm:
if fwd_wait_handles is not None:
for req in fwd_wait_handles if isinstance(fwd_wait_handles, list) else fwd_wait_handles.values():
req.wait()
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if args.recompute_in_bubble:
if num_microbatches_recompute_forward > 0:
num_microbatches_recompute_forward -= 1
elif num_microbatches_recompute > 0 and should_recompute(forward_k):
pass
else:
pipeline_checkpoint_manager.disable_recompute()
output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
if args.recompute_in_bubble or args.recompute_in_advance:
pipeline_checkpoint_manager.enable_recompute()
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
if k == (num_microbatches_remaining - 1):
recv_prev = False
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
if args.recompute_in_advance or args.recompute_in_bubble:
vpp_rank = get_model_chunk_id(k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(vpp_rank)
if not parallel_state.is_pipeline_last_stage() or args.recompute_in_bubble:
pipeline_checkpoint_manager.recompute_next(vpp_rank)
if bwd_wait_handles is not None:
for req in bwd_wait_handles if isinstance(bwd_wait_handles, list) else bwd_wait_handles.values():
req.wait()
backward_k = k
if k < num_dx and getattr(args, "use_nanopipe", False):
WeightGradStore.start_decouple()
if getattr(args, "use_nanopipe", False):
WeightGradStore.resize_ori_storage(getattr(args, "use_nanopipe_swap", False))
input_tensor_grad = backward_step_helper(backward_k)
if getattr(args, "use_nanopipe", False):
if WeightGradStore.is_decoupleBlock:
WeightGradStore.flush()
if k == num_dx - 1:
WeightGradStore.end_decouple()
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
else:
output_tensor = forward_step_helper(forward_k, checkpoint_activations_microbatch)
backward_k = k
if k < num_dx and getattr(args, "use_nanopipe", False):
WeightGradStore.start_decouple()
if getattr(args, "use_nanopipe", False):
WeightGradStore.resize_ori_storage(getattr(args, "use_nanopipe_swap", False))
input_tensor_grad = backward_step_helper(backward_k)
if k == num_dx - 1 and getattr(args, "use_nanopipe", False):
WeightGradStore.end_decouple()
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
if k == (num_microbatches_remaining - 1):
recv_prev = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if not forward_only:
if args.recompute_in_advance:
vpp_rank = get_model_chunk_id(num_microbatches_remaining, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(vpp_rank)
if not parallel_state.is_pipeline_last_stage():
pipeline_checkpoint_manager.recompute_next(vpp_rank)
if args.recompute_in_bubble and num_microbatches_recompute > 0:
old_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
pipeline_checkpoint_manager.recompute_next_force(0)
parallel_state.set_virtual_pipeline_model_parallel_rank(old_vpp_rank)
if config.overlap_p2p_comm and bwd_wait_handles is not None:
for wait_handle in bwd_wait_handles if isinstance(bwd_wait_handles, list) else bwd_wait_handles.values():
wait_handle.wait()
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config)
)
if args.recompute_in_bubble:
num_microbatches_recompute_forward = 1
for k in range(num_microbatches_remaining, total_num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (total_num_microbatches - 1):
recv_next = False
out_tensor, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config, overlap_p2p_comm=True
)
output_tensor_grads[next_backward_model_chunk_id].append(
out_tensor
)
if (
getattr(args, "use_nanopipe", False) and
getattr(args, "use_nanopipe_swap", False) and
k == max(
num_microbatches_remaining + 1,
(total_num_microbatches + num_microbatches_remaining) // 2
)
):
WeightGradStore.swap_tensors()
if args.recompute_in_bubble and num_microbatches_recompute > 0 and \
num_microbatches_recompute_forward < num_microbatches_recompute:
old_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
pipeline_checkpoint_manager.recompute_next_force(0)
parallel_state.set_virtual_pipeline_model_parallel_rank(old_vpp_rank)
num_microbatches_recompute_forward += 1
if args.recompute_in_advance and k != (total_num_microbatches - 1):
vpp_rank = get_model_chunk_id(k + 1, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(vpp_rank)
if not parallel_state.is_pipeline_last_stage():
pipeline_checkpoint_manager.recompute_next(vpp_rank)
if config.overlap_p2p_comm and bwd_wait_handles is not None:
for wait_handle in bwd_wait_handles if isinstance(bwd_wait_handles, list) else bwd_wait_handles.values():
wait_handle.wait()
if getattr(args, "use_nanopipe", False):
if nano_flag[0] and 0 not in synchronized_model_chunks:
config.grad_sync_func[0](model[0].parameters())
synchronized_model_chunks.add(0)
overlap_arg = [pipeline_parallel_size, nano_flag, synchronized_model_chunks, config.grad_sync_func, model]
WeightGradStore.pop(overlap_arg)
enable_grad_sync()
if config.grad_sync_func is not None:
for model_chunk_id in range(num_model_chunks):
if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
if config.timers is not None:
config.timers('forward-backward').stop()
if config.finalize_model_grads_func is not None and not forward_only:
config.finalize_model_grads_func(model)
if args.recompute_in_bubble or args.recompute_in_advance:
pipeline_checkpoint_manager.iter_fin()
return forward_data_store
def get_ripipe_recompute_count_params(num_microbatches, num_model_chunks, num_warmup_microbatches):
"""ripipe related, calculate the variables needed by the recompute_in_bubble function"""
args = get_args()
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
num_microbatches_recompute_steady_groups = 0
num_microbatches_recompute_tail = 0
num_microbatches_recompute = 0
num_microbatches_recompute_forward = 0
if args.recompute_in_bubble and num_microbatches // pipeline_parallel_size > 1:
num_microbatches_recompute = num_warmup_microbatches + 1 - num_model_chunks * pipeline_parallel_size
if num_microbatches_recompute < 0:
num_microbatches_recompute = 0
num_microbatches_recompute_forward = num_microbatches_recompute
if num_microbatches_recompute > 0 and num_microbatches // pipeline_parallel_size >= 3:
num_microbatches_recompute_steady_groups = (num_microbatches // pipeline_parallel_size) - 3
num_microbatches_recompute_tail = 2 + 2 * pipeline_parallel_rank
if num_microbatches_recompute_steady_groups == 0:
if num_microbatches_recompute_tail >= pipeline_parallel_size - 1 - 2 * pipeline_parallel_rank:
num_microbatches_recompute_tail = 0
num_microbatches_recompute_steady_groups = 1
else:
num_microbatches_recompute_tail = 1
params = collections.namedtuple('RecomputeCountParams',
['num_microbatches_recompute', 'num_microbatches_recompute_forward',
'num_microbatches_recompute_steady_groups', 'num_microbatches_recompute_tail'])
return params(num_microbatches_recompute, num_microbatches_recompute_forward,
num_microbatches_recompute_steady_groups, num_microbatches_recompute_tail)