from typing import Optional
import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("static_quant_linear")
@register_tensor_cast_op("static_quant_linear_int4")
def _(
x: torch.Tensor,
w: torch.Tensor,
w_scale: torch.Tensor,
w_offset: Optional[torch.Tensor],
x_scale: Optional[torch.Tensor],
x_offset: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
"""
Semantically equivalent to
`out = (dequant(x) @ dequant(w) + bias).to_dtype(out_dtype)` if x_scale is not None
`out = (x @ dequant(w) + bias).to_dtype(out_dtype)` otherwise
Args:
x: (M, K)
x_scale: scalar, (M, 1)
x_offset: scalar, (M, 1)
w: (K, N), for int4 (K/2, N)
w_scale: scalar or (N,)
w_offset: scalar or (N,)
bias: Optional[(N,)] bias tensor
"""
if out_dtype is None:
out_dtype = x.dtype
return torch.empty((x.shape[0], w.shape[1]), dtype=out_dtype, device="meta")
@register_tensor_cast_op("fp8_linear")
def _(
x: torch.Tensor,
w: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
"""
FP8 linear operation. Both x and w are in FP8 format.
Semantically equivalent to:
`out = (dequant_fp8(x) @ dequant_fp8(w) + bias).to_dtype(out_dtype)`
Args:
x: (M, K) in FP8 format
w: (K, N) in FP8 format
x_scale: scalar or (M, 1) for activation scale
w_scale: scalar or (N,) for weight scale
bias: Optional[(N,)] bias tensor
out_dtype: output data type
"""
if out_dtype is None:
out_dtype = torch.float16
return torch.empty((x.shape[0], w.shape[1]), dtype=out_dtype, device="meta")
@register_tensor_cast_op("mxfp4_linear")
def _(
x: torch.Tensor,
w: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: Optional[torch.dtype],
) -> torch.Tensor:
"""
MXFP4 linear operation. Both x and w are in MXFP4 format with logical
dimension as (M, K) and (K, N). The MXFP4 format is logically represented
with torch.int4.
Semantically equivalent to:
`out = (dequant_mxfp4(x) @ dequant_mxfp4(w) + bias).to_dtype(out_dtype)`
Args:
x: (M, K) in MXFP4 format
w: (K, N) in MXFP4 format
x_scale: (Kg,) for activation scale in torch.float8_e8m0fnu
w_scale: (Kg,) for weight scale in torch.float8_e8m0fnu
bias: Optional[(N,)] bias tensor
out_dtype: output data type
"""
if out_dtype is None:
out_dtype = torch.float16
return torch.empty((x.shape[0], w.shape[1]), dtype=out_dtype, device="meta")