from enum import Enum
from typing import Optional, NamedTuple
import torch
import torch_npu
if not hasattr(torch, 'float8_e4m3fn') or not hasattr(torch, 'float8_e5m2'):
torch.float8_e4m3fn = torch.bfloat16
torch.float8_e5m2 = torch.bfloat16
class FP8Format:
def __init__(self, range_max: float, ebits: int, mbits: int, dtype: Optional[torch.dtype]):
self.max = range_max
self.ebits = ebits
self.mbits = mbits
self.dtype = dtype
@property
def quant_type(self):
if self.dtype is None:
return torch_npu.hifloat8
return self.dtype
class FormatEnum(Enum):
E4M3 = FP8Format(448, 4, 3, torch.float8_e4m3fn)
E5M2 = FP8Format(57344, 5, 2, torch.float8_e5m2)
HIF8 = FP8Format(57344, 5, 2, None)
HIF8_224 = FP8Format(224, 5, 2, None)
HIF8_15 = FP8Format(15, 5, 2, None)
class _FormatConfig(NamedTuple):
inputs: FormatEnum = FormatEnum.E4M3
weight: FormatEnum = FormatEnum.E4M3
grads: FormatEnum = FormatEnum.E4M3
class Format(Enum):
E4M3 = _FormatConfig()
HYBRID = _FormatConfig(grads=FormatEnum.E5M2)
HIF8 = _FormatConfig(inputs=FormatEnum.HIF8_15, weight=FormatEnum.HIF8_15, grads=FormatEnum.HIF8_224)
@classmethod
def from_config_fp8(cls, key: str):
return getattr(cls, key.upper(), None)
class Fp8Recipe(str, Enum):
delayed = 'delayed'
tensorwise = 'tensorwise'
mxfp8 = 'mxfp8'
mxfp8_32x32 = 'mxfp8-32x32'
blockwise = 'blockwise'
class TensorKey(str, Enum):
inputs = 'inputs'
weight = 'weight'
grads = 'grads'
class MatmulKey(tuple, Enum):
forward = (TensorKey.inputs, TensorKey.weight)
dx = (TensorKey.grads, TensorKey.weight)
dw = (TensorKey.grads, TensorKey.inputs)
MATMUL_WISE_MAP = {
MatmulKey.forward: (False, False),
MatmulKey.dx: (False, True),
MatmulKey.dw: (True, True),
}
MATMUL_WISE_MAP_NORMAL = {
MatmulKey.forward: (False, True),
MatmulKey.dx: (False, False),
MatmulKey.dw: (True, False),
}
def get_matmul_wise_by_tensor_key(tensor, key):
from mindspeed.te.pytorch.fp8 import is_fp8_tensor_2d
matmul_wise = MATMUL_WISE_MAP if is_fp8_tensor_2d(tensor) else MATMUL_WISE_MAP_NORMAL
return matmul_wise[key]
def amax_compute_max(amax, amax_history, last_history_index):
amax.copy_(torch.amax(amax_history, dim=0), non_blocking=True)
def amax_compute_most_recent(amax: torch.Tensor, amax_history, last_history_index):
amax.copy_(amax_history[last_history_index], non_blocking=True)
AMAX_COMPUTE_MAP = {'max': amax_compute_max, "most_recent": amax_compute_most_recent}