from dataclasses import dataclass
from enum import Enum
from typing import Optional, NamedTuple
from dataclasses import field
import torch
import torch_npu
from mindspeed.fsdp.parallel_engine_config import QuantizeConfig
class FP8Format:
def __init__(self, range_max: float, ebits: int, mbits: int, dtype: Optional[torch.dtype]):
self.max = range_max
self.ebits = ebits
self.mbits = mbits
self.dtype = dtype
@property
def quant_type(self):
if self.dtype is None:
return torch_npu.hifloat8
return self.dtype
class FormatEnum(Enum):
E4M3 = FP8Format(448, 4, 3, torch.float8_e4m3fn)
E5M2 = FP8Format(57344, 5, 2, torch.float8_e5m2)
HIF8 = FP8Format(57344, 5, 2, None)
class _FormatConfig(NamedTuple):
inputs: FormatEnum = FormatEnum.E4M3
weight: FormatEnum = FormatEnum.E4M3
grads: FormatEnum = FormatEnum.E4M3
class Format(Enum):
E4M3 = _FormatConfig()
HYBRID = _FormatConfig(grads=FormatEnum.E5M2)
HIF8 = _FormatConfig(
inputs=FormatEnum.HIF8,
weight=FormatEnum.HIF8,
grads=FormatEnum.HIF8
)
@classmethod
def from_config_fp8(cls, key: str):
return getattr(cls, key.upper(), None)
@dataclass
class QuantBaseConfig:
quant_format: str = "E4M3"
block_size: int = 32
@classmethod
def get_key_dtype(cls, key):
key_format = Format.from_config_fp8(cls.quant_format)
config = key_format.value
if key == 'inputs':
return config.inputs.value.quant_type
elif key == 'weight':
return config.weight.value.quant_type
else:
return config.grads.value.quant_type
@dataclass
class MXFP8LinearConfig(QuantBaseConfig):
mxfp8_ignored_modules: list[str] = field(default_factory=list)
mxfp8_apply_modules: list[str] = field(default_factory=list)
converter: str = "quantize.linear.mx"
def get_mxfp8linear_config(quant_config: QuantizeConfig):
return MXFP8LinearConfig(
quant_format=quant_config.quant_format,
block_size=quant_config.block_size,
mxfp8_ignored_modules=quant_config.quant_ignored_modules,
mxfp8_apply_modules=quant_config.quant_apply_modules
)