# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine Ascend PyTorch API"""

import io
from functools import lru_cache
import os
import pickle
import warnings
from contextlib import contextmanager
from enum import Enum
from types import MethodType
from typing import (
    Any,
    Dict,
    Generator,
    List,
    Literal,
    Optional,
    Tuple,
    TypedDict,
    Union,
)

import torch
import torch.nn.functional as F  # noqa: N812
from torch.distributed.tensor import DTensor

from transformer_engine.common.recipe import DelayedScaling, Recipe
from ._common import _ParameterInitMeta, noop_cat

from ..constants import dist_group_type
from ..distributed import (
    _fsdp_gather_tensors,
    gather_along_dim,
    in_fp8_activation_recompute_phase,
    is_fp8_activation_recompute_enabled,
)
from ..quantization import (
    DelayedScalingRecipeState,
    Float8BlockScalingRecipeState,
    Float8CurrentScalingRecipeState,
    FP8GlobalStateManager,
    MXFP4BlockScalingRecipeState,
    MXFP8BlockScalingRecipeState,
    RecipeState,
    W4A8BlockScalingRecipeState,
)
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor import (
    Float8TensorStorage,
    MXFP4TensorStorage,
    MXFP8TensorStorage,
    W4A8WeightTensorStorage,
)

_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_dummy_wgrads = {}
_ub_communicators = None
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []

_LAYER_COUNT = 0


class UserBufferQuantizationMode(Enum):
    """UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer."""

    NONE = "none"
    FP8 = "fp8"


def setup_dummy_wgrad(weight):
    if hasattr(weight, "grad_added_to_main_grad"):
        zero = getattr(weight, "zero_out_wgrad", False)
        main_grad = (
            weight.get_main_grad() if hasattr(weight, "__fsdp_param__") else weight.main_grad
        )
        _wgrad = get_dummy_wgrad(
            main_grad.shape,
            weight.dtype,
            zero,
            weight.device,
        )
        weight.grad_added_to_main_grad = True
    else:
        _wgrad = None
    return _wgrad


def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False, device=None) -> torch.Tensor:
    """Returns a dummy tensor of given shape."""
    key = (*shape, dtype, zero)
    if key not in _dummy_wgrads:
        _dummy_wgrads[key] = torch.empty(
            shape,
            dtype=dtype,
            device="npu" if device is None else device,
            requires_grad=False,
        )
        if zero:
            _dummy_wgrads[key].fill_(0)
    return _dummy_wgrads[key].detach()


