from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch.nn import functional as F
import mindspeed_mm.models.predictor.dits.hunyuanvideo15.utils
def broadcast(input_: torch.Tensor, group: dist.ProcessGroup):
src = dist.get_global_rank(group, 0)
dist.broadcast(input_, src=src, group=group)
def _all_to_all_4D(
inputs: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None
) -> torch.tensor:
"""
all-to-all for QKV
Args:
inputs (torch.tensor): a tensor sharded along dim scatter dim
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
"""
if inputs.dim() != 4:
raise ValueError(f"input must be 4D tensor, got {inputs.dim()} and shape {inputs.shape}")
seq_world_size = dist.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
seq_lens = [None] * seq_world_size
dist.all_gather_object(seq_lens, inputs.shape[1], group)
if seq_lens[-1] != seq_lens[0]:
if not seq_lens[0] > seq_lens[-1]:
raise ValueError("seq_lens is invalid")
gap = seq_lens[0] - seq_lens[-1]
if dist.get_group_rank(group, dist.get_rank()) == seq_world_size - 1:
if inputs.shape[1] != seq_lens[-1]:
raise ValueError("inputs is invalid")
inputs = F.pad(inputs, (0, 0, 0, 0, 0, gap))
else:
gap = 0
bs, shard_seqlen, hc, hs = inputs.shape
seqlen = shard_seqlen * seq_world_size
if hc % seq_world_size != 0:
raise ValueError(f'Invalid Head size: {hc}, which should be divisible by spsize {seq_world_size}')
shard_hc = hc // seq_world_size
input_t = (
inputs.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs)
.transpose(0, 2)
.contiguous()
)
output = torch.empty_like(input_t)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
else:
output = input_t
output = output.reshape(seqlen, bs, shard_hc, hs)
output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
if gap > 0:
output = output[:, :-gap]
return output
elif scatter_idx == 1 and gather_idx == 2:
bs, seqlen, shard_hc, hs = inputs.shape
hc = shard_hc * seq_world_size
if seqlen % seq_world_size != 0:
new_seqlen = (seqlen // seq_world_size + 1) * seq_world_size
gap = new_seqlen - seqlen
inputs = F.pad(inputs, (0, 0, 0, 0, 0, gap))
bs, seqlen, shard_hc, hs = inputs.shape
else:
gap = 0
if seqlen % seq_world_size != 0:
raise ValueError("seqlen % seq_world_siz must 0")
shard_seqlen = seqlen // seq_world_size
seq_world_size = dist.get_world_size(group)
input_t = (
inputs.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs)
.transpose(0, 3)
.transpose(0, 1)
.contiguous()
.reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs)
)
output = torch.empty_like(input_t)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
else:
output = input_t
output = output.reshape(hc, shard_seqlen, bs, hs)
output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
if gap > 0 and dist.get_group_rank(group, dist.get_rank()) == seq_world_size - 1:
output = output[:, :-gap]
return output
else:
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
class SeqAllToAll4D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
inputs: torch.Tensor,
scatter_idx: int,
gather_idx: int,
) -> torch.Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return _all_to_all_4D(inputs, scatter_idx, gather_idx, group=group)
@staticmethod
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]:
return (
None,
SeqAllToAll4D.apply(
ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx
),
None,
None,
)
def all_to_all_4D(
input_: torch.Tensor, group: dist.ProcessGroup, scatter_dim: int = 2, gather_dim: int = 1,
):
return SeqAllToAll4D.apply(group, input_, scatter_dim, gather_dim)
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = dist.get_world_size(process_group)
output = _all_to_all(
input_, ctx.world_size, process_group, scatter_dim, gather_dim
)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = _all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
)
return (
grad_output,
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor, group: dist.ProcessGroup, scatter_dim: int = 2, gather_dim: int = 1
):
return _AllToAll.apply(input_, group, scatter_dim, gather_dim)
class _Reduce_Scatter(torch.autograd.Function):
@staticmethod
def forward(ctx, op, group, tensor, *input_tensor_list):
ctx.group = group
tensor = tensor.contiguous()
input_tensor_list = tuple(t.contiguous() for t in input_tensor_list)
dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
class _AllGather(torch.autograd.Function):
"""All-gather communication with autograd support.
Args:
input_: input tensor
dim: dimension along which to concatenate
"""
@staticmethod
def forward(ctx, input_, dim, group):
ctx.dim = dim
ctx.group = group
world_size = dist.get_world_size(group)
input_size = list(input_.size())
sizes = [None] * world_size
dist.all_gather_object(sizes, input_.shape, group)
ctx.input_size = input_size[dim]
tensor_list = [torch.empty(sizes[i], dtype=input_.dtype, device=input_.device) for i in range(world_size)]
input_ = input_.contiguous()
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim)
return output
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
world_size = dist.get_world_size(group)
global_rank = dist.get_rank()
rank = dist.get_group_rank(group, global_rank)
dim = ctx.dim
input_size = ctx.input_size
sizes = [None] * world_size
dist.all_gather_object(sizes, input_size, group=group)
grad_input_list = torch.split(grad_output, sizes, dim=dim)
grad_input = grad_input_list[rank]
grad_input = _Reduce_Scatter.apply(dist.ReduceOp.SUM, group, grad_input, *grad_input_list)
return grad_input, None, None
@torch.compiler.disable
def all_gather(input_: torch.Tensor, dim: int = 1, group=None):
"""Performs an all-gather operation on the input tensor along the specified dimension.
Args:
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
"""
return _AllGather.apply(input_, dim, group)