from typing import List, Optional
import torch
from ..utils import exact_division, register_tensor_cast_op
@register_tensor_cast_op("all_to_all")
def _(
x: torch.Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
rank: int,
rank_group: List[int],
) -> torch.Tensor:
output_num = sum(output_split_sizes)
return torch.empty((output_num, *x.shape[1:]), dtype=x.dtype, device=x.device)
@register_tensor_cast_op("all_reduce")
def _(x: torch.Tensor, rank: int, rank_group: List[int]) -> torch.Tensor:
return torch.empty_like(x)
@register_tensor_cast_op("reduce_scatter")
def _(x: torch.Tensor, dim: int, rank: int, rank_group: List[int]) -> torch.Tensor:
world_size = len(rank_group)
new_shape = list(x.shape)
new_shape[dim] = exact_division(new_shape[dim], world_size)
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
@register_tensor_cast_op("all_gather")
def _(x: torch.Tensor, dim: int, rank: int, rank_group: List[int]) -> torch.Tensor:
world_size = len(rank_group)
new_shape = list(x.shape)
new_shape[dim] = new_shape[dim] * world_size
return torch.empty(new_shape, dtype=x.dtype, device=x.device)
@register_tensor_cast_op("matmul_all_reduce")
def _(
mat1: torch.Tensor,
mat2: torch.Tensor,
bias: Optional[torch.Tensor],
rank: int,
rank_group: List[int],
) -> torch.Tensor:
matmul_out = torch.matmul(mat1, mat2)
return torch.empty_like(matmul_out)
@register_tensor_cast_op("static_quant_linear_all_reduce")
def _(
x: torch.Tensor,
w: torch.Tensor,
w_scale: torch.Tensor,
w_offset: Optional[torch.Tensor],
x_scale: Optional[torch.Tensor],
x_offset: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
rank: int,
rank_group: List[int],
) -> torch.Tensor:
linear_out = torch.ops.tensor_cast.static_quant_linear.default(
x,
w,
w_scale,
w_offset,
x_scale,
x_offset,
bias,
out_dtype if out_dtype is not None else x.dtype,
)
return torch.empty_like(linear_out)
@register_tensor_cast_op("static_quant_linear_int4_all_reduce")
def _(
x: torch.Tensor,
w: torch.Tensor,
w_scale: torch.Tensor,
w_offset: Optional[torch.Tensor],
x_scale: Optional[torch.Tensor],
x_offset: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
rank: int,
rank_group: List[int],
) -> torch.Tensor:
linear_out = torch.ops.tensor_cast.static_quant_linear_int4.default(
x,
w,
w_scale,
w_offset,
x_scale,
x_offset,
bias,
out_dtype if out_dtype is not None else x.dtype,
)
return torch.empty_like(linear_out)
@register_tensor_cast_op("fp8_linear_all_reduce")
def _(
x: torch.Tensor,
w: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
rank: int,
rank_group: List[int],
) -> torch.Tensor:
linear_out = torch.ops.tensor_cast.fp8_linear.default(
x,
w,
x_scale,
w_scale,
bias,
out_dtype,
)
return torch.empty_like(linear_out)
@register_tensor_cast_op("mxfp4_linear_all_reduce")
def _(
x: torch.Tensor,
w: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
rank: int,
rank_group: List[int],
) -> torch.Tensor:
linear_out = torch.ops.tensor_cast.mxfp4_linear.default(
x,
w,
x_scale,
w_scale,
bias,
out_dtype,
)
return torch.empty_like(linear_out)