import contextlib
from typing import List, Iterator, Union, Dict, Any
import torch
from megatron.core import mpu, parallel_state
from megatron.core.enums import ModelType
from megatron.core.parallel_state import (
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_pipeline_model_parallel_group,
get_tensor_model_parallel_group
)
from megatron.core.pipeline_parallel.schedules import (
backward_step,
check_first_val_step,
clear_embedding_activation_buffer,
deallocate_output_tensor,
finish_embedding_wgrad_compute,
forward_backward_no_pipelining,
get_tensor_shapes,
set_current_microbatch,
)
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import (
get_attr_wrapped_model,
get_model_config,
get_model_type,
get_model_xattn,
)
from megatron.training import get_args
from megatron.training.utils import average_losses_across_data_parallel_group
from mindspeed_mm.patchs.layerwise_disaggregated_training import p2p_communication_patch
from mindspeed_mm.patchs.layerwise_disaggregated_training.parallel_state_patch import (
get_pipeline_model_parallel_group_alternate,
get_pipeline_model_parallel_group_last_to_first,
get_pipeline_model_parallel_group_first_to_last,
get_vdp_size,
is_vtp_enabled,
is_vdp_enabled,
is_vtp_stage_rank0,
get_vtp_stage_ranks,
get_vtp_my_stage_idx,
get_vtp_size_list
)
from mindspeed_mm.utils.utils import compute_token_level_loss
stream_ping = None
stream_pang = None
stream_last_to_first = None
stream_first_to_last = None
default_stream = None
def move_to_device(batch: Dict[str, Any], float_dtype: str):
for k, v in batch.items():
if isinstance(v, torch.Tensor):
dtype = float_dtype if torch.is_floating_point(v) else None
batch[k] = v.to(device=torch.cuda.current_device(), dtype=dtype)
elif isinstance(v, list) and all(isinstance(t, torch.Tensor) for t in v):
batch[k] = [t.to(device=torch.cuda.current_device(),
dtype=float_dtype if torch.is_floating_point(t) else None)
for t in v]
def get_batch(data_iterator, is_vit_last_stage=False):
"""Generate a batch."""
if data_iterator is not None:
batch = next(data_iterator)
else:
raise ValueError("Data iterator is None. Unable to retrieve batch.")
move_to_device(batch, get_args().params_dtype)
has_video = 'pixel_values_videos' in batch and 'video_grid_thw' in batch
if has_video:
batch['pixel_values'] = batch.pop('pixel_values_videos')
batch['image_grid_thw'] = batch.pop('video_grid_thw')
if (mpu.is_pipeline_first_stage() or is_vit_last_stage) and get_args().encoder_dp_balance:
batch['pixel_values'], batch['tranfer'] = EncoderBalanceComm.apply(
batch['pixel_values'],
mpu.get_data_parallel_group())
else:
batch['tranfer'] = None
return batch
def get_tps(output_tensor):
"""Get the tokens per sample"""
B, S, _ = output_tensor.shape
dp_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
cp_size = torch.distributed.get_world_size(group=mpu.get_context_parallel_group())
tokens_per_sample = torch.tensor(S, device=output_tensor.device) / dp_size * cp_size
torch.distributed.all_reduce(tokens_per_sample, group=mpu.get_data_parallel_group())
return tokens_per_sample
def loss_func(output_tensor):
"""Loss function."""
args = get_args()
loss_dict = output_tensor['loss_dict']
loss_dir = {}
if args.log_tps:
tokens_per_sample = get_tps(output_tensor['logits'])
loss_dir["tokens per sample"] = tokens_per_sample
if args.calculate_per_token_loss:
loss, local_num_tokens, reporting_loss = compute_token_level_loss(loss_dict)
loss_dir["loss"] = (reporting_loss[0], reporting_loss[1])
return (
loss[0].clone(),
local_num_tokens,
loss_dir
)
loss = loss_dict['loss']
averaged_loss = average_losses_across_data_parallel_group([loss])
loss_dir["loss"] = averaged_loss[0]
loss = loss.unsqueeze(0).clone()
return loss / mpu.get_context_parallel_world_size(), loss_dir
def get_forward_backward_func():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.
model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
num_microbatches (int, required):
The number of microbatches to go through
seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.
micro_batch_size (int, required): The number of sequences in a microbatch.
decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.
forward_only (optional, default = False): Perform only the forward step
collect_non_loss_data (optional, bool, default=False): TODO
first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation
step.
"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def forward_step_impl(data_iterator, model, batch=None):
"""Forward step."""
is_vit_last_stage = False
if model.module.module.add_image_encoder:
is_vit_last_stage = model.module.module.image_encoder.post_process
if batch is None:
output_tensor = model(**get_batch(data_iterator, is_vit_last_stage))
elif parallel_state.is_pipeline_first_stage(ignore_virtual=True):
output_tensor = model(**batch)
else:
output_tensor = model(
input_ids=batch['input_ids'],
pixel_values=batch['pixel_values'],
attention_mask=batch['attention_mask'],
image_grid_thw=batch['image_grid_thw'],
)
return output_tensor, loss_func
def forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
encoder_decoder_xattn=False,
is_end_stage=False,
batch=None,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable):
The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_function = forward_step_func(data_iterator, model, batch)
else:
output_tensor, loss_function = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_first_stage(ignore_virtual=True) and is_end_stage:
if not collect_non_loss_data:
outputs = loss_function(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
else:
if not len(outputs) == 2:
raise ValueError()
output_tensor, loss_reduced = outputs
output_tensor *= parallel_state.get_context_parallel_world_size()
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
else:
data = loss_function(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
if config.calculate_per_token_loss:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale)
else:
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
model_type = get_model_type(model)
if (
model_type == ModelType.encoder_and_decoder
and encoder_decoder_xattn
and parallel_state.is_inside_decoder()
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
def recv_forward_with_reqs(tensor_shapes, config, is_end_stage: bool = False, **kwargs):
"""Wrapper for p2p_communication_patch.recv_forward used with non-interleaving schedule."""
input_tensors = []
reps_list = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensor, reqs = p2p_communication_patch.recv_forward_with_reqs(
tensor_shape, config, is_end_stage, **kwargs
)
input_tensors.append(input_tensor)
reps_list.append(reqs)
return input_tensors, reps_list
def recv_backward_with_reqs(tensor_shapes, config, is_end_stage=False, **kwargs):
"""Wrapper for p2p_communication_patch.recv_backward used with non-interleaving schedule."""
output_tensor_grads = []
reps_list = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grad, reqs = p2p_communication_patch.recv_backward_with_reqs(
tensor_shape, config, is_end_stage, **kwargs
)
output_tensor_grads.append(output_tensor_grad)
reps_list.append(reqs)
return output_tensor_grads, reps_list
def send_forward(
output_tensors, tensor_shapes, config, is_end_stage: bool = False, **kwargs
):
"""Wrapper for p2p_communication_patch.send_forward used with non-interleaving schedule."""
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication_patch.send_forward(output_tensor, config, is_end_stage, wait_on_reqs=False, **kwargs)
def send_backward(
input_tensor_grads, tensor_shapes, config, is_end_stage: bool = False, **kwargs
):
"""Wrapper for p2p_communication_patch.send_backward used with non-interleaving schedule."""
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication_patch.send_backward(input_tensor_grad, config, is_end_stage, wait_on_reqs=False, **kwargs)
def _vtp_send_forward_wrapper(output_tensors, tensor_shapes, config, is_end_stage=False, **kwargs):
"""VTP-aware forward send: uses rank0 async P2P. Returns isend work handles."""
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
handles = []
for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
h = p2p_communication_patch.vtp_send_forward(output_tensor, **kwargs)
if h is not None:
handles.extend(h)
return handles
def _vtp_recv_forward_wrapper(tensor_shapes, config, async_op=False, is_end_stage=False, **kwargs):
"""VTP-aware forward recv: uses rank0 irecv + deferred broadcast.
When async_op=True, returns (input_tensors, reqs_list) for overlap with compute.
When async_op=False, blocks until recv + broadcast complete.
"""
input_tensors = []
reqs_list = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
if async_op:
tensor, reqs = p2p_communication_patch.vtp_recv_forward(
tensor_shape, config, async_op=True, **kwargs
)
input_tensors.append(tensor)
reqs_list.append(reqs)
else:
tensor = p2p_communication_patch.vtp_recv_forward(
tensor_shape, config, async_op=False, **kwargs
)
input_tensors.append(tensor)
if async_op:
return input_tensors, reqs_list
return input_tensors
def _vtp_send_backward_wrapper(input_tensor_grads, tensor_shapes, config, is_end_stage=False, **kwargs):
"""VTP-aware backward send: uses rank0 async P2P. Returns isend work handles."""
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
handles = []
for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
h = p2p_communication_patch.vtp_send_backward(input_tensor_grad, **kwargs)
if h is not None:
handles.extend(h)
return handles
def _vtp_recv_backward_wrapper(tensor_shapes, config, async_op=False, is_end_stage=False, **kwargs):
"""VTP-aware backward recv: uses rank0 irecv + deferred broadcast.
When async_op=True, returns (output_tensor_grads, reqs_list) for overlap.
When async_op=False, blocks until recv + broadcast complete.
"""
output_tensor_grads = []
reqs_list = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
if async_op:
tensor, reqs = p2p_communication_patch.vtp_recv_backward(
tensor_shape, config, async_op=True, **kwargs
)
output_tensor_grads.append(tensor)
reqs_list.append(reqs)
else:
tensor = p2p_communication_patch.vtp_recv_backward(
tensor_shape, config, async_op=False, **kwargs
)
output_tensor_grads.append(tensor)
if async_op:
return output_tensor_grads, reqs_list
return output_tensor_grads
def get_all_batchs(mbn, data_iterator, model, config, vit_hidden_size):
device = f"npu:{torch.cuda.current_device()}"
data_type = torch.int64
hidden_size = config.hidden_size
all_batchs = [[], [], []]
recv_tensor_shapes = []
vit_recv_tensor_shapes = []
def _split_item(item):
"""
Split item into get_vdp_size() parts with round-robin strategy.
example: [1,2,3,4,5,6] vdp=3 -> [[1,4],[2,5],[3,6]]
"""
len_item = item.size(0)
device = item.device
indices = torch.arange(len_item, device=device)
part_indices = indices % get_vdp_size()
item_list = []
for i in range(get_vdp_size()):
mask = (part_indices == i)
item_list.append(item[mask])
return item_list
def _broadcast(item):
if item is not None:
if is_vtp_enabled() and is_vdp_enabled():
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
split_item_list = _split_item(item)
for i, group in enumerate(parallel_state.get_pipeline_model_parallel_group()):
if not is_vtp_stage_rank0():
continue
torch.distributed.broadcast(split_item_list[i], parallel_state.get_pipeline_model_parallel_first_rank(), group=group)
elif is_vtp_stage_rank0():
torch.distributed.broadcast(item, parallel_state.get_pipeline_model_parallel_first_rank(), group=parallel_state.get_pipeline_model_parallel_group())
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
vtp_intra_group = get_tensor_model_parallel_group()
if vtp_intra_group is not None:
stage_ranks = get_vtp_stage_ranks()
my_stage = get_vtp_my_stage_idx()
tp_group = parallel_state.get_tensor_model_parallel_group()
torch.distributed.broadcast(item, stage_ranks[my_stage][0], group=tp_group)
elif is_vdp_enabled():
if isinstance(parallel_state.get_pipeline_model_parallel_group(), list):
split_item_list = _split_item(item)
for i, group in enumerate(parallel_state.get_pipeline_model_parallel_group()):
torch.distributed.broadcast(split_item_list[i], parallel_state.get_pipeline_model_parallel_first_rank(),
group=group)
else:
torch.distributed.broadcast(item, parallel_state.get_pipeline_model_parallel_first_rank(),
group=parallel_state.get_pipeline_model_parallel_group())
elif is_vtp_enabled():
if is_vtp_stage_rank0():
torch.distributed.broadcast(item, parallel_state.get_pipeline_model_parallel_first_rank(),
group=parallel_state.get_pipeline_model_parallel_group())
vtp_intra_group = get_tensor_model_parallel_group()
if vtp_intra_group is not None:
stage_ranks = get_vtp_stage_ranks()
my_stage = get_vtp_my_stage_idx()
torch.distributed.broadcast(item, stage_ranks[my_stage][0], group=vtp_intra_group)
else:
torch.distributed.broadcast(item, parallel_state.get_pipeline_model_parallel_first_rank(),
group=parallel_state.get_pipeline_model_parallel_group())
def get_batch_infos(attention_infos, thws, shapes, i_forward):
seq_len, mbs = shapes[i_forward][0][0], shapes[i_forward][0][1]
attention_mask = torch.ones(mbs, seq_len, device=device, dtype=data_type)
for i, padding_info in enumerate(attention_infos[i_forward]):
padding_side, padding_num = padding_info[0], padding_info[1]
if padding_num == 0:
continue
if padding_side == 0:
attention_mask[i, :padding_num] = torch.zeros(padding_num, device=device, dtype=data_type)
else:
attention_mask[i, -padding_num:] = torch.zeros(padding_num, device=device, dtype=data_type)
image_grid_thw = torch.tensor(thws[i_forward], device=device, dtype=data_type)
return attention_mask, image_grid_thw
is_vit_last_stage = False
if model.module.module.add_image_encoder:
is_vit_last_stage = model.module.module.image_encoder.post_process
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
mbn = mbn * get_vdp_size()
tensor_shapes = torch.empty(
mbn,
5 + 5 * config.micro_batch_size,
device=device,
dtype=data_type
)
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
for i in range(mbn):
batch = get_batch(data_iterator[0], is_vit_last_stage)
mbs, seq_len = batch["input_ids"].shape[0], batch["input_ids"].shape[1]
if batch['pixel_values'] is not None:
x_pixel_values, y_pixel_values = batch['pixel_values'].shape[0], batch['pixel_values'].shape[1]
else:
x_pixel_values = y_pixel_values = 0
tensor_shapes[i, :5] = torch.tensor([seq_len, mbs, hidden_size, x_pixel_values, y_pixel_values], device=device, dtype=data_type)
attention_mask = batch["attention_mask"]
image_grid_thw = batch["image_grid_thw"]
padding_side = (attention_mask[:, 0] != 0).long().unsqueeze(1)
padding_num = (seq_len - attention_mask.sum(dim=1)).unsqueeze(1)
tensor_shapes[i][5:] = torch.cat([padding_side, padding_num, image_grid_thw], dim=1).flatten()
tensor_shape = [(seq_len, mbs, config.hidden_size)]
vit_tensor_shape = [(x_pixel_values, 1, vit_hidden_size)]
all_batchs[0].append(batch)
all_batchs[1].append(batch)
all_batchs[2].append(batch)
recv_tensor_shapes.append(tensor_shape)
vit_recv_tensor_shapes.append(vit_tensor_shape)
_broadcast(tensor_shapes)
else:
_broadcast(tensor_shapes)
tensor_shapes_tolist = tensor_shapes.tolist()
shapes = [[tuple(shape[:3])] for shape in tensor_shapes_tolist]
recv_tensor_shapes = shapes
vit_shapes = [[(shape[3], 1, vit_hidden_size)] for shape in tensor_shapes_tolist]
vit_recv_tensor_shapes = vit_shapes
pixel_shapes = [(shape[3], shape[4]) for shape in tensor_shapes_tolist]
attention_infos = [[(shape[5 + 5 * i: 7 + 5 * i]) for i in range(config.micro_batch_size)] for shape in tensor_shapes_tolist]
thws = [[(shape[7 + 5 * i: 10 + 5 * i]) for i in range(config.micro_batch_size)] for shape in tensor_shapes_tolist]
for i in range(mbn):
x_pixel_values, y_pixel_values = pixel_shapes[i][0], pixel_shapes[i][1]
if x_pixel_values > 0 and y_pixel_values > 0:
pixel_values = torch.zeros(x_pixel_values, y_pixel_values, device=device, dtype=torch.bfloat16)
else:
pixel_values = None
seq_len, mbs = shapes[i][0][0], shapes[i][0][1]
input_ids = torch.zeros(mbs, seq_len, device=device, dtype=data_type)
attention_mask, image_grid_thw = get_batch_infos(attention_infos, thws, shapes, i)
batch = {
'input_ids': input_ids,
'labels': None,
'pixel_values': pixel_values,
'attention_mask': attention_mask,
'image_grid_thw': image_grid_thw,
'tranfer': None
}
all_batchs[0].append(batch)
all_batchs[1].append(batch)
all_batchs[2].append(batch)
return all_batchs, recv_tensor_shapes, vit_recv_tensor_shapes
def forward_backward_pipelining_without_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
"""
Run non-interleaved 1F1B schedule, with communication between pipeline
stages. Returns dictionary with losses if the last stage, empty dict otherwise.
"""
if not isinstance(model, list):
raise TypeError("cloud-edge pipeline parallelism expected model chunking")
if not all(isinstance(chunk, torch.nn.Module) for chunk in model):
raise TypeError("invalid model chunking")
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
data_iterator = [None]
config = get_model_config(model[0])
config.variable_seq_lengths = False
config.layerwise_disaggregated_training = True
forward_step_func = forward_step_impl
if config.finalize_model_grads_func is not None and not forward_only:
embedding_module = clear_embedding_activation_buffer(config, model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if isinstance(no_sync_func, list):
for func in no_sync_func:
no_sync_context = func()
no_sync_context.__enter__()
else:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
model_type = get_model_type(model[0])
encoder_decoder_xattn = get_model_xattn(model[0])
rank = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(
rank=rank - 1,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
encoder_decoder_xattn=encoder_decoder_xattn,
)
send_tensor_shapes = get_tensor_shapes(
rank=rank,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
encoder_decoder_xattn=encoder_decoder_xattn,
)
input_tensors = []
output_tensors = []
vit_input_tensors = []
vit_output_tensors = []
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
forward_data_store = []
global default_stream
if default_stream is None:
default_stream = torch.cuda.default_stream()
global stream_ping
if stream_ping is None:
stream_ping = torch.cuda.Stream()
global stream_pang
if stream_pang is None:
stream_pang = torch.cuda.Stream()
global stream_last_to_first
if stream_last_to_first is None:
stream_last_to_first = torch.cuda.Stream()
global stream_first_to_last
if stream_first_to_last is None:
stream_first_to_last = torch.cuda.Stream()
group_ping = get_pipeline_model_parallel_group()
group_pang = get_pipeline_model_parallel_group_alternate()
group_last_to_first = get_pipeline_model_parallel_group_last_to_first()
group_first_to_last = get_pipeline_model_parallel_group_first_to_last()
vtp_active = is_vtp_enabled()
vtp_need_asymmetric_fwd = False
vtp_need_asymmetric_bwd = False
vtp_send_forward_group = None
vtp_recv_forward_group = None
vtp_send_backward_group = None
vtp_recv_backward_group = None
if vtp_active:
vtp_size_list = get_vtp_size_list()
my_stage = get_vtp_my_stage_idx()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
next_stage = (my_stage + 1) % pp_size
prev_stage = (my_stage - 1) % pp_size
vtp_need_asymmetric_fwd = vtp_size_list[my_stage] != vtp_size_list[next_stage]
vtp_need_asymmetric_bwd = vtp_size_list[my_stage] != vtp_size_list[prev_stage]
if parallel_state.get_pipeline_model_parallel_rank() % 2 == 0:
receive_forward_stream = receive_backward_stream = stream_ping
send_forward_stream = send_backward_stream = stream_pang
receive_forward_group = receive_backward_group = group_ping
send_forward_group = send_backward_group = group_pang
else:
receive_forward_stream = receive_backward_stream = stream_pang
send_forward_stream = send_backward_stream = stream_ping
receive_forward_group = receive_backward_group = group_pang
send_forward_group = send_backward_group = group_ping
if parallel_state.get_pipeline_model_parallel_world_size() % 2 == 1:
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
receive_forward_stream = stream_last_to_first
receive_forward_group = group_last_to_first
send_backward_stream = stream_first_to_last
send_backward_group = group_first_to_last
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
receive_backward_stream = stream_first_to_last
receive_backward_group = group_first_to_last
send_forward_stream = stream_last_to_first
send_forward_group = group_last_to_first
if vtp_need_asymmetric_fwd or vtp_need_asymmetric_bwd:
vtp_send_forward_group = send_forward_group
vtp_recv_forward_group = receive_forward_group
vtp_send_backward_group = send_backward_group
vtp_recv_backward_group = receive_backward_group
if not isinstance(receive_forward_group, list):
receive_forward_group = [receive_forward_group]
if not isinstance(receive_backward_group, list):
receive_backward_group = [receive_backward_group]
if not isinstance(send_forward_group, list):
send_forward_group = [send_forward_group]
if not isinstance(send_backward_group, list):
send_backward_group = [send_backward_group]
_vtp_pending_sends = []
def wait_helper(reqs_list):
is_wait = False
recv_prev = False
for reqs in reqs_list:
if reqs is None:
continue
if "recv_prev" in reqs.keys():
recv_prev = True
for req in reqs if isinstance(reqs, list) else reqs.values():
req.wait()
is_wait = True
if is_wait:
if recv_prev:
default_stream.wait_stream(receive_forward_stream)
else:
default_stream.wait_stream(receive_backward_stream)
reqs_list = []
def send_forward_with_stream(
output_tensor, send_tensor_shapes, config, is_end_stage=False, **kwargs
):
with torch.cuda.stream(send_forward_stream):
send_forward_stream.wait_stream(default_stream)
if vtp_need_asymmetric_fwd:
if (parallel_state.is_pipeline_first_stage(ignore_virtual=True)
and is_end_stage):
return
if is_vdp_enabled():
handles = _vtp_send_forward_wrapper(
output_tensor, send_tensor_shapes, config,
is_end_stage=is_end_stage, **kwargs
)
_vtp_pending_sends.extend(handles)
else:
handles = _vtp_send_forward_wrapper(
output_tensor, send_tensor_shapes, config,
group=vtp_send_forward_group,
is_end_stage=is_end_stage,
)
_vtp_pending_sends.extend(handles)
else:
send_forward(
output_tensor, send_tensor_shapes, config, is_end_stage, **kwargs
)
if output_tensor is not None:
if isinstance(output_tensor, list):
for output_tensor_i in output_tensor:
if output_tensor_i is not None:
output_tensor_i.record_stream(send_forward_stream)
else:
output_tensor.record_stream(send_forward_stream)
def recv_forward_with_stream(
recv_tensor_shapes, config, is_end_stage=False, **kwargs
):
with torch.cuda.stream(receive_forward_stream):
if vtp_need_asymmetric_bwd:
if (parallel_state.is_pipeline_first_stage(ignore_virtual=True) and not is_end_stage):
default_stream.wait_stream(receive_forward_stream)
if kwargs.get("wait_on_reqs", True):
return [None]
return [None], []
wait_on_reqs = kwargs.get("wait_on_reqs", True)
vtp_group = vtp_recv_forward_group
if is_vdp_enabled():
if wait_on_reqs:
input_tensor = _vtp_recv_forward_wrapper(
recv_tensor_shapes, config,
async_op=False, is_end_stage=is_end_stage, **kwargs
)
for input_tensor_i in input_tensor:
if input_tensor_i is not None:
input_tensor_i.record_stream(default_stream)
default_stream.wait_stream(receive_forward_stream)
return input_tensor
else:
input_tensor, reqs_list = _vtp_recv_forward_wrapper(
recv_tensor_shapes, config,
async_op=True, is_end_stage=is_end_stage, **kwargs
)
for input_tensor_i in input_tensor:
if input_tensor_i is not None:
input_tensor_i.record_stream(default_stream)
return input_tensor, reqs_list
else:
vtp_group = vtp_recv_forward_group
if wait_on_reqs:
input_tensor = _vtp_recv_forward_wrapper(
recv_tensor_shapes, config, group=vtp_group,
async_op=False, is_end_stage=is_end_stage,
)
for input_tensor_i in input_tensor:
if input_tensor_i is not None:
input_tensor_i.record_stream(default_stream)
default_stream.wait_stream(receive_forward_stream)
return input_tensor
else:
input_tensor, reqs_list = _vtp_recv_forward_wrapper(
recv_tensor_shapes, config, group=vtp_group,
async_op=True, is_end_stage=is_end_stage,
)
for input_tensor_i in input_tensor:
if input_tensor_i is not None:
input_tensor_i.record_stream(default_stream)
return input_tensor, reqs_list
else:
input_tensor, reqs_list = recv_forward_with_reqs(
recv_tensor_shapes, config, is_end_stage, **kwargs
)
for input_tensor_i in input_tensor:
if input_tensor_i is not None:
input_tensor_i.record_stream(default_stream)
if "wait_on_reqs" in kwargs.keys():
if kwargs["wait_on_reqs"] is True:
default_stream.wait_stream(receive_forward_stream)
return input_tensor
else:
default_stream.wait_stream(receive_forward_stream)
return input_tensor
return input_tensor, reqs_list
def send_backward_with_stream(
input_tensor_grad, recv_tensor_shapes, config, is_end_stage=False, **kwargs
):
with torch.cuda.stream(send_backward_stream):
send_backward_stream.wait_stream(default_stream)
if vtp_need_asymmetric_bwd:
if (parallel_state.is_pipeline_first_stage(ignore_virtual=True)
and not is_end_stage):
return
if is_vdp_enabled():
handles = _vtp_send_backward_wrapper(
input_tensor_grad, recv_tensor_shapes, config,
is_end_stage=is_end_stage, **kwargs
)
_vtp_pending_sends.extend(handles)
else:
handles = _vtp_send_backward_wrapper(
input_tensor_grad, recv_tensor_shapes, config,
group=vtp_send_backward_group,
is_end_stage=is_end_stage,
)
_vtp_pending_sends.extend(handles)
else:
send_backward(
input_tensor_grad, recv_tensor_shapes, config, is_end_stage, **kwargs
)
if input_tensor_grad is not None:
if isinstance(input_tensor_grad, list):
for input_tensor_grad_i in input_tensor_grad:
if input_tensor_grad_i is not None:
input_tensor_grad_i.record_stream(send_backward_stream)
else:
input_tensor_grad.record_stream(send_backward_stream)
def recv_backward_with_stream(
recv_tensor_shapes, config, is_end_stage=False, **kwargs
):
wait_on_reqs = kwargs.get("wait_on_reqs", True)
with torch.cuda.stream(receive_backward_stream):
if vtp_need_asymmetric_fwd:
if (parallel_state.is_pipeline_first_stage(ignore_virtual=True)
and is_end_stage):
default_stream.wait_stream(receive_backward_stream)
return [None], []
vtp_group = vtp_recv_backward_group
if is_vdp_enabled():
if wait_on_reqs:
output_tensor_grad = _vtp_recv_backward_wrapper(
recv_tensor_shapes, config,
async_op=False, is_end_stage=is_end_stage, **kwargs
)
for output_tensor_grad_i in output_tensor_grad:
if output_tensor_grad_i is not None:
output_tensor_grad_i.record_stream(default_stream)
default_stream.wait_stream(receive_backward_stream)
return output_tensor_grad, []
else:
output_tensor_grad, reqs_list = _vtp_recv_backward_wrapper(
recv_tensor_shapes, config,
async_op=True, is_end_stage=is_end_stage, **kwargs
)
for output_tensor_grad_i in output_tensor_grad:
if output_tensor_grad_i is not None:
output_tensor_grad_i.record_stream(default_stream)
return output_tensor_grad, reqs_list
else:
if wait_on_reqs:
output_tensor_grad = _vtp_recv_backward_wrapper(
recv_tensor_shapes, config, group=vtp_group,
async_op=False, is_end_stage=is_end_stage,
)
for output_tensor_grad_i in output_tensor_grad:
if output_tensor_grad_i is not None:
output_tensor_grad_i.record_stream(default_stream)
default_stream.wait_stream(receive_backward_stream)
return output_tensor_grad, []
else:
output_tensor_grad, reqs_list = _vtp_recv_backward_wrapper(
recv_tensor_shapes, config, group=vtp_group,
async_op=True, is_end_stage=is_end_stage,
)
for output_tensor_grad_i in output_tensor_grad:
if output_tensor_grad_i is not None:
output_tensor_grad_i.record_stream(default_stream)
return output_tensor_grad, reqs_list
else:
output_tensor_grad, reqs_list = recv_backward_with_reqs(
recv_tensor_shapes, config, is_end_stage, **kwargs
)
for output_tensor_grad_i in output_tensor_grad:
if output_tensor_grad_i is not None:
output_tensor_grad_i.record_stream(default_stream)
if wait_on_reqs:
default_stream.wait_stream(receive_backward_stream)
return output_tensor_grad, []
return output_tensor_grad, reqs_list
if hasattr(model[0], 'module'):
float16_wrapper = model[0].module
else:
float16_wrapper = model[0]
if hasattr(float16_wrapper, 'module'):
ldt_vlm_model = float16_wrapper.module
else:
ldt_vlm_model = float16_wrapper
vit_hidden_size = 0
if hasattr(ldt_vlm_model, 'image_encoder'):
try:
mlp_linear = ldt_vlm_model.image_encoder.encoder.blocks.layers[0].mlp.linear_fc1
vit_hidden_size = mlp_linear.input_size
except Exception as e:
raise AssertionError(f"Failed to read hidden size from VIT: {str(e)}")
all_batchs, recv_forward_tensor_shapes, vit_recv_fwd_tensor_shapes = get_all_batchs(
num_microbatches, data_iterator, model[0], config, vit_hidden_size)
recv_backward_tensor_shapes = recv_forward_tensor_shapes.copy()
vit_recv_bwd_tensor_shapes = vit_recv_fwd_tensor_shapes.copy()
pp_group = get_pipeline_model_parallel_group()
if not isinstance(pp_group, list):
pp_group = [pp_group]
next_rank = get_pipeline_model_parallel_next_rank()
if not isinstance(next_rank, list):
next_rank = [next_rank]
prev_rank = get_pipeline_model_parallel_prev_rank()
if not isinstance(prev_rank, list):
prev_rank = [prev_rank]
num_vit_warmup = min(parallel_state.get_pipeline_model_parallel_world_size(), num_microbatches)
num_forward_end_backward_start = int(
(4 * parallel_state.get_pipeline_model_parallel_world_size() + 1) / 6 + .00001
)
input_tensor_tmp = None
vit_input_tensor_tmp = None
vdp_input_tensor_tmp = None
input_tensor_queue = []
vit_input_tensor_queue = []
vdp_input_tensor_queue = []
reqs_list = []
vit_reqs_list = []
vdp_reqs_list = []
reqs_queue = []
vit_reqs_queue = []
vdp_reqs_queue = []
mbn = num_microbatches * get_vdp_size() if parallel_state.is_pipeline_first_stage(ignore_virtual=True) else num_microbatches
vit_fwd_num = vit_bwd_num = mbn
llm_fwd_num = llm_bwd_num = mbn
last_stage_febs_num = mbn
group_iter = [i_group for i_group in range(len(pp_group))]
def set_vpp_rank(vpp_rank):
parallel_state.set_virtual_pipeline_model_parallel_rank(vpp_rank)
set_vpp_rank(0)
for i in range(num_vit_warmup):
last_iteration = i == (num_vit_warmup - 1)
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
i % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
if i == 0:
for rfg, nr, pr in zip(receive_forward_group, next_rank, prev_rank):
if not parallel_state.is_pipeline_first_stage(ignore_virtual=True):
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
input_tensor = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=rfg,
next_rank=nr,
prev_rank=pr,
)
input_tensor_queue.append(input_tensor)
reqs_queue.append([])
for i_group, sfg, nr, pr in zip(group_iter, send_forward_group, next_rank, prev_rank):
reqs_list = reqs_queue.pop(0)
wait_helper(reqs_list)
if i_group == 0 and not last_iteration:
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
input_tensor_queue.append([None])
reqs_queue.append([])
else:
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
input_tensor_tmp, reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
wait_on_reqs=False,
)
input_tensor_queue.append(input_tensor_tmp)
reqs_queue.append(reqs_list)
set_vpp_rank(1)
if i_group == 0 and last_iteration and num_warmup_microbatches > 0:
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
else:
recv_tensor_shapes = recv_forward_tensor_shapes.pop(0)
input_tensor_tmp, reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
is_end_stage=True,
wait_on_reqs=False,
)
input_tensor_queue.append(input_tensor_tmp)
reqs_queue.append(reqs_list)
set_vpp_rank(0)
this_iterator = None
this_model = model[0]
input_tensor = input_tensor_queue.pop(0)
output_tensor, num_tokens = forward_step(
forward_step_func,
this_iterator,
this_model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
encoder_decoder_xattn=encoder_decoder_xattn,
batch=all_batchs[0].pop(0)
)
vit_fwd_num -= 1
total_num_tokens += num_tokens
send_forward_with_stream(
output_tensor,
send_tensor_shapes,
config,
group=sfg,
next_rank=nr,
prev_rank=pr,
)
if not forward_only:
vit_input_tensors.append(input_tensor)
vit_output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
for i in range(num_warmup_microbatches):
set_vpp_rank(1)
last_iteration = i == (num_warmup_microbatches - 1)
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
i % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
for i_group, sfg, nr, pr in zip(group_iter, send_forward_group, next_rank, prev_rank):
reqs_list = reqs_queue.pop(0)
wait_helper(reqs_list)
if (
i_group == 0
and i >= num_forward_end_backward_start - 1
and vit_fwd_num > 0
):
set_vpp_rank(0)
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
vit_input_tensor_queue.append([None])
vit_reqs_queue.append([])
else:
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
vit_input_tensor_tmp, vit_reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
wait_on_reqs=False,
)
vit_input_tensor_queue.append(vit_input_tensor_tmp)
vit_reqs_queue.append(vit_reqs_list)
set_vpp_rank(1)
if (
i_group == 0
and llm_fwd_num > 1
and (not last_iteration or (last_iteration and not parallel_state.is_pipeline_first_stage(ignore_virtual=True)))
):
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
else:
recv_tensor_shapes = recv_forward_tensor_shapes.pop(0)
input_tensor_tmp, reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
is_end_stage=True,
wait_on_reqs=False,
)
input_tensor_queue.append(input_tensor_tmp)
reqs_queue.append(reqs_list)
if (
i_group == 0
and last_iteration
and parallel_state.is_pipeline_first_stage(ignore_virtual=True)
and num_forward_end_backward_start > 0
):
set_vpp_rank(2)
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
recv_tensor_shapes = recv_forward_tensor_shapes.pop(0)
vdp_input_tensor_tmp, vdp_reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
is_end_stage=True,
wait_on_reqs=False
)
vdp_input_tensor_queue.append(vdp_input_tensor_tmp)
vdp_reqs_queue.append(vdp_reqs_list)
set_vpp_rank(1)
this_iterator = None
this_model = model[1]
input_tensor = input_tensor_queue.pop(0)
output_tensor, num_tokens = forward_step(
forward_step_func,
this_iterator,
this_model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
encoder_decoder_xattn=encoder_decoder_xattn,
batch=all_batchs[1].pop(0)
)
llm_fwd_num -= 1
total_num_tokens += num_tokens
send_forward_with_stream(
output_tensor,
send_tensor_shapes,
config,
group=sfg,
next_rank=nr,
prev_rank=pr,
)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
if i >= num_forward_end_backward_start - 1 and vit_fwd_num > 0:
for sfg, nr, pr in zip(send_forward_group, next_rank, prev_rank):
set_vpp_rank(0)
reqs_list = vit_reqs_queue.pop(0)
wait_helper(vit_reqs_list)
this_iterator = None
this_model = model[0]
input_tensor = vit_input_tensor_queue.pop(0)
output_tensor, num_tokens = forward_step(
forward_step_func,
this_iterator,
this_model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
encoder_decoder_xattn=encoder_decoder_xattn,
batch=all_batchs[0].pop(0)
)
vit_fwd_num -= 1
total_num_tokens += num_tokens
send_forward_with_stream(
output_tensor,
send_tensor_shapes,
config,
group=sfg,
next_rank=nr,
prev_rank=pr,
)
if not forward_only:
vit_input_tensors.append(input_tensor)
vit_output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
set_vpp_rank(2)
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
for i in range(num_forward_end_backward_start):
if last_stage_febs_num == 0:
break
last_iteration = i == (num_forward_end_backward_start - 1)
for i_group, sbg, nr, pr in zip(group_iter, send_backward_group, next_rank, prev_rank):
vdp_reqs_list = vdp_reqs_queue.pop(0)
wait_helper(vdp_reqs_list)
if i_group == 0 and not last_iteration:
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
recv_tensor_shapes = recv_forward_tensor_shapes.pop(0)
vdp_input_tensor_tmp, vdp_reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
is_end_stage=True,
wait_on_reqs=False,
)
vdp_input_tensor_queue.append(vdp_input_tensor_tmp)
vdp_reqs_queue.append(vdp_reqs_list)
this_iterator = None
this_model = model[2]
input_tensor_end = vdp_input_tensor_queue.pop(0)
output_tensor_end, num_tokens = forward_step(
forward_step_func,
this_iterator,
this_model,
num_microbatches,
input_tensor_end,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
encoder_decoder_xattn=encoder_decoder_xattn,
is_end_stage=True,
batch=all_batchs[2].pop(0)
)
last_stage_febs_num -= 1
total_num_tokens += num_tokens
if not forward_only:
output_tensor_grad_end = [None] * len(recv_tensor_shapes)
deallocate_output_tensor(output_tensor_end[0], config.deallocate_pipeline_outputs)
input_tensor_grad_end = backward_step(
input_tensor_end, output_tensor_end, output_tensor_grad_end, model_type, config
)
if last_iteration:
input_tensor_end = None
send_backward_with_stream(
input_tensor_grad_end,
send_tensor_shapes,
config,
group=sbg,
next_rank=nr,
prev_rank=pr,
is_end_stage=True,
)
output_tensor_grad_tmp = None
vit_output_tensor_grad_tmp = None
output_tensor_grad_queue = []
vit_output_tensor_grad_queue = []
for i in range(num_microbatches):
last_iteration = i == num_microbatches - 1
if i == 0:
if parallel_state.is_pipeline_first_stage(ignore_virtual=True) and llm_fwd_num > 0:
set_vpp_rank(1)
for rfg, nr, pr in zip(receive_forward_group, next_rank, prev_rank):
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
input_tensor_tmp = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=rfg,
next_rank=nr,
prev_rank=pr,
is_end_stage=True
)
input_tensor_queue.append(input_tensor_tmp)
reqs_queue.append([])
if vit_fwd_num > 0 and i + num_warmup_microbatches >= num_forward_end_backward_start - 1:
set_vpp_rank(0)
for rfg, nr, pr in zip(receive_forward_group, next_rank, prev_rank):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
vit_input_tensor_queue.append([None])
vit_reqs_queue.append([])
else:
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
vit_input_tensor_tmp, vit_reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=rfg,
next_rank=nr,
prev_rank=pr,
wait_on_reqs=False
)
vit_input_tensor_queue.append(vit_input_tensor_tmp)
vit_reqs_queue.append(vit_reqs_list)
if parallel_state.is_pipeline_first_stage(ignore_virtual=True) and last_stage_febs_num > 0:
set_vpp_rank(2)
for rfg, nr, pr in zip(receive_forward_group, next_rank, prev_rank):
recv_tensor_shapes = recv_forward_tensor_shapes.pop(0)
vdp_input_tensor_tmp, vdp_reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=rfg,
next_rank=nr,
prev_rank=pr,
is_end_stage=True,
wait_on_reqs=False
)
vdp_input_tensor_queue.append(vdp_input_tensor_tmp)
vdp_reqs_queue.append(vdp_reqs_list)
if llm_fwd_num > 0:
set_vpp_rank(1)
for sfg, nr, pr in zip(send_forward_group, next_rank, prev_rank):
reqs_list = reqs_queue.pop(0)
wait_helper(reqs_list)
this_iterator = None
this_model = model[1]
input_tensor = input_tensor_queue.pop(0)
output_tensor, num_tokens = forward_step(
forward_step_func,
this_iterator,
this_model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step, forward_only, (llm_fwd_num == mbn)
),
current_microbatch=i + num_warmup_microbatches,
encoder_decoder_xattn=encoder_decoder_xattn,
batch=all_batchs[1].pop(0)
)
llm_fwd_num -= 1
total_num_tokens += num_tokens
send_forward_with_stream(
output_tensor,
send_tensor_shapes,
config,
group=sfg,
next_rank=nr,
prev_rank=pr,
)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
for i_group, sfg, nr, pr in zip(group_iter, send_forward_group, next_rank, prev_rank):
if vit_fwd_num > 0 and i + num_warmup_microbatches >= num_forward_end_backward_start - 1:
set_vpp_rank(0)
vit_reqs_list = vit_reqs_queue.pop(0)
wait_helper(vit_reqs_list)
if i_group == 0 and llm_bwd_num > 0:
set_vpp_rank(1)
for future_rbg, future_nr, future_pr in zip(receive_backward_group, next_rank, prev_rank):
recv_tensor_shapes = recv_backward_tensor_shapes.pop(0)
output_tensor_grad_tmp, reqs_list = recv_backward_with_stream(
recv_tensor_shapes,
config,
group=future_rbg,
next_rank=future_nr,
prev_rank=future_pr,
wait_on_reqs=False,
)
output_tensor_grad_queue.append(output_tensor_grad_tmp)
reqs_queue.append(reqs_list)
if i_group == 0 and vit_bwd_num > 0 and i >= num_forward_end_backward_start and not forward_only:
set_vpp_rank(0)
for future_rbg, future_nr, future_pr in zip(receive_backward_group, next_rank, prev_rank):
recv_tensor_shapes = vit_recv_bwd_tensor_shapes.pop(0)
vit_output_tensor_grad_tmp, vit_reqs_list = recv_backward_with_stream(
recv_tensor_shapes,
config,
group=future_rbg,
next_rank=future_nr,
prev_rank=future_pr,
wait_on_reqs=False
)
vit_output_tensor_grad_queue.append(vit_output_tensor_grad_tmp)
vit_reqs_queue.append(vit_reqs_list)
vit_output_tensor = None
if vit_fwd_num > 0 and i + num_warmup_microbatches >= num_forward_end_backward_start - 1:
set_vpp_rank(0)
this_iterator = None
this_model = model[0]
input_tensor = vit_input_tensor_queue.pop(0)
vit_output_tensor, num_tokens = forward_step(
forward_step_func,
this_iterator,
this_model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step, forward_only, (vit_fwd_num == mbn)
),
current_microbatch=i + num_warmup_microbatches,
encoder_decoder_xattn=encoder_decoder_xattn,
batch=all_batchs[0].pop(0)
)
vit_fwd_num -= 1
total_num_tokens += num_tokens
if parallel_state.get_pipeline_model_parallel_world_size() > 2 or not parallel_state.is_pipeline_last_stage(ignore_virtual=True):
send_forward_with_stream(
vit_output_tensor,
send_tensor_shapes,
config,
group=sfg,
next_rank=nr,
prev_rank=pr,
)
if not forward_only:
vit_input_tensors.append(input_tensor)
vit_output_tensors.append(vit_output_tensor)
deallocate_output_tensor(vit_output_tensor[0], config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_first_stage(ignore_virtual=True) and last_stage_febs_num > 0:
for sbg, nr, pr in zip(send_backward_group, next_rank, prev_rank):
set_vpp_rank(2)
vdp_reqs_list = vdp_reqs_queue.pop(0)
wait_helper(vdp_reqs_list)
this_iterator = None
this_model = model[2]
input_tensor_end = vdp_input_tensor_queue.pop(0)
output_tensor_end, num_tokens = forward_step(
forward_step_func,
this_iterator,
this_model,
num_microbatches,
input_tensor_end,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step, forward_only, (last_stage_febs_num == mbn)
),
current_microbatch=i + num_warmup_microbatches,
encoder_decoder_xattn=encoder_decoder_xattn,
is_end_stage=True,
batch=all_batchs[2].pop(0)
)
last_stage_febs_num -= 1
total_num_tokens += num_tokens
if not forward_only:
deallocate_output_tensor(output_tensor_end[0], config.deallocate_pipeline_outputs)
output_tensor_grad_end = [None] * len(recv_tensor_shapes)
input_tensor_grad_end = backward_step(
input_tensor_end, output_tensor_end, output_tensor_grad_end, model_type, config
)
send_backward_with_stream(
input_tensor_grad_end,
send_tensor_shapes,
config,
group=sbg,
next_rank=nr,
prev_rank=pr,
is_end_stage=True,
)
for i_group, sbg, nr, pr in zip(group_iter, send_backward_group, next_rank, prev_rank):
if llm_bwd_num > 0:
set_vpp_rank(1)
reqs_list = reqs_queue.pop(0)
wait_helper(reqs_list)
if i_group == 0 and not last_iteration and llm_fwd_num > 0:
set_vpp_rank(1)
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
else:
recv_tensor_shapes = recv_forward_tensor_shapes.pop(0)
input_tensor_tmp, reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
is_end_stage=True,
wait_on_reqs=False,
)
input_tensor_queue.append(input_tensor_tmp)
reqs_queue.append(reqs_list)
if llm_bwd_num > 0:
set_vpp_rank(1)
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = output_tensor_grad_queue.pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
llm_bwd_num -= 1
send_backward_with_stream(
input_tensor_grad,
send_tensor_shapes,
config,
group=sbg,
next_rank=nr,
prev_rank=pr,
is_end_stage=True,
)
for i_group, sbg, nr, pr in zip(group_iter, send_backward_group, next_rank, prev_rank):
if vit_bwd_num > 0 and i >= num_forward_end_backward_start and not forward_only:
set_vpp_rank(0)
vit_reqs_list = vit_reqs_queue.pop(0)
wait_helper(vit_reqs_list)
if i_group == 0 and not last_iteration:
if vit_fwd_num > 0 and i + num_warmup_microbatches + 1 >= num_forward_end_backward_start - 1:
set_vpp_rank(0)
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
vit_input_tensor_queue.append([None])
vit_reqs_queue.append([])
else:
recv_tensor_shapes = vit_recv_fwd_tensor_shapes.pop(0)
vit_input_tensor_tmp, vit_reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
wait_on_reqs=False,
)
vit_input_tensor_queue.append(vit_input_tensor_tmp)
vit_reqs_queue.append(vit_reqs_list)
if parallel_state.is_pipeline_first_stage(ignore_virtual=True) and last_stage_febs_num > 0:
set_vpp_rank(2)
for future_rfg, future_nr, future_pr in zip(receive_forward_group, next_rank, prev_rank):
recv_tensor_shapes = recv_forward_tensor_shapes.pop(0)
vdp_input_tensor_tmp, vdp_reqs_list = recv_forward_with_stream(
recv_tensor_shapes,
config,
group=future_rfg,
next_rank=future_nr,
prev_rank=future_pr,
is_end_stage=True,
wait_on_reqs=False,
)
vdp_input_tensor_queue.append(vdp_input_tensor_tmp)
vdp_reqs_queue.append(vdp_reqs_list)
elif i_group == 0 and vit_bwd_num > 0 and last_iteration:
set_vpp_rank(0)
for future_rbg, future_nr, future_pr in zip(receive_backward_group, next_rank, prev_rank):
recv_tensor_shapes = vit_recv_bwd_tensor_shapes.pop(0)
vit_output_tensor_grad_tmp, vit_reqs_list = recv_backward_with_stream(
recv_tensor_shapes,
config,
group=future_rbg,
next_rank=future_nr,
prev_rank=future_pr,
wait_on_reqs=False,
)
vit_output_tensor_grad_queue.append(vit_output_tensor_grad_tmp)
vit_reqs_queue.append(vit_reqs_list)
if vit_bwd_num > 0 and i >= num_forward_end_backward_start and not forward_only:
set_vpp_rank(0)
if vit_bwd_num == 1:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()
input_tensor = vit_input_tensors.pop(0)
output_tensor = vit_output_tensors.pop(0)
output_tensor_grad = vit_output_tensor_grad_queue.pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
vit_bwd_num -= 1
send_backward_with_stream(
input_tensor_grad,
send_tensor_shapes,
config,
group=sbg,
next_rank=nr,
prev_rank=pr,
)
if (
vit_output_tensor is not None
and parallel_state.get_pipeline_model_parallel_world_size() == 2
and parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
set_vpp_rank(0)
send_forward_with_stream(
vit_output_tensor,
send_tensor_shapes,
config,
group=send_forward_group,
)
if not forward_only:
vit_input_tensors.append(input_tensor)
vit_output_tensors.append(vit_output_tensor)
deallocate_output_tensor(vit_output_tensor[0], config.deallocate_pipeline_outputs)
set_vpp_rank(0)
if not forward_only:
while vit_bwd_num > 0:
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
last_iteration = vit_bwd_num == get_vdp_size()
else:
last_iteration = vit_bwd_num == 1
if last_iteration:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()
for i_group, sbg, nr, pr in zip(group_iter, send_backward_group, next_rank, prev_rank):
vit_reqs_list = vit_reqs_queue.pop(0)
wait_helper(vit_reqs_list)
if i_group == 0 and not last_iteration:
for future_rbg, future_nr, future_pr in zip(receive_backward_group, next_rank, prev_rank):
recv_tensor_shapes = vit_recv_bwd_tensor_shapes.pop(0)
vit_output_tensor_grad_tmp, vit_reqs_list = recv_backward_with_stream(
recv_tensor_shapes,
config,
group=future_rbg,
next_rank=future_nr,
prev_rank=future_pr,
wait_on_reqs=False,
)
vit_output_tensor_grad_queue.append(vit_output_tensor_grad_tmp)
vit_reqs_queue.append(vit_reqs_list)
input_tensor = vit_input_tensors.pop(0)
output_tensor = vit_output_tensors.pop(0)
output_tensor_grad = vit_output_tensor_grad_queue.pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
vit_bwd_num -= 1
send_backward_with_stream(
input_tensor_grad,
send_tensor_shapes,
config,
group=sbg,
next_rank=nr,
prev_rank=pr,
)
if no_sync_context is not None:
enable_grad_sync()
if config.grad_sync_func is not None:
for this_model in model:
config.grad_sync_func(this_model.parameters())
if config.finalize_model_grads_func is not None and not forward_only:
finish_embedding_wgrad_compute(config, embedding_module)
this_model = model if parallel_state.is_pipeline_first_stage(ignore_virtual=True) else [model[0]]
config.finalize_model_grads_func(
this_model, total_num_tokens if config.calculate_per_token_loss else None
)
if config.timers is not None:
config.timers('forward-backward').stop()
_drain_vtp_sends()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store