import fnmatch
from enum import auto, Enum
from typing import Callable, Optional
import torch
import torch.nn as nn
from .core.quantization.datatypes import QuantizeAttentionAction
from .utils import DTYPE_FP4, DTYPE_FP8, pattern_match
class LinearQuantType(Enum):
W8A16 = auto()
W8A8 = auto()
W4A8 = auto()
FP8 = auto()
MXFP4 = auto()
def quant_type_to_dynamic_quant_dtype(
quant_type: LinearQuantType,
) -> Optional[torch.dtype]:
if quant_type in (LinearQuantType.W8A8, LinearQuantType.W4A8):
return torch.int8
elif quant_type == LinearQuantType.FP8:
return DTYPE_FP8
elif quant_type == LinearQuantType.MXFP4:
return DTYPE_FP4
elif quant_type == LinearQuantType.W8A16:
return None
else:
raise ValueError(f"Unsupported quant_type for dynamic quant: {quant_type}")
def quant_type_to_weight_dtype(quant_type: LinearQuantType) -> torch.dtype:
if quant_type in (
LinearQuantType.W8A8,
LinearQuantType.W4A8,
LinearQuantType.W8A16,
):
return torch.int8
elif quant_type == LinearQuantType.FP8:
return DTYPE_FP8
elif quant_type == LinearQuantType.MXFP4:
return DTYPE_FP4
else:
raise ValueError(f"Unsupported quant_type for weight quant: {quant_type}")
class AttentionQuantType(Enum):
INT8 = auto()
FP8 = auto()
def get_attention_quant_type(action: QuantizeAttentionAction) -> AttentionQuantType:
try:
return getattr(AttentionQuantType, action.name)
except AttributeError:
raise ValueError(
f"Unsupported quantization action: {action}. Ensure '{action.name}' is defined in AttentionQuantType."
) from None
_QUANT_TYPE_TO_TORCH_DTYPE_MAP = {
AttentionQuantType.INT8: torch.int8,
AttentionQuantType.FP8: torch.float8_e4m3fn,
}
def get_torch_dtype_from_quant_type(quant_type: AttentionQuantType) -> torch.dtype:
if quant_type not in _QUANT_TYPE_TO_TORCH_DTYPE_MAP:
raise ValueError(
f"Unsupported attention quant type: {quant_type}. "
f"Supported types: {list(_QUANT_TYPE_TO_TORCH_DTYPE_MAP.keys())}"
)
return _QUANT_TYPE_TO_TORCH_DTYPE_MAP[quant_type]
def get_torch_quant_type(action: QuantizeAttentionAction) -> AttentionQuantType:
try:
return getattr(AttentionQuantType, action.name)
except AttributeError:
raise ValueError(
f"Unsupported quantization action: {action}. Ensure '{action.name}' is defined in AttentionQuantType."
) from None
class QuantGranularity(Enum):
PER_TENSOR = auto()
PER_SAMPLE = auto()
PER_GROUP = auto()
class QuantScheme(Enum):
SYMMETRIC = auto()
ASYMMETRIC = auto()
def get_quant_config(name, quant_config, default_config_name):
if not hasattr(quant_config, "_cached_wildcard_configs"):
quant_config._cached_wildcard_configs = {
n: quant_config.linear_configs[n] for n in quant_config.linear_configs if "*" in n or "?" in n
}
wildcard_configs = quant_config._cached_wildcard_configs
if name in quant_config.linear_configs:
return quant_config.linear_configs[name]
for pattern, config in wildcard_configs.items():
if fnmatch.fnmatch(name, pattern):
return config
return quant_config.linear_configs.get(default_config_name)
def replace_module(name, new_module, root_module):
if not root_module:
return
path = name.split(".")
parent_name, child_name = ".".join(path[:-1]), path[-1]
parent_module = root_module
if parent_name:
parent_module = parent_module.get_submodule(parent_name)
setattr(parent_module, child_name, new_module)
def quantize_linear_modules(
root_module: nn.Module,
quant_linear_cls: Optional["QuantLinearBase"],
quant_config: Optional["QuantConfig"],
default_config_name: str,
strip_module_fn: Optional[Callable[[str], str]],
) -> None:
"""
Quantize Linear modules in a root module with specified quantization config and class.
Args:
root_module: (nn.Module) Root module containing Linear layers to be quantized
quant_linear_cls: (QuantLinearBase) Quantized Linear class to replace original Linear modules
quant_config: (QuantConfig) Quantization configuration object with linear config rules and exclude list
default_config_name: (str) Fallback config name if no match found for a target Linear module
strip_module_fn:
(Optional[Callable[[str], str]]) Function to clean/normalize module names,
None = use raw module name without modification
"""
if not quant_linear_cls or not root_module:
return
for name, module in root_module.named_modules():
if pattern_match(name, quant_config.modules_to_not_convert):
continue
if isinstance(module, torch.nn.Linear):
module_name = strip_module_fn(name) if strip_module_fn else name
cfg = get_quant_config(module_name, quant_config, default_config_name)
if cfg:
new_module = quant_linear_cls(module, cfg)
replace_module(name, new_module, root_module)