from typing import List, Tuple
import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("rms_norm")
def _(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> torch.Tensor:
return torch.empty_like(x).contiguous()
@register_tensor_cast_op("rms_norm_quant")
def _(
x: torch.Tensor,
weight: torch.Tensor,
quant_scale: torch.Tensor,
quant_offset: torch.Tensor,
eps: float,
out_dtype: torch.dtype = torch.int8,
) -> torch.Tensor:
return torch.empty_like(x, dtype=out_dtype).contiguous()
@register_tensor_cast_op("add_rms_norm")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> torch.Tensor:
return torch.empty_like(x).contiguous()
@register_tensor_cast_op("add_rms_norm2")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x).contiguous(), torch.empty_like(x).contiguous()
@register_tensor_cast_op("add_rms_norm_quant")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
quant_scale: torch.Tensor,
quant_offset: torch.Tensor,
eps: float,
out_dtype: torch.dtype = torch.int8,
) -> torch.Tensor:
return torch.empty_like(x, dtype=out_dtype).contiguous()
@register_tensor_cast_op("add_rms_norm_quant2")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
quant_scale: torch.Tensor,
quant_offset: torch.Tensor,
eps: float,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x, dtype=out_dtype).contiguous(), torch.empty_like(x).contiguous()
@register_tensor_cast_op("rms_norm_dynamic_quant_symmetric")
def _(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.tensor_cast.dynamic_quantize_symmetric(x, dims, scale_dtype=scale_dtype, out_dtype=out_dtype)
@register_tensor_cast_op("rms_norm_dynamic_quant_asymmetric")
def _(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return torch.ops.tensor_cast.dynamic_quantize_asymmetric(x, dims, scale_dtype=scale_dtype, out_dtype=out_dtype)
@register_tensor_cast_op("add_rms_norm_dynamic_quant_symmetric")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.tensor_cast.dynamic_quantize_symmetric(x, dims, scale_dtype=scale_dtype, out_dtype=out_dtype)
@register_tensor_cast_op("add_rms_norm_dynamic_quant_asymmetric")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return torch.ops.tensor_cast.dynamic_quantize_asymmetric(x, dims, scale_dtype=scale_dtype, out_dtype=out_dtype)
@register_tensor_cast_op("add_rms_norm_dynamic_quant2_symmetric")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x2 = torch.empty_like(x).contiguous()
x1, scale = torch.ops.tensor_cast.dynamic_quantize_symmetric(x, dims, scale_dtype=scale_dtype, out_dtype=out_dtype)
return x1, scale, x2
@register_tensor_cast_op("add_rms_norm_dynamic_quant2_asymmetric")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
dims: List[int],
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x2 = torch.empty_like(x).contiguous()
x1, scale, offset = torch.ops.tensor_cast.dynamic_quantize_asymmetric(
x, dims, scale_dtype=scale_dtype, out_dtype=out_dtype
)
return x1, scale, offset, x2
@register_tensor_cast_op("rms_norm_dynamic_quant_mxfp4")
def _(
x: torch.Tensor,
weight: torch.Tensor,
eps: float,
group_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.tensor_cast.dynamic_quantize_mxfp4(x, group_size=group_size)
@register_tensor_cast_op("add_rms_norm_dynamic_quant_mxfp4")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
group_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.tensor_cast.dynamic_quantize_mxfp4(x, group_size=group_size)
@register_tensor_cast_op("add_rms_norm_dynamic_quant2_mxfp4")
def _(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
group_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x2 = torch.empty_like(x).contiguous()
x1, scale = torch.ops.tensor_cast.dynamic_quantize_mxfp4(x, group_size=group_size)
return x1, scale, x2