from typing import List, Optional, Tuple
import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("quantize")
def _(
x: torch.Tensor,
scale: torch.Tensor,
offset: Optional[torch.Tensor],
out_dtype: torch.dtype = torch.int8,
) -> torch.Tensor:
"""`out = clamp(round(x / scale) + offset)`"""
return torch.empty_like(x, dtype=out_dtype)
@register_tensor_cast_op("dynamic_quantize_asymmetric")
def _(
x: torch.Tensor,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dynamically quantize the input tensor `x` to `out_dtype` with symmetric quantization.
The quantization scale is computed based on the absolute max value of `x` along the specified dimensions.
Args:
x: The input tensor to be quantized.
dims: The dimensions along which to compute the quantization scale.
scale_dtype: The data type for the quantization scale (default: torch.float32).
out_dtype: The target data type for quantization (default: torch.int8).
Returns:
A tuple containing:
- The quantized tensor.
- The quantization scale tensor. When `dims` is empty, the scale is a scalar tensor.
Otherwise, the scale tensor has the same shape as `x` with the specified `dims` reduced to size 1.
- The quantization offset tensor, the same shape as the scale tensor but with dtype torch.int32.
"""
if len(dims) == 0:
scale_shape = torch.Size([])
else:
scale_shape = list(x.shape)
for dim in dims:
scale_shape[dim] = 1
scale_shape = torch.Size(scale_shape)
return (
torch.empty_like(x, dtype=out_dtype),
torch.empty(scale_shape, dtype=scale_dtype, device="meta"),
torch.empty(scale_shape, dtype=torch.int32, device="meta"),
)
@register_tensor_cast_op("dynamic_quantize_symmetric")
def _(
x: torch.Tensor,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Same as `dynamic_quantize_asymmetric` but for symmetric quantization (no offset).
"""
if len(dims) == 0:
scale_shape = torch.Size([])
else:
scale_shape = list(x.shape)
for dim in dims:
scale_shape[dim] = 1
scale_shape = torch.Size(scale_shape)
return (
torch.empty_like(x, dtype=out_dtype),
torch.empty(scale_shape, dtype=scale_dtype, device="meta"),
)
@register_tensor_cast_op("dynamic_quantize_mxfp4")
def _(
x: torch.Tensor,
group_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dynamically quantize the input tensor `x` to MXFP4. The quantization is applied
per channel group along the last dimension, where each channel group contains
`group_size` channels. The quantization is symmetric.
Args:
x: The input tensor to be quantized.
group_size: The channel group size for MXFP4 quantization.
Returns:
A tuple containing:
- The quantized tensor.
- The quantization scale tensor of shape (K_group,).
"""
K = x.shape[-1]
K_group = (K + group_size - 1) // group_size
return (
torch.empty_like(x, dtype=torch.int4),
torch.empty((K_group,), dtype=torch.float8_e8m0fnu, device="meta"),
)