from typing import Optional, List, Tuple
import torch
import torch.distributed as dist
from .utils import cal_split_sizes, reorder_output, cal_split_sizes_multi
from ...distributed.parallel_state import get_parallel_state
PERMUTE_DIMS1 = {
4: (1, 2, 3, 0),
5: (1, 2, 3, 0, 4),
}
PERMUTE_DIMS2 = {
4: (1, 2, 0, 3),
5: (1, 2, 0, 3, 4),
}
def adjust_tensor_dimensions(tensor, scatter_idx, gather_idx):
"""
Adjusts the dimensions of a tensor to move scatter_idx and gather_idx to dim 0 and dim 1 respectively.
Args:
tensor (torch.Tensor): The input tensor.
scatter_idx (int): The index of the dimension to scatter.
gather_idx (int): The index of the dimension to gather.
Returns:
tuple: A tuple containing the adjusted tensor and the list of adjusted dimensions.
"""
dims = list(range(tensor.dim()))
if gather_idx == 0:
if scatter_idx != 1:
dims[1], dims[gather_idx] = dims[gather_idx], dims[1]
dims[0], dims[scatter_idx] = dims[scatter_idx], dims[0]
else:
dims[scatter_idx], dims[gather_idx] = dims[gather_idx], dims[scatter_idx]
elif gather_idx == 1:
if scatter_idx != 0:
dims[0], dims[scatter_idx] = dims[scatter_idx], dims[0]
else:
if scatter_idx == 0:
dims[1], dims[gather_idx] = dims[gather_idx], dims[1]
else:
dims[0], dims[scatter_idx] = dims[scatter_idx], dims[0]
dims[1], dims[gather_idx] = dims[gather_idx], dims[1]
return tensor.permute(dims).contiguous(), dims
def unadjust_tensor_dimensions(tensor, adjusted_dims):
"""
Reverses the dimension adjustments using the list of adjusted dimensions.
Args:
tensor (torch.Tensor): The tensor whose dimensions need to be restored.
adjusted_dims (list): The list of adjusted dimensions used during the adjustment process.
Returns:
torch.Tensor: The tensor with its dimensions reverted to the original order.
"""
inverse_dims = [0] * len(adjusted_dims)
for new_pos, old_pos in enumerate(adjusted_dims):
inverse_dims[old_pos] = new_pos
unadjusted_tensor = tensor.permute(inverse_dims).contiguous()
return unadjusted_tensor
def _all_to_all(
input_: torch.Tensor,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
gather_size: Optional[int] = None
):
"""
Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
dimension and then gathers it along the specified gather dimension. The function supports aligned and unaligned
data.
Args:
input_ (torch.Tensor): The input tensor to be processed.
group (dist.ProcessGroup): The process group perform the operation within.
scatter_dim (int): The index of the dimension that needs to be scattered.
gather_dim (int): The index of the dimension that needs to be gathered.
gather_size (Optional[int]): The total size of the output tensor along the `gather_dim`. If not provided, it
will be calculated as the product of the original size of the `gather_dim` of the input tensor and the
`world_size`.
Returns:
torch.Tensor: The resulting tensor after performing the all-to-all operation.
Note:
- The tensor will be split into `world_size` chunks along the `scatter_dim`. Each process will receive one
chunk. If the total size of the `scatter_dim` is not divisible by `world_size`, the extra elements will be
distributed to the first few processes, ensuring that no process receives more than one additional element
compared to the others.
- The tensor will be gathered along the `gather_dim`, with each process contributing its part to form the
final output tensor. The gathering process also supports unaligned data, where the remainder elements
are distributed to the first few processes.
"""
world_size = dist.get_world_size(group)
if world_size == 1:
return input_
scatter_size = input_.size(scatter_dim)
if gather_size is None:
gather_size = input_.size(gather_dim) * world_size
gather_mod = gather_size % world_size
scatter_mod = scatter_size % world_size
if gather_mod == 0 and scatter_mod == 0:
return _aligned_all_to_all(input_, group, scatter_dim, gather_dim)
elif gather_mod != 0 and scatter_mod != 0:
return _full_unaligned_all_to_all(input_, group, scatter_dim, gather_dim, gather_size)
else:
return _partial_unaligned_all_to_all(input_, group, scatter_dim, gather_dim, gather_size)
def _full_unaligned_all_to_all(
input_: torch.Tensor,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
gather_size: Optional[int] = None
):
"""
Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
dimension and then gathers it along the specified gather dimension. This function supports unaligned scatter
and gather sizes.
Args:
input_ (torch.Tensor): The input tensor to be processed.
world_size (int): The number of processes in the process group.
group (dist.ProcessGroup): The process group to perform the operation within.
scatter_dim (int): The index of the dimension that needs to be scattered.
gather_dim (int): The index of the dimension that needs to be gathered.
gather_size (Optional[int]): The total size of the output tensor along the `gather_dim`. If not provided, it
will be calculated as the product of the original size of the `gather_dim` of the input tensor and the
`world_size`.
Returns:
torch.Tensor: The resulting tensor after performing the all-to-all operation.
"""
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)
scatter_sizes = cal_split_sizes(dim_size=input_.size(scatter_dim), world_size=world_size)
input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
gather_sizes = cal_split_sizes(dim_size=gather_size, world_size=world_size)
output_list = []
tensor_shape_base = input_list[rank].size()
for i in range(world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def _aligned_all_to_all(
input_: torch.Tensor,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
"""
Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
dimension and then gathers it along the specified gather dimension.
Special note: The function only supports aligned data (both scatter_size and gather_size are divisible by
world_size)
"""
world_size = dist.get_world_size(group)
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // world_size
if scatter_dim == 0:
input_t = input_.reshape([world_size] + inp_shape).contiguous()
else:
input_t = input_.reshape([-1, world_size] + inp_shape[scatter_dim:]).transpose(0, 1).contiguous()
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
output = output.view([world_size] + inp_shape).contiguous()
output_dim = output.dim()
if gather_dim == 1:
output = output.transpose(0, 1).contiguous()
elif gather_dim == 2:
output = output.permute(*PERMUTE_DIMS2[output_dim]).contiguous()
elif gather_dim == 3:
output = output.permute(*PERMUTE_DIMS1[output_dim]).contiguous()
output = output.view(inp_shape[:gather_dim] + [inp_shape[gather_dim] * world_size, ] + inp_shape[gather_dim + 1:]
).contiguous()
return output
def _partial_unaligned_all_to_all(
input_: torch.Tensor,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
gather_size: Optional[int] = None
):
"""
Helper function to perform the all-to-all operation. It scatters the input tensor along the specified scatter
dimension and then gathers it along the specified gather dimension. The function supports aligned and unaligned
data.
Special note: In the case of aligned data (both scatter_size and gather_size are divisible by world_size),
_partial_unaligned_all_to_all function performs worse than _aligned_all_to_all function. Therefore, in the case of
aligning data, it is recommended to use _aligned_all_to_all function.
"""
world_size = dist.get_world_size(group)
input_ = input_.contiguous()
rank = dist.get_rank(group=group)
scatter_size = input_.size(scatter_dim)
if gather_size is None:
gather_size = input_.size(gather_dim) * world_size
scatter_size_per_rank = scatter_size // world_size
scatter_size_remainder = scatter_size % world_size
input_split_sizes = [scatter_size_per_rank + (1 if i < scatter_size_remainder else 0) for i in range(world_size)]
gather_size_per_rank = gather_size // world_size
gather_size_remainder = gather_size % world_size
output_split_sizes = [gather_size_per_rank + (1 if i < gather_size_remainder else 0) for i in range(world_size)]
reshaped_input, reshaped_input_dims = adjust_tensor_dimensions(input_, scatter_dim, gather_dim)
reshaped_input_shape = list(reshaped_input.shape)
if scatter_size % world_size == 0:
reshaped_input = reshaped_input.view(
[world_size, input_.size(scatter_dim) // world_size, input_.size(gather_dim)] + reshaped_input_shape[2:]
).transpose(1, 2).contiguous()
output_dims = reshaped_input_dims
output_dims[1], output_dims[0] = output_dims[0], output_dims[1]
output = torch.empty((gather_size, input_split_sizes[rank], *reshaped_input_shape[2:]),
dtype=input_.dtype, device=input_.device)
output_shape = list(output.shape)
dist.all_to_all_single(
output,
reshaped_input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes if scatter_size % world_size != 0 else [1 for _ in range(world_size)],
group=group,
)
if gather_size % world_size == 0 and scatter_size % world_size != 0:
output = output.view(
[world_size, input_split_sizes[rank], gather_size // world_size] + reshaped_input_shape[2:]
).transpose(1, 2).reshape(output_shape).contiguous()
unadjust_output_ = unadjust_tensor_dimensions(output, output_dims)
return unadjust_output_
class _AllToAll(torch.autograd.Function):
"""Custom autograd function that performs an all-to-all communication.
This function supports both aligned and unaligned data.
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim, gather_size=None):
"""
Forward pass: Perform all-to-all communication by scattering the input tensor along the specified scatter
dimension and then gathering it along the specified gather dimension.
Args:
input_ (torch.Tensor): The input tensor to be processed.
process_group (dist.ProcessGroup): The process group to perform the operation within.
scatter_dim (int): The index of the dimension that needs to be scattered.
gather_dim (int): The index of the dimension that needs to be gathered.
gather_size (int): The size of the gather dimension.
Returns:
torch.Tensor: The resulting tensor after performing the all-to-all operation.
"""
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.scatter_size = input_.size(scatter_dim)
ctx.gather_dim = gather_dim
ctx.gather_size = gather_size
output = _all_to_all(
input_, process_group, scatter_dim, gather_dim, gather_size
)
return output
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass: Perform the reverse all-to-all communication
Args:
grad_output (torch.Tensor): The gradient of the output with respect to the loss.
Returns:
tuple: The gradient of the input with respect to the loss and `None` for other arguments.
"""
grad_output = _all_to_all(
grad_output,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
ctx.scatter_size
)
return (
grad_output,
None,
None,
None,
None,
None
)
def _split(
input_: torch.Tensor,
pg: dist.ProcessGroup,
dim: int = -1,
split_sizes: Optional[List[int]] = None
) -> torch.Tensor:
"""
Splits a tensor across the specified dimension and returns the part corresponding to the current rank,
supporting aligned and unaligned data.
Args:
input_ (torch.Tensor): The input tensor to be split.
pg (dist.ProcessGroup): The process group to perform the operation within.
dim (int, optional): The dimension along which to split the tensor. Defaults to -1 (last dimension).
split_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be split.
If not provided, the tensor will be split equally among the processes, with the remainder
distributed to the first few processes. Defaults to None.
Returns:
torch.Tensor: The part of the tensor corresponding to the current rank in the process group.
"""
if split_sizes is not None and not isinstance(split_sizes, list):
raise ValueError("split_sizes must be a list if provided.")
world_size = dist.get_world_size(pg)
if world_size == 1:
return input_
if split_sizes is None:
dim_size = input_.size(dim)
base_size = dim_size // world_size
remainder = dim_size % world_size
split_sizes = [base_size + 1 if i < remainder else base_size for i in range(world_size)]
tensor_list = torch.split(input_, split_sizes, dim=dim)
rank = dist.get_rank(pg)
output = tensor_list[rank].contiguous()
return output
def _gather(input_: torch.Tensor,
pg: dist.ProcessGroup,
dim: int = -1,
gather_sizes: Optional[List[int]] = None):
"""
Gathers tensors from all processes in the process group and concatenates them along the specified dimension,
supporting aligned and unaligned data.
Args:
input_ (torch.Tensor): The input tensor to be gathered.
pg (dist.ProcessGroup): The process group to perform the operation within.
dim (int, optional): The dimension along which to concatenate the gathered tensors. Defaults to -1 (last dimension).
gather_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be gathered.
If not provided, it is assumed that all tensors have the same shape as the input tensor. Defaults to None.
Returns:
torch.Tensor: The concatenated tensor after gathering from all processes in the process group.
"""
if gather_sizes is not None and not isinstance(gather_sizes, list):
raise ValueError("gather_sizes must be a list if provided.")
world_size = dist.get_world_size(pg)
if world_size == 1:
return input_
input_ = input_.contiguous()
if gather_sizes:
tensor_list = []
tensor_shape_base = input_.size()
for i in range(world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[dim] = gather_sizes[i]
tensor_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
else:
tensor_list = [torch.empty_like(input_, dtype=input_.dtype, device=input_.device) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_, group=pg)
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
Custom autograd function that gathers the input tensor from all processes in the model parallel region and
concatenates them.
During the backward pass, it splits the gradients and scales them according to the gradient scaling mode.
"""
@staticmethod
def symbolic(graph, input_, process_group, dim, gather_sizes):
"""
Define the symbolic representation of the custom operation.
"""
return _gather(input_, process_group, dim, gather_sizes)
@staticmethod
def forward(ctx, input_, process_group, dim, gather_sizes, grad_scale="up"):
"""
Forward pass: Gathers tensors from all processes in the specified process group and concatenates them along the specified dimension.
Args:
input_ (torch.Tensor): The input tensor to be processed.
process_group (dist.ProcessGroup): The process group to perform the operation within.
dim (int): The dimension along which to concatenate the gathered tensors.
gather_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be gathered.
grad_scale (str, optional): Gradient scaling mode. Can be "up", "down", or None. Defaults to "up".
Returns:
torch.Tensor: The resulting tensor after gathering and concatenating.
"""
ctx.mode = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.gather_sizes = gather_sizes
return _gather(input_, process_group, dim, ctx.gather_sizes)
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass: Distribute the gradients to the input tensors and scales them according to the gradient scaling mode.
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
torch.Tensor: The gradient of the input with respect to the loss.
"""
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.mode)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.mode)
return _split(grad_output, ctx.mode, ctx.dim, ctx.gather_sizes), None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Custom autograd function that splits the input tensor and keeps only the corresponding chunk for the current rank.
During the backward pass, it gathers the gradients and scales them according to the gradient scaling mode.
"""
@staticmethod
def symbolic(graph, input_, process_group, dim, split_sizes):
return _split(input_, process_group, dim, split_sizes)
@staticmethod
def forward(ctx, input_, process_group, dim, split_sizes, grad_scale):
ctx.mode = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.split_sizes = split_sizes
return _split(input_, process_group, dim, ctx.split_sizes)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.mode)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.mode)
return _gather(grad_output, ctx.mode, ctx.dim, ctx.split_sizes), None, None, None, None
def all_to_all(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
scatter_dim: int = 2,
gather_dim: int = 1,
gather_size: Optional[int] = None
):
"""
Performs an all-to-all operation on the input tensor. The input tensor is scattered along the specified scatter
dimension and then gathered along the specified gather dimension.
This function supports both aligned and unaligned data.
Args:
input_ (torch.Tensor): The input tensor to be processed.
process_group (dist.ProcessGroup): The process group to perform the operation within.
scatter_dim (int, optional): The index of the dimension that needs to be scattered. Defaults to 2.
gather_dim (int, optional): The index of the dimension that needs to be gathered. Defaults to 1.
gather_size (Optional[int]): The total size of the output tensor along the `gather_dim`. If not provided, it
will be calculated as the product of the original size of the `gather_dim` of the input tensor and the
`world_size`.
Returns:
torch.Tensor: The resulting tensor after performing the all-to-all operation.
"""
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, gather_size)
def split_forward_gather_backward(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
dim: int,
split_sizes: Optional[List[int]] = None,
grad_scale: str = "down"
) -> torch.Tensor:
"""
Splits the input tensor and keeps only the corresponding chunk for the current rank.
During the backward pass, it gathers the gradients and scales them according to the gradient scaling mode.
This function supports both aligned and unaligned data.
Args:
input_ (torch.Tensor): The input tensor to be processed.
process_group (dist.ProcessGroup): The process group to perform the operation within.
dim (int): The dimension along which to split the tensor.
split_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be split.
If not provided, the tensor will be split equally among the processes. Defaults to None.
grad_scale (str, optional): Gradient scaling mode. Can be "up", "down", or None. Defaults to "down".
Returns:
torch.Tensor: The resulting tensor after splitting and keeping only the corresponding chunk.
"""
return _SplitForwardGatherBackward.apply(input_, process_group, dim, split_sizes, grad_scale)
def gather_forward_split_backward(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
dim: int,
gather_sizes: Optional[List[int]] = None,
grad_scale: str = "up"
) -> torch.Tensor:
"""
Gathers the input tensor from all processes in the model parallel region and concatenates them along the specified
dimension. During the backward pass, it splits the gradients and scales them according to the gradient scaling mode.
This function handles both aligned and unaligned data during the gather and scatter operations.
Args:
input_ (torch.Tensor): The input tensor to be processed.
process_group (dist.ProcessGroup): The process group to perform the operation within.
dim (int): The dimension along which to concatenate the gathered tensors.
gather_sizes (Optional[List[int]], optional): A list of sizes for each part of the tensor to be gathered.
If not provided, it is assumed that all tensors have the same shape as the input tensor. Defaults to None.
grad_scale (str, optional): Gradient scaling mode. Can be "up", "down", or None. Defaults to "up".
Returns:
torch.Tensor: The resulting tensor after gathering and concatenating.
"""
return _GatherForwardSplitBackward.apply(input_, process_group, dim, gather_sizes, grad_scale)
class _SplitForwardGatherBackWardWithMegatronCP(torch.autograd.Function):
'''
Split the input tensor in the forward pass and gather the gradients in the backward pass.
It will be implemented in Mindspeed in the future.
'''
@staticmethod
def forward(ctx, val, cp_rank, cp_size, seq_dim, cp_group=None):
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1):],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2):])
ctx.cp_group = cp_group
ctx.cp_rank = cp_rank
ctx.cp_size = cp_size
ctx.seq_dim = seq_dim
return val
@staticmethod
def backward(ctx, grad_output):
grad_input = {}
grad_input = _gather(grad_output, ctx.cp_group, dim=ctx.seq_dim) / ctx.cp_size
grad_input = reorder_output(grad_input, ctx.cp_rank, ctx.cp_size, ctx.cp_group, dim=ctx.seq_dim)
return grad_input, None, None, None, None
class _GatherForwardSplitBackWardWithMegatronCP(torch.autograd.Function):
'''
Split the input tensor in the forward pass and gather the gradients in the backward pass with megatron cp(Ring Attention)
It will be implemented in Mindspeed in the future.
'''
@staticmethod
def forward(ctx, val, cp_rank, cp_size, seq_dim, cp_group=None):
val = _gather(val, cp_group, dim=seq_dim)
val = reorder_output(val, cp_rank, cp_size, cp_group, dim=seq_dim)
ctx.cp_group = cp_group
ctx.cp_rank = cp_rank
ctx.cp_size = cp_size
ctx.seq_dim = seq_dim
return val
@staticmethod
def backward(ctx, grad_output):
cp_group = ctx.cp_group
cp_rank = ctx.cp_rank
cp_size = ctx.cp_size
seq_dim = ctx.seq_dim
grad_output = grad_output.view(
*grad_output.shape[0:seq_dim],
2 * cp_size,
grad_output.shape[seq_dim] // (2 * cp_size),
*grad_output.shape[(seq_dim + 1):],
) * cp_size
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=grad_output.device)
grad_output = grad_output.index_select(seq_dim, index)
grad_input = grad_output.view(*grad_output.shape[0:seq_dim], -1, *grad_output.shape[(seq_dim + 2):])
return grad_input, None, None, None, None
def load_balanced_split_forward_gather_backward(
input_: torch.Tensor,
process_group: torch.distributed.ProcessGroup,
dim: int = 0
) -> torch.Tensor:
cp_size = torch.distributed.get_world_size(group=process_group)
cp_rank = torch.distributed.get_rank(group=process_group)
return _SplitForwardGatherBackWardWithMegatronCP.apply(input_, cp_rank, cp_size, dim, process_group)
def load_balanced_gather_forward_split_backward(
input_: torch.Tensor,
process_group: torch.distributed.ProcessGroup,
dim: int = 0
) -> torch.Tensor:
cp_size = torch.distributed.get_world_size(group=process_group)
cp_rank = torch.distributed.get_rank(group=process_group)
return _GatherForwardSplitBackWardWithMegatronCP.apply(input_, cp_rank, cp_size, dim, process_group)
def split_forward_gather_backward_with_cp(
input_: torch.Tensor,
dim: int,
) -> torch.Tensor:
"""
Perform a context-parallel-aware tensor split during forward pass and gather during backward pass.
This function supports multiple context parallel (CP) algorithms:
- Ulysses-style CP: uniform or non-uniform split across CP ranks.
- Ring Attention: typically used with sequence parallelism and ring-based communication.
- Hybrid CP: combines ring and Ulysses-style splitting in nested groups.
"""
ps = get_parallel_state()
seq_len = input_.shape[dim]
if ps.is_ring_enable():
if seq_len % (2 * ps.get_ring_group_size()) != 0:
raise ValueError(f"Seq lens should be multiple of 2 * ring_size, but got seq_len: {seq_len}, ring_size: {ps.get_ring_group_size()}")
input_ = load_balanced_split_forward_gather_backward(input_, ps.get_ring_group(), dim=dim)
seq_len = input_.shape[dim]
if ps.is_ulysses_enable():
split_gather_sizes = cal_split_sizes(seq_len, ps.get_ulysses_group_size())
input_ = split_forward_gather_backward(input_, ps.get_ulysses_group(), dim=dim, split_sizes=split_gather_sizes)
return input_
def gather_forward_split_backward_with_cp(
input_: torch.Tensor,
dim: int,
gather_size: int
):
"""
Perform a context-parallel-aware tensor gather during forward pass and split during backward pass.
This is the reverse operation of split_forward_gather_backward_with_cp.
"""
ps = get_parallel_state()
if ps.is_ring_enable():
if gather_size % (2 * ps.get_ring_group_size()) != 0:
raise ValueError(f"Total gather size should be multiple of 2 * ring_size, but got total gather size: {gather_size}, ring_size: {ps.get_ring_group_size()}")
gather_size = gather_size // ps.get_ring_group_size()
if ps.is_ulysses_enable():
gather_size_list = cal_split_sizes(gather_size, ps.get_ulysses_group_size())
input_ = gather_forward_split_backward(input_, ps.get_ulysses_group(), dim=dim, gather_sizes=gather_size_list)
if ps.is_ring_enable():
input_ = load_balanced_gather_forward_split_backward(input_, ps.get_ring_group(), dim=dim)
return input_
def packed_data_split_forward_gather_backward(
hidden_states: torch.Tensor,
process_group: torch.distributed.ProcessGroup,
sequence_lengths: List[int],
dim: int = 0
) -> torch.Tensor:
"""
Split packed hidden states for sequence parallelism, handling multiple variable-length sequences.
This function:
1. Splits each sequence across all processes in the process group
2. Returns the local chunks for the current rank
Args:
hidden_states: Concatenated hidden states of all sequences [total_seq_len, hidden_dim]
process_group: Distributed process group for parallelism
sequence_lengths: List of lengths for each individual sequence
dim: Dimension along which to split (default: 0 - sequence dimension)
Returns:
local_hidden_states: Local chunks of hidden states for current rank
"""
world_size = torch.distributed.get_world_size(process_group)
local_sequence_chunks = []
current_position = 0
for _, seq_len in enumerate(sequence_lengths):
sequence_end = current_position + seq_len
current_sequence = hidden_states.narrow(dim=dim, start=current_position, length=seq_len)
current_position = sequence_end
sequence_split_sizes = cal_split_sizes(seq_len, world_size)
local_chunk = split_forward_gather_backward(
current_sequence,
process_group=process_group,
dim=dim,
grad_scale="down",
split_sizes=sequence_split_sizes
)
local_sequence_chunks.append(local_chunk)
local_hidden_states = torch.cat(local_sequence_chunks, dim=dim)
return local_hidden_states
def packed_data_gather_forward_split_backward(
local_hidden_states: torch.Tensor,
all_split_sizes: torch.Tensor,
process_group: torch.distributed.ProcessGroup,
dim: int = 0
) -> torch.Tensor:
"""
Gather distributed sequence chunks and reconstruct original packed sequences.
Reconstruction process:
1. For each sequence: gather chunks from all ranks and concatenate
2. Concatenate all reconstructed sequences along sequence dimension
Args:
local_hidden_states: Local hidden states [local_seq_len, hidden_dim]
all_split_sizes: Split sizes tensor [world_size, num_sequences] showing how each sequence
was distributed across ranks
process_group: Distributed process group for communication
dim: Dimension along which sequences were split (default: 0 - sequence dimension)
Returns:
reconstructed_sequences: Reconstructed packed sequences [total_seq_len, hidden_dim]
"""
world_size, num_sequences = all_split_sizes.shape
rank = torch.distributed.get_rank(process_group)
rank_total_sizes = []
for r in range(world_size):
total_size_this_rank = all_split_sizes[r].sum().item()
rank_total_sizes.append(total_size_this_rank)
all_gathered_tensor = gather_forward_split_backward(
local_hidden_states,
process_group=process_group,
dim=dim,
grad_scale="up",
gather_sizes=rank_total_sizes
)
rank_chunks = torch.split(all_gathered_tensor, rank_total_sizes, dim=dim)
rank_seq_chunks = [list(torch.split(rank_chunks[i], all_split_sizes[i].tolist(), dim=dim)) for i in range(world_size)]
reconstructed_sequences = []
for _ in range(num_sequences):
sequence_chunks = [rank_seq_chunks[i].pop(0) for i in range(world_size)]
reconstructed_sequences.extend(sequence_chunks)
return torch.cat(reconstructed_sequences, dim=dim)
def packed_data_split_forward_gather_backward_with_cp(
x: torch.Tensor,
dim: int,
seq_lens: List[int]
):
"""
Split packed sequences across context parallel (CP) ranks during the forward pass,
and gather full gradients during the backward pass.
This function supports three CP strategies:
- **Ulysses CP**: Splits the entire packed sequence uniformly (or near-uniformly) across all CP ranks.
- **Ring Attention**: Splits *each individual sample sequence* (e.g., image tokens) across CP ranks,
then concatenates the resulting shards to form a new packed tensor.
- **Hybrid CP**: First applies ring-based splitting per sample, then further splits the result
using Ulysses within each ring subgroup.
Args:
x: Concatenated sequences: s1+s2+s3+...
"""
ps = get_parallel_state()
if ps.is_ring_enable():
x = packed_data_split_forward_gather_backward(x, ps.get_ring_group(), seq_lens, dim)
if ps.is_ulysses_enable():
split_gather_sizes = cal_split_sizes(x.shape[dim], ps.get_ulysses_group_size())
x = split_forward_gather_backward(x, ps.get_ulysses_group(), dim=dim, split_sizes=split_gather_sizes)
return x
def packed_data_gather_forward_split_backward_with_cp(
x: torch.Tensor,
dim: int,
seq_lens: List[int]
):
"""
Gather visual sequences across context parallel (CP) ranks during the forward pass,
and split gradients back during the backward pass.
This function supports multiple CP strategies:
- **Ulysses CP**: All-gather full sequence using precomputed per-rank sequence lengths.
- **Ring Attention**: Reconstruct packed tensor from sequence chunks distributed across CP ranks.
- **Hybrid CP**: First gather within Ulysses subgroups, then across ring-based CP groups.
"""
ps = get_parallel_state()
all_split_sizes_tensor = cal_split_sizes_multi(seq_lens, ps.get_ring_group_size())
gather_sizes = cal_split_sizes(all_split_sizes_tensor[ps.get_ring_rank()].sum(), ps.get_ulysses_group_size())
if ps.is_ulysses_enable():
x = gather_forward_split_backward(x, ps.get_ulysses_group(), dim=dim, gather_sizes=gather_sizes)
if ps.is_ring_enable():
x = packed_data_gather_forward_split_backward(x, all_split_sizes_tensor, ps.get_ring_group(), dim=dim)
return x