"""
DEPRECATED in favor of `transformer_engine.pytorch.quantization.py`.
"""
import warnings
warnings.warn(
"Using deprecated internal API from Transformer Engine. "
"transformer_engine.pytorch.fp8 will be removed in a "
"future release.",
DeprecationWarning,
stacklevel=2,
)
from transformer_engine.common.recipe import (
DelayedScaling,
Float8BlockScaling,
Float8CurrentScaling,
Format,
MXFP4BlockScaling,
MXFP8BlockScaling,
Recipe,
)
from .quantization import (
check_recipe_support,
get_fp8_torch_dtype,
DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
Float8CurrentScalingRecipeState,
MXFP4BlockScalingRecipeState,
MXFP8BlockScalingRecipeState,
RecipeState,
)
from .quantization import (
get_default_recipe as get_default_fp8_recipe,
)
from .quantization.utils import get_fp8_te_dtype
from .quantization.manager import (
FP8GlobalStateManager,
_amax_and_scale_update,
_compute_amax_and_update_history,
_compute_scaling_factor,
_default_get_amax_and_update_history,
_default_sf_compute,
_update_amax_history,
check_fp8_block_scaling_support,
check_fp8_support,
check_mxfp4_support,
check_mxfp8_support,
check_nvfp4_support,
fp8_autocast,
fp8_model_init,
get_fp8_max,
split_and_copy,
)