def initialize_ub(
    shape: list,
    tp_size: int,
    use_fp8: bool = False,
    quantization_modes: List[UserBufferQuantizationMode] = None,
    dtype: torch.dtype = torch.bfloat16,
    ub_cfgs: Optional[Union[dict, List[dict]]] = None,
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
    """Initialize user buffers for communication."""
    pass


def destroy_ub() -> None:
    """Destroy user buffers."""
    global _ub_communicators
    _ub_communicators = None


def fill_userbuffers_buffer_for_all_gather(
    tensor: torch.Tensor,
    ub_comm: Any,
    ub_obj: Any,
    tp_size: int,
    tp_rank: int,
    ub_algo: Any,
) -> None:
    """Fill user buffers for all gather operation."""
    pass


def get_ub(
    shape: List[int],
    dtype: torch.dtype,
    ub_comm: Any,
    ub_algo: Any,
) -> Any:
    """Get user buffer."""
    pass


def _check_fp8_reduce_and_update():
    """Check if this is the first FP8 module (for backward reduce-and-update)."""
    reset = in_fp8_activation_recompute_phase()
    if reset:
        qstate = FP8GlobalStateManager.quantization_state
        _first_fp8_module = qstate.is_first_fp8_module
    result = FP8GlobalStateManager.is_first_fp8_module()
    if reset:
        qstate.is_first_fp8_module = _first_fp8_module
    return result


class FP8Meta(TypedDict):
    scaling_fwd: RecipeState
    scaling_bwd: RecipeState

    fp8_checkpoint: bool
    fp8_group: Optional[torch.distributed.ProcessGroup]
    recipe: Recipe
    num_gemms: int
    fp8_max_fwd: float
    fp8_max_bwd: float
    # ---------- cache ----------
    buffer_index_and_autocast_key: Tuple[int, str, int, str]
    global_fp8_buffer_pos_fwd_recompute: int


class _Quantizers(TypedDict):
    scaling_fwd: List["Quantizer"]
    scaling_bwd: List["Quantizer"]


def _is_weight_workspace_valid(
    workspace: QuantizedTensorStorage,
    quantizer: Quantizer,
) -> bool:
    """Check if a cached weight workspace is compatible with the quantizer's current usage."""
    if isinstance(workspace, Float8TensorStorage):
        if quantizer.columnwise_usage and workspace._data is None:
            return False
    elif isinstance(workspace, MXFP8TensorStorage):
        if quantizer.rowwise_usage and workspace._rowwise_data is None:
            return False
        if quantizer.columnwise_usage and workspace._columnwise_data is None:
            return False
    elif isinstance(workspace, MXFP4TensorStorage):
        if quantizer.rowwise_usage and workspace._rowwise_data is None:
            return False
        if quantizer.columnwise_usage and workspace._columnwise_data is None:
            return False
    elif isinstance(workspace, W4A8WeightTensorStorage):
        if quantizer.rowwise_usage and workspace._mxfp4_rowwise_data is None:
            return False
        if quantizer.columnwise_usage and workspace._mxfp8_columnwise_data is None:
            return False
    return True


def quantize_weight(
    *,
    tensor: Optional[torch.Tensor] = None,
    quantizer: Optional[Quantizer] = None,
    workspace: Optional[QuantizedTensorStorage] = None,
    update_workspace: bool = True,
    skip_update_flag: Optional[torch.Tensor] = None,
    fsdp_group: Optional["dist_group_type"] = None,
    workspace_dtype: Optional[torch.dtype] = None,
    cache: bool = False,
    group_list=None,
) -> Tuple[QuantizedTensorStorage, Optional[QuantizedTensorStorage]]:
    """Quantize a weight tensor, optionally reusing a cached workspace.

    Parameters
    ----------
    tensor: torch.Tensor, optional
        Weight tensor to quantize.
    quantizer: Quantizer, optional
        Quantizer for casting the weight.
    workspace: QuantizedTensorStorage, optional
        Previously cached workspace (from the module's ``_fp8_workspaces``).
        ``None`` indicates a cache miss.
    update_workspace: bool, default = True
        Whether to update an existing workspace with fresh values.
    skip_update_flag: torch.Tensor, optional
        GPU flag to conditionally skip the update.
    fsdp_group: dist_group_type, optional
        FSDP process group the weights are distributed over.
    workspace_dtype: torch.dtype, optional
        High-precision dtype for debug quantization workspaces.
    cache: bool, default = False
        If ``True`` and a new workspace is created, it will be returned
        as the second element so the caller can store it.

    Returns
    -------
    (weightmat, new_workspace)
        *weightmat*: quantized weight ready for GEMM.
        *new_workspace*: non-``None`` only when a brand-new workspace was
        created **and** ``cache=True``.  The caller should store it in
        ``_fp8_workspaces``.
    """

    if isinstance(tensor, QuantizedTensor) and quantizer is not None:
        update_rowwise = True if quantizer.rowwise_usage else None
        update_columnwise = True if quantizer.columnwise_usage else None
        tensor.update_usage(
            rowwise_usage=update_rowwise,
            columnwise_usage=update_columnwise,
        )
        return tensor, None

    if workspace is not None and quantizer is not None:
        if not _is_weight_workspace_valid(workspace, quantizer):
            workspace = None

    if (
        workspace is not None
        and tensor is not None
        and fsdp_group is not None
        and workspace.data.shape != tensor.data.shape
    ):
        _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], workspace)

    if workspace is not None:
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
            if hasattr(workspace, "quantize_"):
                if group_list is None:
                    workspace.quantize_(tensor, noop_flag=skip_update_flag)
                else:
                    workspace.grouped_quantize_(tensor, group_list, noop_flag=skip_update_flag)
        return workspace, None

    if tensor is None or quantizer is None:
        raise ValueError("tensor and quantizer kwargs must be provided to construct FP8 workspace")
    if cache:
        saved_internal = quantizer.internal
        quantizer.internal = False
    if group_list is None:
        out = quantizer.quantize(tensor, dtype=workspace_dtype)
    else:
        out = quantizer.grouped_quantize(tensor, group_list, dtype=workspace_dtype)
    if cache:
        quantizer.internal = saved_internal
        return out, out
    return out, None


