import torch
from ... import config
_RMS_NORM_DTYPE_LIST = [torch.float16, torch.bfloat16]
_DEFAULT_EPS_SCALAR_WORKAROUND = {"eps": 1e-6}
def _create_pattern_result(pattern, replacement, example_inputs):
return pattern, replacement, example_inputs, _DEFAULT_EPS_SCALAR_WORKAROUND
class RMSNormPattern:
"""
Pattern for RMS normalization.
This pattern computes the RMS normalization of the input tensor.
"""
@staticmethod
def create(dtype):
def get_inputs():
hidden_states = torch.empty(2, 4, dtype=dtype, device="meta")
weight = torch.empty(4, dtype=dtype, device="meta")
return [hidden_states, weight]
def pattern(hidden_states, weight, eps):
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
out = weight * hidden_states.to(dtype)
return out
def replacement(hidden_states, weight, eps):
out = torch.ops.tensor_cast.rms_norm(hidden_states, weight, eps)
return out
return _create_pattern_result(pattern, replacement, get_inputs())
class AddRMSNormPattern:
@staticmethod
def create():
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, residual, weight]
def pattern(hidden_states, residual, weight, eps):
out = torch.ops.tensor_cast.rms_norm(hidden_states + residual, weight, eps)
return out
def replacement(hidden_states, residual, weight, eps):
out = torch.ops.tensor_cast.add_rms_norm(hidden_states, residual, weight, eps)
return out
return _create_pattern_result(pattern, replacement, get_inputs())
class AddRMSNorm2Pattern:
"""AddRMSNorm2 pattern that produces both the output and the residual."""
@staticmethod
def create():
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, residual, weight]
def pattern(hidden_states, residual, weight, eps):
residual = hidden_states + residual
out = torch.ops.tensor_cast.rms_norm(residual, weight, eps)
return out, residual
def replacement(hidden_states, residual, weight, eps):
out, residual = torch.ops.tensor_cast.add_rms_norm2(hidden_states, residual, weight, eps)
return out, residual
return _create_pattern_result(pattern, replacement, get_inputs())
class RMSNormQuantPattern:
@staticmethod
def create(eps: float = 1e-6):
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
scale = torch.empty(1, device="meta")
offset = torch.empty(1, device="meta")
return [hidden_states, weight, scale, offset]
def pattern(hidden_states, weight, scale, offset):
out = torch.ops.tensor_cast.rms_norm(hidden_states, weight, eps)
out = torch.ops.tensor_cast.quantize(out, scale, offset)
return out
def replacement(hidden_states, weight, scale, offset):
out = torch.ops.tensor_cast.rms_norm_quant(hidden_states, weight, scale, offset, eps)
return out
return (pattern, replacement, get_inputs())
class AddRMSNormQuantPattern:
@staticmethod
def create(eps: float = 1e-6):
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
scale = torch.empty(1, device="meta")
offset = torch.empty(1, device="meta")
return [hidden_states, residual, weight, scale, offset]
def pattern(hidden_states, residual, weight, scale, offset):
out = torch.ops.tensor_cast.rms_norm_quant(hidden_states + residual, weight, scale, offset, eps)
return out
def replacement(hidden_states, residual, weight, scale, offset):
out = torch.ops.tensor_cast.add_rms_norm_quant(hidden_states, residual, weight, scale, offset, eps)
return out
return (pattern, replacement, get_inputs())
class AddRMSNormQuant2Pattern:
"""AddRMSNormQuant2 pattern that produces both the output and the residual."""
@staticmethod
def create(eps: float = 1e-6):
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
scale = torch.empty(4, device="meta")
offset = torch.empty(4, device="meta")
return [hidden_states, residual, weight, scale, offset]
def pattern(hidden_states, residual, weight, scale, offset):
residual = hidden_states + residual
out = torch.ops.tensor_cast.rms_norm_quant(residual, weight, scale, offset, eps)
return out, residual
def replacement(hidden_states, residual, weight, scale, offset):
out, residual = torch.ops.tensor_cast.add_rms_norm_quant2(
hidden_states, residual, weight, scale, offset, eps
)
return out, residual
return (pattern, replacement, get_inputs())
class RMSNormDynamicQuantPattern:
"""Pattern for RMS norm followed by dynamic quantization (symmetric or asymmetric)."""
@staticmethod
def create(
eps: float = 1e-6,
symmetric: bool = True,
per_sample: bool = False,
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
):
dims = [-1] if per_sample else []
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, weight]
def pattern(hidden_states, weight):
out = torch.ops.tensor_cast.rms_norm(hidden_states, weight, eps)
if symmetric:
result = torch.ops.tensor_cast.dynamic_quantize_symmetric(
out, dims, scale_dtype=scale_dtype, out_dtype=out_dtype
)
return result
else:
result = torch.ops.tensor_cast.dynamic_quantize_asymmetric(
out, dims, scale_dtype=scale_dtype, out_dtype=out_dtype
)
return result
def replacement(hidden_states, weight):
if symmetric:
result = torch.ops.tensor_cast.rms_norm_dynamic_quant_symmetric(
hidden_states,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return result
else:
result = torch.ops.tensor_cast.rms_norm_dynamic_quant_asymmetric(
hidden_states,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return result
return (pattern, replacement, get_inputs())
class AddRMSNormDynamicQuantPattern:
"""Pattern for add RMS norm followed by dynamic quantization (symmetric or asymmetric)."""
@staticmethod
def create(
eps: float = 1e-6,
symmetric: bool = True,
per_sample: bool = False,
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
):
dims = [-1] if per_sample else []
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, residual, weight]
def pattern(hidden_states, residual, weight):
if symmetric:
result = torch.ops.tensor_cast.rms_norm_dynamic_quant_symmetric(
hidden_states + residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return result
else:
result = torch.ops.tensor_cast.rms_norm_dynamic_quant_asymmetric(
hidden_states + residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return result
def replacement(hidden_states, residual, weight):
if symmetric:
result = torch.ops.tensor_cast.add_rms_norm_dynamic_quant_symmetric(
hidden_states,
residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return result
else:
result = torch.ops.tensor_cast.add_rms_norm_dynamic_quant_asymmetric(
hidden_states,
residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return result
return (pattern, replacement, get_inputs())
class AddRMSNormDynamicQuant2Pattern:
"""Pattern for add RMS norm2 followed by dynamic quantization (symmetric or asymmetric)."""
@staticmethod
def create(
eps: float = 1e-6,
symmetric: bool = True,
per_sample: bool = False,
scale_dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype = torch.int8,
):
dims = [-1] if per_sample else []
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, residual, weight]
def pattern(hidden_states, residual, weight):
residual = hidden_states + residual
if symmetric:
result = torch.ops.tensor_cast.rms_norm_dynamic_quant_symmetric(
residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return *result, residual
else:
result = torch.ops.tensor_cast.rms_norm_dynamic_quant_asymmetric(
residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return *result, residual
def replacement(hidden_states, residual, weight):
if symmetric:
out, scale, residual = torch.ops.tensor_cast.add_rms_norm_dynamic_quant2_symmetric(
hidden_states,
residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return out, scale, residual
else:
out, scale, offset, residual = torch.ops.tensor_cast.add_rms_norm_dynamic_quant2_asymmetric(
hidden_states,
residual,
weight,
eps,
dims,
scale_dtype=scale_dtype,
out_dtype=out_dtype,
)
return out, scale, offset, residual
return (pattern, replacement, get_inputs())
class RMSNormDynamicQuantMXFP4Pattern:
"""Pattern for RMS norm followed by MXFP4 dynamic quantization."""
@staticmethod
def create(eps: float = 1e-6, group_size: int = 32):
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, weight]
def pattern(hidden_states, weight):
out = torch.ops.tensor_cast.rms_norm(hidden_states, weight, eps)
return torch.ops.tensor_cast.dynamic_quantize_mxfp4(out, group_size=group_size)
def replacement(hidden_states, weight):
return torch.ops.tensor_cast.rms_norm_dynamic_quant_mxfp4(
hidden_states,
weight,
eps,
group_size,
)
return (pattern, replacement, get_inputs())
class AddRMSNormDynamicQuantMXFP4Pattern:
"""Pattern for add RMS norm followed by MXFP4 dynamic quantization."""
@staticmethod
def create(eps: float = 1e-6, group_size: int = 64):
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, residual, weight]
def pattern(hidden_states, residual, weight):
return torch.ops.tensor_cast.rms_norm_dynamic_quant_mxfp4(
hidden_states + residual,
weight,
eps,
group_size,
)
def replacement(hidden_states, residual, weight):
return torch.ops.tensor_cast.add_rms_norm_dynamic_quant_mxfp4(
hidden_states,
residual,
weight,
eps,
group_size,
)
return (pattern, replacement, get_inputs())
class AddRMSNormDynamicQuant2MXFP4Pattern:
"""Pattern for add RMS norm2 followed by MXFP4 dynamic quantization."""
@staticmethod
def create(eps: float = 1e-6, group_size: int = 64):
def get_inputs():
hidden_states = torch.empty(2, 4, device="meta")
residual = torch.empty(2, 4, device="meta")
weight = torch.empty(4, device="meta")
return [hidden_states, residual, weight]
def pattern(hidden_states, residual, weight):
residual = hidden_states + residual
result = torch.ops.tensor_cast.rms_norm_dynamic_quant_mxfp4(
residual,
weight,
eps,
group_size,
)
return *result, residual
def replacement(hidden_states, residual, weight):
out, scale, residual = torch.ops.tensor_cast.add_rms_norm_dynamic_quant2_mxfp4(
hidden_states,
residual,
weight,
eps,
group_size,
)
return out, scale, residual
return (pattern, replacement, get_inputs())
def register_all_patterns():
from . import register_pattern
if config.compilation.fusion_patterns.enable_rms_norm:
for dtype in _RMS_NORM_DTYPE_LIST:
pattern, replacement, example_inputs, scalar_workaround = RMSNormPattern.create(dtype)
register_pattern(
f"rms_norm_pattern_{dtype}",
pattern,
replacement,
example_inputs,
scalar_workaround=scalar_workaround,
level=0,
)
if config.compilation.fusion_patterns.enable_rms_norm_quant:
register_pattern(
"rms_norm_quant_pattern",
*RMSNormQuantPattern.create(),
)
if config.compilation.fusion_patterns.enable_add_rms_norm:
pattern, replacement, example_inputs, scalar_workaround = AddRMSNormPattern.create()
register_pattern(
"add_rms_norm_pattern",
pattern,
replacement,
example_inputs,
scalar_workaround=scalar_workaround,
level=1,
)
pattern, replacement, example_inputs, scalar_workaround = AddRMSNorm2Pattern.create()
register_pattern(
"add_rms_norm2_pattern",
pattern,
replacement,
example_inputs,
scalar_workaround=scalar_workaround,
level=1,
)
if config.compilation.fusion_patterns.enable_rms_norm_quant:
register_pattern(
"add_rms_norm_quant_pattern",
*AddRMSNormQuantPattern.create(),
)
register_pattern(
"add_rms_norm_quant2_pattern",
*AddRMSNormQuant2Pattern.create(),
)
if config.compilation.fusion_patterns.enable_rms_norm_quant:
for symmetric in [True, False]:
for per_sample in [True, False]:
variant_name = "symmetric" if symmetric else "asymmetric"
variant_name += "_per_sample" if per_sample else "_per_tensor"
register_pattern(
f"rms_norm_dynamic_quant_{variant_name}_pattern",
*RMSNormDynamicQuantPattern.create(symmetric=symmetric, per_sample=per_sample),
)
if config.compilation.fusion_patterns.enable_add_rms_norm:
register_pattern(
f"add_rms_norm_dynamic_quant_{variant_name}_pattern",
*AddRMSNormDynamicQuantPattern.create(symmetric=symmetric, per_sample=per_sample),
)
register_pattern(
f"add_rms_norm_dynamic_quant2_{variant_name}_pattern",
*AddRMSNormDynamicQuant2Pattern.create(symmetric=symmetric, per_sample=per_sample),
)
for group_size in [32, 64]:
register_pattern(
f"rms_norm_dynamic_quant_mxfp4_g{group_size}_pattern",
*RMSNormDynamicQuantMXFP4Pattern.create(group_size=group_size),
)
if config.compilation.fusion_patterns.enable_add_rms_norm:
register_pattern(
f"add_rms_norm_dynamic_quant_mxfp4_g{group_size}_pattern",
*AddRMSNormDynamicQuantMXFP4Pattern.create(group_size=group_size),
)
register_pattern(
f"add_rms_norm_dynamic_quant2_mxfp4_g{group_size}_pattern",
*AddRMSNormDynamicQuant2MXFP4Pattern.create(group_size=group_size),
)