import torch
import torch.distributed as dist
from .parallel_mgr import get_sequence_parallel_size
def _all_to_all_func(input_, world_size, process_group, scatter_dim=2, gather_dim=1):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=process_group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def split_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int):
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
if world_size == 1:
return input_
if pad > 0:
pad_size = list(input_.shape)
pad_size[dim] = pad
input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
dim_size = input_.size(dim)
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
output = tensor_list[rank].contiguous()
return output
def gather_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int):
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_, group=process_group)
output = torch.cat(tensor_list, dim=dim)
if pad > 0:
output = output.narrow(dim, 0, output.size(dim) - pad)
return output
SPTIAL_PAD = 0
TEMPORAL_PAD = 0
def set_spatial_pad(dim_size: int):
sp_size = get_sequence_parallel_size()
pad = (sp_size - (dim_size % sp_size)) % sp_size
global SPTIAL_PAD
SPTIAL_PAD = pad
def get_spatial_pad() -> int:
return SPTIAL_PAD
def set_temporal_pad(dim_size: int):
sp_size = get_sequence_parallel_size()
pad = (sp_size - (dim_size % sp_size)) % sp_size
global TEMPORAL_PAD
TEMPORAL_PAD = pad
def get_temporal_pad() -> int:
return TEMPORAL_PAD
def all_to_all_with_pad(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
**kwargs
):
scatter_dim = kwargs.get("scatter_dim", 2)
gather_dim = kwargs.get("gather_dim", 1)
scatter_pad = kwargs.get("scatter_pad", 0)
gather_pad = kwargs.get("gather_pad", 0)
if scatter_pad > 0:
pad_shape = list(input_.shape)
pad_shape[scatter_dim] = scatter_pad
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
world_size = dist.get_world_size(process_group)
input_ = _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
if gather_pad > 0:
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
return input_