class TransformerEngineBaseModule(torch.nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta: FP8Meta = {"fp8_checkpoint": False, "fp8_group": None}
        self.fp8_meta_tensors_initialized = False
        self.quantizers: _Quantizers = {"scaling_fwd": [], "scaling_bwd": []}
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
        self.fsdp_wrapped = False
        self.fsdp_group = None
        self._fp8_workspaces: Dict[str, "QuantizedTensor"] = {}
        self.activation_dtype: Optional[torch.dtype] = None
        self.wgrad_accumulation_and_reduce_hooks = []
        self.wgrad_store = None
        # fast set attr
        self.allow_different_data_and_param_types = False

    def fast_setattr(self, name: str, value: Any) -> None:
        """
        Fast version of the Module's set attribute function.
        Should be used for regular attributes, but not properties nor parameters/buffers.
        """
        self.__dict__[name] = value

    def module_setattr(self, name: str, value: Any) -> None:
        """
        Regular version of the Module's set attribute function.
        Should be used only when the fast version cannot be used - for the properties,
        parameters and buffers.
        """
        super().__setattr__(name, value)

    @property
    @lru_cache
    def is_fsdp2(self) -> bool:
        """Whether this module is wrapped with FSDP2."""
        try:
            from ..distributed import _get_module_fsdp_state

            _get_module_fsdp_state(self)
        except (RuntimeError, ImportError):
            return False
        return True

    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
        """
        Delayed scaling only.

        Increase or decrease size of amax history based on given `length`.

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:  # type: Literal["scaling_fwd", "scaling_bwd"]
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
            recipe_state = self.fp8_meta[meta_key]
            assert isinstance(recipe_state, DelayedScalingRecipeState)
            curr_len = recipe_state.amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                recipe_state.amax_history = recipe_state.amax_history[:length].clone()
            elif length > curr_len:
                extra_rows = length - curr_len
                recipe_state.amax_history = F.pad(
                    recipe_state.amax_history, pad=(0, 0, 0, extra_rows)
                )

            # Update quantizers with new amax pointers.
            self.quantizers[meta_key] = recipe_state.make_quantizers()
            # Make sure weight tensors has correct quantizers
            self._update_weight_quantizers()

            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
                qstate = FP8GlobalStateManager.quantization_state
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in qstate.global_amax_buffer:
                        if buffer_key not in qstate.global_amax_history_buffer:
                            raise RuntimeError(
                                "TE internal error during amax history change: "
                                f"buffer_key '{buffer_key}' found in global_amax_buffer "
                                "but missing from global_amax_history_buffer"
                            )
                        qstate.global_amax_buffer[buffer_key][pos] = recipe_state.amax_history[0]

    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key: Literal["scaling_fwd", "scaling_bwd"] = (
            "scaling_fwd" if fwd else "scaling_bwd"
        )

        # Return early if recipe state matches recipe
        if self.fp8_meta_tensors_initialized:
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
            if recipe.mxfp4() and isinstance(recipe_state, MXFP4BlockScalingRecipeState):
                return
            if recipe.w4a8() and isinstance(recipe_state, W4A8BlockScalingRecipeState):
                return
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
            if recipe.float8_block_scaling() and isinstance(
                recipe_state, Float8BlockScalingRecipeState
            ):
                return

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2

        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
        )

        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

    def _update_weight_quantizers(self) -> None:
        """Update the quantizers for the weight tensors."""
        weight_tensors = self._get_weight_tensors()
        weight_quantizers = self._get_weight_quantizers()
        if len(weight_tensors) != len(weight_quantizers):
            raise ValueError(
                f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
                f"({len(weight_quantizers)}) must match"
            )
        for weight, quantizer in zip(weight_tensors, weight_quantizers):
            if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
                weight.update_quantizer(quantizer)

    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
        """Get the weight tensors of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
        )

    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
        )

    def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
        """Init scales and amaxes."""
        self.set_meta_tensor(True, recipe)
        self.set_meta_tensor(False, recipe)

        self.fast_setattr("fp8_meta_tensors_initialized", True)

    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""

        def reset(key: Literal["scaling_fwd", "scaling_bwd"]):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].amax_history.copy_(
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
                else:
                    if key not in fp8_meta_tensors:
                        raise KeyError(
                            f"Cannot reset fp8 tensors: key '{key}' not found in fp8_meta_tensors. "
                            f"Available keys: {list(fp8_meta_tensors.keys())}"
                        )
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])

        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

    def get_extra_state(self) -> torch.Tensor:
        """Save before checkpointing.
        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
        #     We have experienced problems (e.g. in ONNX export) with
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        """

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

        # Store FP8 state if needed
        state = None
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
        if not fp8_checkpoint:
            return torch.empty(0, dtype=torch.uint8)

        # Copy tensors to CPU and store
        state = {}
        state["recipe"] = self.fp8_meta["recipe"]
        if state["recipe"].delayed():
            state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
            state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
            state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
            state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)

        # Store other pickelable values
        extra = {}
        for k, v in self.fp8_meta.items():
            if k != "buffer_index_and_autocast_key" and isinstance(
                v, (bool, int, float, str, tuple, list)
            ):
                extra[k] = v
        state["extra_fp8_variables"] = extra

        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
        return state_serialized

    def set_extra_state(self, state: torch.Tensor) -> None:
        """Load previous state."""
        # Maintain backwards compatibility with older checkpoints.
        if state is None:
            return

        # Load state
        if isinstance(state, torch.Tensor):
            # No FP8 is indicated by an empty tensor we don't need to unpickle.
            if state.numel() == 0:
                return
            # Default format: byte tensor with pickled data
            state = pickle.loads(state.detach().cpu().numpy().tobytes())  # nosec
        elif isinstance(state, io.BytesIO):
            # Deprecated format with io.BytesIO
            state.seek(0)
            state = torch.load(state, map_location="cuda")
        else:
            raise RuntimeError("Unsupported checkpoint format.")

        if state is None:
            return

        # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
        if "recipe" not in state:
            # TE 1.x only supported delayed scaling, which was the default recipe
            state["recipe"] = DelayedScaling()
            # TE 1.x also saved scale_inv, which is not needed with Recipe object
            state.pop("scale_inv_fwd", None)
            state.pop("scale_inv_bwd", None)

        # Load extra items
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"] = state["recipe"]
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

        # Initialize before loading
        self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst.copy_(src, non_blocking=True)

        # Load tensors
        if self.fp8_meta["recipe"].delayed():
            copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
            copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
            copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
            copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
        torch.npu.synchronize()

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():  # 这里暂时先定为float16, 暂时应不会跑进此分支
            self.fast_setattr("activation_dtype", torch.float16)
            return
        dtype = inp.dtype
        # All checks after this have already been performed once, thus skip
        if self.activation_dtype == dtype:
            return

        if not self.allow_different_data_and_param_types:
            for name, param in self.named_parameters():
                if param is not None:
                    if dtype != param.dtype:
                        raise TypeError(
                            "Data types for parameters must match when outside of autocasted "
                            f"region. Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                        )
        self.fast_setattr("activation_dtype", dtype)

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = None
                  tensor parallel process group.
        """
        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """Returns the FP8 weights."""
        fp8_params = []
        for param in self.parameters(recurse=False):
            if isinstance(param, QuantizedTensor) and param.requires_grad:
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
        """Initialize fp8 related metadata and tensors during fprop."""
        meta = self.fp8_meta

        fp8 = FP8GlobalStateManager.is_fp8_enabled()
        fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
        fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
        self.fast_setattr("fp8_parameters", fp8_parameters)
        self.fast_setattr("fp8", fp8)
        self.fast_setattr("fp8_calibration", fp8_calibration)
        fp8_enabled = fp8 or fp8_calibration
        meta["fp8_checkpoint"] = fp8_enabled

        _original_recipe = None

        if fp8_parameters or fp8_enabled:
            _original_recipe = meta.get("recipe", None)
            if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe:
                # FP8 init has already been run and recipe is the same, don't do anything.
                return
            meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fast_setattr("fp8_initialized", False)
            return

        if fp8_parameters and not self.fp8_initialized:
            meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(meta["recipe"])

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
            meta["num_gemms"] = num_gemms
            meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()

            # Set FP8_MAX per tensor according to recipe
            if hasattr(meta["recipe"], "fp8_format"):
                meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.fwd.value.max
                meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.bwd.value.max

            # Allocate scales and amaxes
            self.init_fp8_meta_tensors(meta["recipe"])
            self.fast_setattr("fp8_initialized", True)

            meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()

        _current_recipe = meta["recipe"]
        if _original_recipe is not None and not (
            issubclass(_current_recipe.__class__, _original_recipe.__class__)
            or issubclass(_original_recipe.__class__, _current_recipe.__class__)
        ):
            warnings.warn(
                f"Recipe type changed from {_original_recipe.__class__.__name__} "
                f"to {_current_recipe.__class__.__name__}. "
                "This may affect model behavior."
            )
            # Clear cached workspaces as they were created with the old recipe/quantizer type
            self._fp8_workspaces.clear()

    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
        allow_non_contiguous: bool = False,
        allow_different_data_and_param_types: bool = False,
    ):
        self.fast_setattr(
            "allow_different_data_and_param_types", allow_different_data_and_param_types
        )
        self.fast_setattr("forwarded_at_least_once", True)
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
        else:
            if not (inp.is_cuda or getattr(inp, "is_npu", False)):
                raise RuntimeError(
                    f"TransformerEngine needs CUDA or NPU. Got input on device: {inp.device}"
                )

            if self.tp_size > 1:
                if not self.tp_group_initialized:
                    raise RuntimeError(
                        "Tensor parallel group not initialized. Call "
                        "set_tensor_parallel_group() before forward pass when tp_size > 1."
                    )

            self.set_activation_dtype(inp)
            self.init_fp8_metadata(num_gemms=num_gemms)

            delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
            if delayed_scaling_recipe:
                if self.sequence_parallel:
                    if not self.fp8_meta["recipe"].reduce_amax:
                        raise ValueError(
                            "Amax reduction across tensor parallel group is "
                            "necessary when using sequence parallelism with FP8."
                        )

                if not FP8GlobalStateManager.fp8_graph_capturing():
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

                # Activation recomputation is used and this is the first forward phase.
                if self.training and is_fp8_activation_recompute_enabled():
                    FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

        if not allow_non_contiguous and not inp.is_contiguous():
            inp = inp.contiguous()
        return inp

    def end_forward(self):
        if not self.fp8:
            return
        if not self.fp8_meta["recipe"].delayed():
            return
        if not in_fp8_activation_recompute_phase():
            return
        FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)

    @contextmanager
    def prepare_forward_ctx(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
        allow_non_contiguous: bool = False,
        allow_different_data_and_param_types: bool = False,
    ) -> Generator[torch.Tensor, None, None]:
        """Checks and prepares for FWD execution."""
        inp = self.prepare_forward(
            inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types
        )
        try:
            yield inp
        finally:
            self.end_forward()

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
    ) -> Tuple[Union[torch.Tensor, QuantizedTensor], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output`.
            R2: bias gradient on R1.

        """
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()

        if row_parallel_mode and ctx.sequence_parallel and not ctx.ub_overlap_ag:
            grad_output, _ = gather_along_dim(
                grad_output,
                ctx.tp_group,
            )

        grad_bias = None
        if ctx.use_bias:
            grad_bias = grad_output.view((-1, grad_output.shape[-1])).sum(dim=0)

        if not ctx.fp8 and not ctx.debug:
            return grad_output, grad_bias

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
        # Also supports debug quantization, which is handled inside gather_along_first_dim.
        # Quantize the gradient if needed
        if not isinstance(grad_output, QuantizedTensor):
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias

    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        # Initialize param_init_meta exactly once during the init. FSDP2 can call
        # register parameter again to change parameters to DTensors. And it calls
        # it without custom fp8 specific kwargs that we need. And so we dont want
        # to reset/loose our fp8 init attributes.
        if hasattr(self, "param_init_meta") and name not in self.param_init_meta:
            self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Check if parameter is a DTensor (FSDP2) or regular tensor
            is_dtensor = isinstance(param, DTensor)
            dtensor_param = param if is_dtensor else None
            # Need to update/quantize local tensor in case of DTensor
            param = param._local_tensor if is_dtensor else param
            # Ensure parameter is on a real device
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="npu")
            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)

            # Wrap parameters in QuantizedTensor if needed
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
            high_precision_init_val = None
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
                # Keep high-precision values on CPU if needed
                if self.preserve_high_precision_init_val:
                    high_precision_init_val = param.detach().cpu()

                # Configure quantizer
                quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
                if quantizer is None:
                    raise RuntimeError("Weight quantizer has not been initialized")
                quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
                quantizer.internal = False
                # TODO @Muu 这里后续实现current的时候看下如何放开
                # if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
                #     device_mesh = dtensor_param.device_mesh
                #     amax_reduction_group = (
                #         device_mesh.get_group(mesh_dim="shard")
                #         if device_mesh.ndim > 1
                #         else device_mesh.get_group()
                #     )
                #     quantizer.amax_reduction_group = amax_reduction_group
                #     quantizer.with_amax_reduction = True
                # Quantize parameter
                param = quantizer(param)

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
            if is_dtensor:
                # recreate the DTensor from the parameter.
                dtensor_param = DTensor.from_local(
                    param,
                    device_mesh=dtensor_param.device_mesh,
                    placements=dtensor_param.placements,
                    shape=dtensor_param.size(),
                    stride=dtensor_param.stride(),
                )
                dtensor_param = torch.nn.Parameter(dtensor_param)
            else:
                param = torch.nn.Parameter(param)

            # Keep high-precision values on CPU if needed
            if high_precision_init_val is not None:
                # - Master weights are initialized from model weights, if we use fp8 primary
                #   weights to initialize master weights, the numerical values of master weights
                #   are not consistent with the numerical values when we initialize them from
                #   bf16/fp16 weights.
                # - So we add a `_high_precision_init_val` attribute to each model weight to store
                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
                #   use `get_high_precision_init_val` to get this cpu tensor.
                # - This cpu tensor is not needed once the master weight is initialized, so users
                #   should call `clear_high_precision_init_val` to remove it after master weight
                #   is initialized.

                def get(self):
                    if hasattr(self, "_high_precision_init_val"):
                        return self._high_precision_init_val
                    return None

                def clear(self):
                    if hasattr(self, "_high_precision_init_val"):
                        del self._high_precision_init_val

                param._high_precision_init_val = high_precision_init_val
                param.get_high_precision_init_val = MethodType(get, param)
                param.clear_high_precision_init_val = MethodType(clear, param)
                # Update the parameter based on its type

            if not is_dtensor:
                self.module_setattr(name, param)
            else:
                self.module_setattr(name, dtensor_param)

    def get_weight_workspace(
        self,
        *,
        tensor: Optional[torch.Tensor] = None,
        quantizer: Optional[Quantizer] = None,
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
        # fsdp_group: Optional[dist_group_type] = None,
        workspace_dtype: Optional[torch.dtype] = None,
    ) -> QuantizedTensor:
        # TODO @Muu 这里下个版本删除 follow最新版TE写法
        """Get workspace buffer for weights and maybe update its values

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = True
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
        # fsdp_group: bool, default = None
        #     FSDP process group that the weights are distributed over.
        workspace_dtype: torch.dtype, default = None
            If weight workspace contains high-precision tensor - for example
            for debug quantization, this is dtype of the tensor.
        """
        # Handle case where weights are already quantized
        # Note: Make sure weights have required usages, but do not
        # destroy unnecessary usages since they may be used later.
        if isinstance(tensor, QuantizedTensor) and quantizer is not None:
            update_rowwise_usage = True if quantizer.rowwise_usage else None
            update_columnwise_usage = True if quantizer.columnwise_usage else None
            tensor.update_usage(
                rowwise_usage=update_rowwise_usage,
                columnwise_usage=update_columnwise_usage,
            )

            return tensor

        # Try getting workspace from cache
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)

        # Reset cache if workspace is invalid
        if out is not None and quantizer is not None:
            reset_cache = False
            if isinstance(out, Float8TensorStorage):
                if quantizer.columnwise_usage and out._data is None:
                    reset_cache = True
            elif isinstance(out, MXFP8TensorStorage):
                if quantizer.rowwise_usage and out._rowwise_data is None:
                    reset_cache = True
                elif quantizer.columnwise_usage and out._columnwise_data is None:
                    reset_cache = True
            elif isinstance(out, MXFP4TensorStorage):
                if quantizer.rowwise_usage and out._rowwise_data is None:
                    reset_cache = True
                elif quantizer.columnwise_usage and out._columnwise_data is None:
                    reset_cache = True
            elif isinstance(out, W4A8WeightTensorStorage):
                if quantizer.rowwise_usage and out._mxfp4_rowwise_data is None:
                    reset_cache = True
                elif quantizer.columnwise_usage and out._mxfp8_columnwise_data is None:
                    reset_cache = True
            if reset_cache:
                out.clear()  # TODO @Muu 校验一下这里是否仍然需要clear
                out = None
                del self._fp8_workspaces[cache_name]

        # TODO @Muu 这里删除了FSDP相关的逻辑 看后续是否需要补充

        # Construct workspace if needed
        if out is None:
            if tensor is None or quantizer is None:
                raise ValueError(
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
                )

            if cache_name is not None:
                # Ensure the tensor in the cache is an instance of torch.Tensor,
                # as it persists beyond a single forward pass.
                # Setting internal=True would cause the data to be removed in prepare_for_saving(...).
                quantizer_internal = quantizer.internal
                quantizer.internal = False
            out = quantizer.quantize(tensor, dtype=workspace_dtype)
            if cache_name is not None:
                quantizer.internal = quantizer_internal

            # Update cache
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
            return out

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
            if hasattr(out, "quantize_"):
                out.quantize_(tensor, noop_flag=skip_update_flag)
        return out

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

    def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
        """
        This method is used to manually control the weight gradient accumulation and reduce.
        This method should be called before the backward() method.
        Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
        and reduce in backward();
        And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
        """
        self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)

    def need_backward_dw(self):
        """
        Check if this module needs to execute the delayed weight gradient computation.
        This method should be used at the beginning of self.backward_dw() to determine if it
        should actually be executed or just return without doing anything.
        User can also manually call this method to check that before calling into backward_dw().
        """
        return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()

    def backward_dw(self):
        """
        Execute the delayed weight gradient computation.
        This method is called after the main backward pass to compute weight gradients.
        """
        if not self.need_backward_dw():
            return
        (wgrad, bgrad), _ = self.wgrad_store.pop()
        if not self.fuse_wgrad_accumulation:
            weight_tensor = noop_cat(self._get_weight_tensors())
            weight_tensor.grad = wgrad.to(weight_tensor.dtype)
        if self.use_bias and bgrad is not None:
            # TODO @Muu 这里bgrad 看下后续怎么解开
            bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
            if bias_tensor.grad is None:
                bias_tensor.grad = bgrad.to(bias_tensor.dtype)
        del wgrad
        del bgrad
        self._trigger_wgrad_accumulation_and_reduce_hooks()

    def _trigger_wgrad_accumulation_and_reduce_hooks(self):
        """
        Trigger the wgrad accumulation and reduce hooks.
        """
        for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
            wgrad_accumulation_and_reduce_hook()