from enum import Enum
from dataclasses import dataclass, field
from typing import Optional, List, Dict
import torch
import torch_npu
class ScalingStrategyEnum(Enum):
"""
Quantization strategy — trainable and QAT (Quantization-Aware Training) will be added later.
"""
DYNAMIC = "dynamic"
DELAYED = "delayed"
TRAINABLE = "trainable"
QAT_W4A16 = "qat—w4a16"
class ScalingGranularityEnum(Enum):
"""
Quantization granularity
"""
PER_TENSOR = "pertensor"
PER_CHANNEL = "perchannel"
PER_GROUP = "pergroup"
PER_BLOCK = "perblock"
MX = "mx"
@dataclass
class ScalingGranularity:
"""
Quantization granularity
"""
stype: ScalingGranularityEnum
block_size: Optional[List[int]]
def __post_init__(self):
if self.stype == ScalingGranularityEnum.MX:
if self.block_size is None:
self.block_size = [1, 1, 32]
elif self.block_size != [1, 1, 32]:
raise ValueError(
f"Invalid block size: {self.block_size} for MX, [1, 1, 32] is the only supported block size for MX")
@dataclass
class QuantRecipe:
"""
Quantization recipe that defines the quantization strategy,
granularity, and data types for inputs, weights, and gradients.
"""
scaling_strategy: ScalingStrategyEnum = field(default=ScalingStrategyEnum.DYNAMIC)
scaling_granularity: ScalingGranularity = field(default_factory=ScalingGranularity)
inputs_dtype: Optional[torch.dtype] = None
weight_dtype: Optional[torch.dtype] = None
grads_dtype: Optional[torch.dtype] = None
@classmethod
def from_recipe_name(cls, recipe_name: str):
"""
Create a QuantRecipe instance based on the recipe name.
Args:
recipe_name (str): Name of the quantization recipe.
Returns:
QuantRecipe:The corresponding quantization recipe instance。
"""
if recipe_name not in _registry_quant_recipes:
parts = recipe_name.split("_")
if len(parts) == 5:
scaling_strategy, scaling_granularity, inputs_dtype, weight_dtype, grads_dtype = parts
try:
scaling_strategy = ScalingStrategyEnum(scaling_strategy)
if "-" in scaling_granularity:
scaling_granularity, block_size = scaling_granularity.split("-")[0], scaling_granularity.split(
"-")[1:]
if len(block_size) != 3:
raise ValueError(f"Invalid block size: {block_size}")
block_size = [int(b) for b in block_size]
else:
block_size = None
scaling_granularity = ScalingGranularity(ScalingGranularityEnum(scaling_granularity), block_size)
inputs_dtype = _dtype_mapping[inputs_dtype]
weight_dtype = _dtype_mapping[weight_dtype]
grads_dtype = _dtype_mapping[grads_dtype]
return cls(
scaling_strategy=scaling_strategy,
scaling_granularity=scaling_granularity,
inputs_dtype=inputs_dtype,
weight_dtype=weight_dtype,
grads_dtype=grads_dtype,
)
except ValueError:
pass
raise ValueError(f"Unknown recipe name: {recipe_name}")
return _registry_quant_recipes[recipe_name]
def get_key_dtype(self, key: str):
if key == "inputs":
return self.inputs_dtype
elif key == "weight":
return self.weight_dtype
elif key == "grads":
return self.grads_dtype
else:
raise ValueError(f"Unknown key: {key}")
_registry_quant_recipes: Dict[str, QuantRecipe] = dict()
_dtype_mapping = {
"E1M2": torch_npu.float4_e1m2fn_x2,
"E2M1": torch_npu.float4_e2m1fn_x2,
"E4M3": torch.float8_e4m3fn,
"E5M2": torch.float8_e5m2,
"HiF8": torch_npu.hifloat8,
"BF16": torch.bfloat16,
"FP16": torch.float16,
}
def recipe_register(name=None):
"""
Register a quantization recipe into the `recipes` dictionary.
Args:
name (str, optional): The name of the recipe.
If not provided, the name of the decorated function will be used.
"""
def decorator(obj):
key = name if name is not None else obj.__name__
_registry_quant_recipes[key] = obj
return obj
if callable(name):
obj = name
name = None
return decorator(obj)
else:
return decorator
@recipe_register("mxfp8")
def register_mxfp8_recipe():
"""Register a MXFP8 recipe to the recipes dictionary."""
return QuantRecipe(
scaling_strategy=ScalingStrategyEnum.DYNAMIC,
scaling_granularity=ScalingGranularity(ScalingGranularityEnum.MX, [1, 1, 32]),
inputs_dtype=_dtype_mapping["E4M3"],
weight_dtype=_dtype_mapping["E4M3"],
grads_dtype=_dtype_mapping["E4M3"],
)
@recipe_register("delayed_hif8_pertensor")
def register_delayed_hif8_per_tensor_recipe():
"""Register a per-tensor FP8 delayed recipe to the recipes dictionary."""
return QuantRecipe(
scaling_strategy=ScalingStrategyEnum.DELAYED,
scaling_granularity=ScalingGranularity(ScalingGranularityEnum.PER_TENSOR),
inputs_dtype=_dtype_mapping["HiF8"],
weight_dtype=_dtype_mapping["HiF8"],
grads_dtype=_dtype_mapping["HiF8"],
)