"""Constants for TransformerEngine Ascend PyTorch API"""
from enum import Enum, IntEnum, unique
import torch
import torch_npu
if not hasattr(torch, "float8_e4m3fn") or not hasattr(torch, "float8_e5m2"):
torch.float8_e4m3fn = torch.bfloat16
torch.float8_e5m2 = torch.bfloat16
@unique
class NPUVersion(IntEnum):
NONE = 0
A2 = 2
A3 = 3
A5 = 5
MAX_VERSION = 999
class Fp8Recipe(str, Enum):
"""FP8 recipe enumeration.
Defines available FP8 quantization recipes.
"""
delayed = "delayed"
tensorwise = "tensorwise"
mxfp4 = "mxfp4"
mxfp8 = "mxfp8"
blockwise = "blockwise"
@unique
class TensorUsage(str, Enum):
LHS = "LN"
LHS_TRANS = "LT"
RHS = "RN"
RHS_TRANS = "RT"
USAGE_WITH_TRANS = (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS)
USAGE_WITHOUT_TRANS = (TensorUsage.LHS, TensorUsage.RHS)
@unique
class ParallelMode(str, Enum):
COLUMN = "column"
ROW = "row"
class QuantDtype:
def __init__(self, fwd: torch.dtype, bwd: torch.dtype):
self.fwd = fwd
self.bwd = bwd
if self.fwd == torch_npu.hifloat8:
self.mm_kwargs = {"x1_dtype": self.fwd, "x2_dtype": self.fwd}
self.gmm_kwargs = {"x_dtype": self.fwd, "weight_dtype": self.fwd}
else:
self.mm_kwargs = {}
self.gmm_kwargs = {}
class FP8FwdTensorIdx(IntEnum):
"""FP8 tensor indices for forward pass."""
GEMM1_INPUT = 0
GEMM1_WEIGHT = 1
GEMM1_OUTPUT = 2
GEMM2_INPUT = 3
GEMM2_WEIGHT = 4
GEMM2_OUTPUT = 5
GEMM3_INPUT = 6
GEMM3_WEIGHT = 7
GEMM3_OUTPUT = 8
class FP8BwdTensorIdx(IntEnum):
"""FP8 tensor indices for backward pass."""
GRAD_OUTPUT1 = 0
GRAD_INPUT1 = 1
GRAD_WEIGHT1 = 2
GRAD_OUTPUT2 = 3
GRAD_INPUT2 = 4
GRAD_WEIGHT2 = 5
GRAD_OUTPUT3 = 6
GRAD_INPUT3 = 7
GRAD_WEIGHT3 = 8
GemmParallelModes = ("row", "column", None)
AttnTypes = ("self", "cross")
AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi")
MXFP4_BLOCK_SCALING_SIZE = 32
MXFP8_BLOCK_SCALING_SIZE = 32
dist_group_type = torch.distributed.ProcessGroup
__all__ = [
"FP8FwdTensorIdx",
"FP8BwdTensorIdx",
"GemmParallelModes",
"AttnTypes",
"AttnBiasTypes",
"dist_group_type",
]