from typing import List, Optional
import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("grouped_matmul")
def _(
x: List[torch.Tensor],
w: List[torch.Tensor],
bias: List[Optional[torch.Tensor]],
) -> torch.Tensor:
"""
Perform grouped quantized matrix multiplication. The arguments follow
the same convention as `static_quant_linear` but are grouped as lists.
The output is a concatenation of the individual matmul results, not a list
of tensors.
"""
M = sum(xi.shape[0] for xi in x)
N = w[0].shape[1]
return torch.empty((M, N), dtype=x[0].dtype, device="meta")
@register_tensor_cast_op("grouped_matmul_quant")
@register_tensor_cast_op("grouped_matmul_quant_int4")
def _(
x: List[torch.Tensor],
w: List[torch.Tensor],
w_scale: List[torch.Tensor],
w_offset: List[Optional[torch.Tensor]],
x_scale: List[torch.Tensor],
x_offset: List[Optional[torch.Tensor]],
bias: List[Optional[torch.Tensor]],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
"""Similar to `grouped_matmul` but with quantization parameters."""
if out_dtype is None:
out_dtype = x[0].dtype
M = sum(xi.shape[0] for xi in x)
N = w[0].shape[1]
return torch.empty((M, N), dtype=out_dtype, device="meta")
@register_tensor_cast_op("grouped_matmul_fp8")
@register_tensor_cast_op("grouped_matmul_mxfp4")
def _(
x: List[torch.Tensor],
w: List[torch.Tensor],
w_scale: List[torch.Tensor],
x_scale: List[torch.Tensor],
bias: List[Optional[torch.Tensor]],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
"""Similar to `grouped_matmul` but for FP8 quantization."""
if out_dtype is None:
out_dtype = x[0].dtype
M = sum(xi.shape[0] for xi in x)
N = w[0].shape[1]
return torch.empty((M, N), dtype=out_dtype, device="meta")
@register_tensor_cast_op("grouped_matmul_swiglu")
def _(
x: List[torch.Tensor],
w: List[torch.Tensor],
bias: List[Optional[torch.Tensor]],
) -> torch.Tensor:
M = sum(xi.shape[0] for xi in x)
N = w[0].shape[1] if w else 0
gmm_out_shape = (M, N)
dtype = x[0].dtype if x else torch.float32
swiglu_out_shape = gmm_out_shape
return torch.empty(swiglu_out_shape, dtype=dtype, device="meta")
@register_tensor_cast_op("grouped_matmul_quant_swiglu")
@register_tensor_cast_op("grouped_matmul_quant_int4_swiglu")
def _(
x: List[torch.Tensor],
w: List[torch.Tensor],
w_scale: List[torch.Tensor],
w_offset: List[Optional[torch.Tensor]],
x_scale: List[torch.Tensor],
x_offset: List[Optional[torch.Tensor]],
bias: List[Optional[torch.Tensor]],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
if out_dtype is None:
out_dtype = x[0].dtype if x else torch.float32
M = sum(xi.shape[0] for xi in x)
N = w[0].shape[1] if w else 0
gmm_out_shape = (M, N)
swiglu_out_shape = gmm_out_shape
return torch.empty(swiglu_out_shape, dtype=out_dtype, device="meta")
@register_tensor_cast_op("grouped_matmul_fp8_swiglu")
@register_tensor_cast_op("grouped_matmul_mxfp4_swiglu")
def _(
x: List[torch.Tensor],
w: List[torch.Tensor],
w_scale: List[torch.Tensor],
x_scale: List[torch.Tensor],
bias: List[Optional[torch.Tensor]],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
if out_dtype is None:
out_dtype = x[0].dtype if x else torch.float32
M = sum(xi.shape[0] for xi in x)
N = w[0].shape[1] if w else 0
gmm_out_shape = (M, N)
swiglu_out_shape = gmm_out_shape
return torch.empty(swiglu_out_shape, dtype=out_dtype, device="meta")