import dataclasses
from functools import cached_property
from typing import Dict, List, Optional, Tuple
import torch
from .mode import QuantAlgorithm, QuantMode
from ..utils import ModelInitError, ParametersInvalid
W8A8_STATIC_LINEAR_STRATEGIES = ("dynamic", "static")
W4A4_LINEAR_STRATEGIES = ("W4A4", "W4A8")
FA_STRATEGIES = ("MXFP4", "FP8", "FLOAT")
VALID_STRATEGIES = {
"w8a8_static_linear": W8A8_STATIC_LINEAR_STRATEGIES,
"w4a4_linear": W4A4_LINEAR_STRATEGIES,
"fa": FA_STRATEGIES,
}
class TimestepPolicyConfig:
def __init__(self, default_strategy="dynamic", w4a4_default_strategy="W4A4", fa_default_strategy="MXFP4"):
r"""
The method is used to init TimestepPolicyConfig.
"""
self._strategies = {target: {} for target in VALID_STRATEGIES}
self._default_strategy = {
"w8a8_static_linear": default_strategy,
"w4a4_linear": w4a4_default_strategy,
"fa": fa_default_strategy,
}
self._validate_strategy(default_strategy, "w8a8_static_linear")
self._validate_strategy(w4a4_default_strategy, "w4a4_linear")
self._validate_strategy(fa_default_strategy, "fa")
@staticmethod
def _resolve_legacy_linear_target(strategy):
if strategy in W8A8_STATIC_LINEAR_STRATEGIES:
return "w8a8_static_linear"
if strategy in W4A4_LINEAR_STRATEGIES:
return "w4a4_linear"
return "linear"
@classmethod
def _normalize_target(cls, target, strategy=None):
if not isinstance(target, str):
raise TypeError(f"target必须是字符串类型,实际类型:{type(target)}")
if target == "linear":
target = cls._resolve_legacy_linear_target(strategy)
if target not in VALID_STRATEGIES:
raise ValueError(f"无效的target:{target},允许值:{tuple(VALID_STRATEGIES)}")
return target
@staticmethod
def _validate_strategy(strategy, target):
if not isinstance(strategy, str):
raise TypeError(f"strategy必须是字符串类型,实际类型:{type(strategy)}")
if strategy not in VALID_STRATEGIES[target]:
raise ValueError(f"无效的策略类型:{strategy},允许值:{VALID_STRATEGIES[target]}")
@staticmethod
def _normalize_step_range(step_range):
if isinstance(step_range, int):
return [step_range]
if isinstance(step_range, (list, range)):
if not all(isinstance(s, int) for s in step_range):
raise TypeError("step_range列表必须包含整数元素")
return step_range
raise TypeError(f"step_range必须是int, list, range类型,实际类型:{type(step_range)}")
def register(self, step_range, strategy, target="w4a4_linear"):
r"""
The method is used to register strategy.
Args:
step_range: Timestep range, the type can be int, range, or list.
strategy: Strategy string. Timestep linear supports dynamic/static; W4A4 linear supports W4A4/W4A8;
FA supports MXFP4/FP8/FLOAT.
target: w8a8_static_linear, w4a4_linear, or fa. Defaults to w4a4_linear.
"""
target = self._normalize_target(target, strategy)
self._validate_strategy(strategy, target)
step_range = self._normalize_step_range(step_range)
for step in step_range:
self._strategies[target][step] = strategy
def get_strategy(self, step, target="w4a4_linear"):
r"""
The method is used to get strategy.
Args:
step: Specifild timestep.
target: w8a8_static_linear, w4a4_linear, or fa. Defaults to w4a4_linear.
Returns:
The strategy corresponding to the specifiled timestep.
"""
target = self._normalize_target(target)
if step is not None and not isinstance(step, int):
raise ParametersInvalid(f"step must be the type of int, but currently got {type(step)}.")
return self._strategies[target].get(step, self._default_strategy[target])
@dataclasses.dataclass
class QuantConfig:
quant_des_path: Optional[str] = None
quant_algo: Optional[QuantAlgorithm] = None
quantized_layers: Optional[Dict[str, "QuantConfig"]] = None
exclude_layers: Optional[Tuple[str, ...]] = None
dtype: torch.dtype = torch.bfloat16
use_nz: Optional[bool] = None
timestep_config: Optional[TimestepPolicyConfig] = None
mxfp4_scale_alg: Optional[int] = None
mxfp4_dst_type_max: float = 7.25
def __post_init__(self):
self.quant_algo = self._normalize_quant_algo(self.quant_algo)
self.exclude_layers = self._normalize_exclude_layers(self.exclude_layers)
self.quantized_layers = self._normalize_quantized_layers(self.quantized_layers)
if self.quant_des_path is not None and not isinstance(self.quant_des_path, str):
raise ModelInitError("self.quant_des_path must be a string or None.")
if self.quant_algo is not None and not isinstance(self.quant_algo, QuantAlgorithm):
raise ModelInitError(
f'self.quant_algo must be an instance of QuantAlgorithm, but actually got {type(self.quant_algo)}.'
)
if not isinstance(self.dtype, torch.dtype) or self.dtype not in (torch.float16, torch.bfloat16):
raise ModelInitError("self.dtype must be torch.float16 or torch.bfloat16.")
if self.timestep_config is not None and not isinstance(self.timestep_config, TimestepPolicyConfig):
raise ModelInitError("self.timestep_config must be an instance of TimestepPolicyConfig.")
if self.use_nz is not None and not isinstance(self.use_nz, bool):
raise ModelInitError("self.use_nz must be a bool.")
if self.mxfp4_scale_alg is not None and not isinstance(self.mxfp4_scale_alg, int):
raise ModelInitError("mxfp4_scale_alg must be an int or None.")
if isinstance(self.mxfp4_dst_type_max, bool) or not isinstance(self.mxfp4_dst_type_max, (int, float)):
raise ModelInitError("mxfp4_dst_type_max must be a float.")
self.mxfp4_dst_type_max = float(self.mxfp4_dst_type_max)
@staticmethod
def _normalize_quant_algo(quant_algo):
if isinstance(quant_algo, str):
return QuantAlgorithm(quant_algo.upper())
return quant_algo
@staticmethod
def _normalize_exclude_layers(exclude_layers):
if exclude_layers is None:
return None
if isinstance(exclude_layers, list):
exclude_layers = tuple(exclude_layers)
if not isinstance(exclude_layers, tuple):
raise ModelInitError("self.exclude_layers must be a tuple")
for layer in exclude_layers:
if not isinstance(layer, str):
raise ModelInitError("Items in exclude_layers must be strings")
return exclude_layers
@classmethod
def _normalize_quantized_layers(cls, quantized_layers):
if quantized_layers is None:
return None
if not isinstance(quantized_layers, dict):
raise ModelInitError("self.quantized_layers must be a dictionary")
normalized = {}
for name, layer_config in quantized_layers.items():
if not isinstance(name, str):
raise ModelInitError("Keys in self.quantized_layers must be strings")
if isinstance(layer_config, dict):
layer_config = cls.parse_from_dict(layer_config)
if not isinstance(layer_config, QuantConfig):
raise ModelInitError("Values in self.quantized_layers must be instances of QuantConfig")
normalized[name] = layer_config
return normalized
@classmethod
def parse_from_dict(cls, config: dict):
return cls(**dict(config))
@classmethod
def from_kwargs(cls, kwargs: dict):
timestep_config = kwargs.get('timestep_config', None)
timestep_policy = kwargs.get('timestep_policy', None)
if timestep_config is not None and timestep_policy is not None and timestep_config is not timestep_policy:
raise ParametersInvalid("timestep_config and timestep_policy cannot both be set to different objects.")
config_kwargs = {}
for name in (
'quant_des_path',
'quant_algo',
'quantized_layers',
'exclude_layers',
'dtype',
'use_nz',
'mxfp4_scale_alg',
'mxfp4_dst_type_max',
):
if name in kwargs:
config_kwargs[name] = kwargs[name]
config_kwargs['timestep_config'] = timestep_config if timestep_config is not None else timestep_policy
return cls(**config_kwargs)
def merged_with_user(self, user_config: Optional["QuantConfig"]):
if user_config is None:
return self
merged = dataclasses.replace(self)
for field in dataclasses.fields(QuantConfig):
value = getattr(user_config, field.name)
if value is not None:
setattr(merged, field.name, value)
if user_config.quant_algo is not None and user_config.quantized_layers is None and merged.quantized_layers:
merged.quantized_layers = {
name: dataclasses.replace(layer_config, quant_algo=user_config.quant_algo)
for name, layer_config in merged.quantized_layers.items()
}
return merged
def to_kwargs(self):
kwargs = {'dtype': self.dtype}
if self.use_nz is not None:
kwargs['use_nz'] = self.use_nz
if self.timestep_config is not None:
kwargs['timestep_config'] = self.timestep_config
kwargs['quant_config'] = self
return kwargs
@cached_property
def layer_quantization_mode(self):
if self.quantized_layers is not None:
return {
name: QuantMode.from_quant_algo(layer_config.quant_algo)
for name, layer_config in self.quantized_layers.items()
}
return QuantMode.from_quant_algo(self.quant_algo)
@cached_property
def quant_algorithms_list(self):
if not self.quantized_layers:
return []
return list(set(layer_config.quant_algo for _, layer_config in self.quantized_layers.items()))
def serialize_to_dict(self):
return dataclasses.asdict(self)
@dataclasses.dataclass
class LayerQuantConfig(QuantConfig):
def __post_init__(self):
super().__post_init__()
self.auto_quant_mode = self.layer_quantization_mode if self.quantized_layers else {}
@cached_property
def layer_quantization_mode(self):
if self.quantized_layers is not None:
return {
name: QuantMode.from_quant_algo(layer_config.quant_algo)
for name, layer_config in self.quantized_layers.items()
}
return {}
def serialize_to_dict(self):
output = {'quant_algo': self.quant_algo, 'quantized_layers': {}}
if self.quantized_layers:
output['quantized_layers'] = {
name: layer_config.serialize_to_dict() for name, layer_config in self.quantized_layers.items()
}
for layer_dict in output['quantized_layers'].values():
layer_dict.pop('exclude_layers', None)
return output
SUPPORTED_ONLINE_QUANT_TYPES = (
QuantAlgorithm.W8A8_DYNAMIC,
QuantAlgorithm.W8A8_MXFP8,
QuantAlgorithm.W4A4_MXFP4_DYNAMIC,
QuantAlgorithm.W4A4_MXFP4_DUALSCALE,
)
SUPPORTED_ONLINE_FALLBACK_TYPES = (
QuantAlgorithm.W8A8,
QuantAlgorithm.W16A16,
)
_W4A4_QUANT_TYPES = (
QuantAlgorithm.W4A4_MXFP4_DYNAMIC,
QuantAlgorithm.W4A4_MXFP4_DUALSCALE,
)
@dataclasses.dataclass
class OnlineQuantConfig:
quant_type: QuantAlgorithm = QuantAlgorithm.W8A8_DYNAMIC
fallback_layers: Optional[Dict[str, QuantAlgorithm]] = None
fallback_timesteps: Optional[List[int]] = None
def __post_init__(self):
self.quant_type = QuantConfig._normalize_quant_algo(self.quant_type)
if not isinstance(self.quant_type, QuantAlgorithm):
raise ModelInitError(
f'self.quant_type must be an instance of QuantAlgorithm, but actually got {type(self.quant_type)}.'
)
if self.quant_type not in SUPPORTED_ONLINE_QUANT_TYPES:
raise ModelInitError(
f'self.quant_type must be one of {SUPPORTED_ONLINE_QUANT_TYPES}, but actually got {self.quant_type}.'
)
if self.fallback_layers is None:
self.fallback_layers = {}
if not isinstance(self.fallback_layers, dict):
raise ModelInitError("self.fallback_layers must be a dict mapping layer name patterns to QuantAlgorithm")
parsed_fallback = {}
for pattern, algo in self.fallback_layers.items():
if not isinstance(pattern, str):
raise ModelInitError("Keys in fallback_layers must be strings")
algo = QuantConfig._normalize_quant_algo(algo)
if not isinstance(algo, QuantAlgorithm):
raise ModelInitError(f"Values in fallback_layers must be QuantAlgorithm, got {type(algo)}")
if algo not in SUPPORTED_ONLINE_FALLBACK_TYPES:
raise ModelInitError(
f"Fallback algorithm for '{pattern}' must be one of {SUPPORTED_ONLINE_FALLBACK_TYPES}, "
f"but got {algo}"
)
parsed_fallback[pattern] = algo
self.fallback_layers = parsed_fallback
if self.fallback_timesteps is not None:
if self.quant_type not in _W4A4_QUANT_TYPES:
raise ModelInitError(
f"fallback_timesteps is only supported for W4A4 quantization types "
f"{_W4A4_QUANT_TYPES}, but quant_type is {self.quant_type}"
)
if not isinstance(self.fallback_timesteps, (list, set, range)):
raise ModelInitError(
f"fallback_timesteps must be a list, set, or range, got {type(self.fallback_timesteps)}"
)
for ts in self.fallback_timesteps:
if not isinstance(ts, int):
raise ModelInitError(f"All elements in fallback_timesteps must be int, got {type(ts)}")
self.fallback_timesteps = list(self.fallback_timesteps)
@classmethod
def parse_from_dict(cls, config: dict):
return cls(**dict(config))
def serialize_to_dict(self):
fallback = {
pattern: algo.value if isinstance(algo, QuantAlgorithm) else algo
for pattern, algo in self.fallback_layers.items()
}
result = {
'quant_type': self.quant_type.value if isinstance(self.quant_type, QuantAlgorithm) else self.quant_type,
'fallback_layers': fallback,
}
if self.fallback_timesteps is not None:
result['fallback_timesteps'] = self.fallback_timesteps
return result