# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) 2025 Advanced Micro Devices Inc.
# See LICENSE for license information.


"""Python fallback implementation of FusedAdam for Ascend NPU."""

from __future__ import annotations

import math
import os
import warnings
from collections.abc import Iterable
from typing import Optional

import torch
import torch_npu
from torch.distributed._tensor import DTensor

from ..quantized_tensor import QuantizedTensor
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Tensor


class FusedAdam(torch.optim.Optimizer):
    """Adam/AdamW optimizer with the Transformer Engine FusedAdam interface.

    This NPU implementation intentionally does not depend on NVIDIA's
    ``transformer_engine_torch`` CUDA extension.  It uses ordinary PyTorch tensor
    operations as a correctness-first fallback, while keeping the public
    constructor and ``step`` signatures compatible with TE/Apex FusedAdam.
    """

    def __init__(
        self,
        params: Iterable[torch.nn.Parameter | dict],
        lr: float = 1e-3,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0.0,
        amsgrad: bool = False,
        *,
        bias_correction=True,
        adam_w_mode=True,
        capturable=False,
        master_weights=False,
        master_weight_dtype=torch.float32,
        exp_avg_dtype=torch.float32,
        exp_avg_sq_dtype=torch.float32,
        use_decoupled_grad=False,
        store_param_remainders=False,
        set_grad_none: Optional[bool] = None,
    ):
        if amsgrad:
            raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
        lr_value = self._scalar_to_float(lr, "lr")
        eps_value = self._scalar_to_float(eps, "eps")
        weight_decay_value = self._scalar_to_float(weight_decay, "weight_decay")
        beta1_value = self._scalar_to_float(betas[0], "beta1")
        beta2_value = self._scalar_to_float(betas[1], "beta2")
        if lr_value < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if eps_value < 0.0:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= beta1_value < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= beta2_value < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if weight_decay_value < 0.0:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        if capturable:
            raise NotImplementedError(
                "NPU FusedAdam fallback does not support capturable=True yet."
            )
        if master_weights and master_weight_dtype not in (torch.float32, torch.float16):
            raise NotImplementedError(
                "NPU FusedAdam fallback currently supports only "
                "master_weight_dtype=torch.float32 or torch.float16."
            )
        self._check_state_dtype(exp_avg_dtype, "exp_avg_dtype")
        self._check_state_dtype(exp_avg_sq_dtype, "exp_avg_sq_dtype")
        if torch.uint8 in (exp_avg_dtype, exp_avg_sq_dtype) and not self._float8_state_available():
            raise NotImplementedError(
                "NPU FusedAdam fallback requires TransformerEngineNPU Float8Tensor, "
                "torch_npu.npu_dynamic_quant, and torch.float8_e4m3fn support when "
                "exp_avg_dtype or exp_avg_sq_dtype is torch.uint8."
            )
        if store_param_remainders:
            raise NotImplementedError(
                "NPU FusedAdam fallback does not support store_param_remainders=True yet."
            )

        defaults = {
            "lr": lr_value,
            "bias_correction": bias_correction,
            "betas": betas,
            "eps": eps,
            "weight_decay": weight_decay,
        }
        super().__init__(params, defaults)

        self.adam_w_mode = 1 if adam_w_mode else 0
        self.capturable = capturable
        self.master_weights = master_weights
        self.master_weight_dtype = master_weight_dtype
        self.exp_avg_dtype = exp_avg_dtype
        self.exp_avg_sq_dtype = exp_avg_sq_dtype
        self.use_decoupled_grad = use_decoupled_grad
        self.store_param_remainders = False
        self.dtype_to_range_map = {
            torch.float16: torch.finfo(torch.float16).max / 2.0,
            torch.uint8: 448.0,
        }
        self._scales = {}
        self._state_quantizers = {}
        self.fp8_exp_avg_sq_storage = os.getenv(
            "TE_NPU_FUSED_ADAM_FP8_EXP_AVG_SQ_STORAGE",
            "sqrt",
        ).lower()

        self.set_grad_none = set_grad_none
        if self.set_grad_none is not None:
            warnings.warn(
                "set_grad_none kwarg in FusedAdam constructor is deprecated. "
                "Use set_to_none kwarg in zero_grad instead.",
                DeprecationWarning,
            )

    def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
        """Reset parameter gradients.

        If ``set_to_none`` is omitted, use the deprecated constructor option
        ``set_grad_none`` when provided; otherwise match TE FusedAdam's default
        of setting gradients to ``None``.
        """

        if self.set_grad_none is not None:
            if set_to_none is not None and set_to_none != self.set_grad_none:
                raise ValueError(
                    f"Called zero_grad with set_to_none={set_to_none}, "
                    "but FusedAdam was initialized with "
                    f"set_grad_none={self.set_grad_none}"
                )
            set_to_none = self.set_grad_none
        if set_to_none is None:
            set_to_none = True

        if not self.use_decoupled_grad:
            super().zero_grad(set_to_none=set_to_none)
            return

        for group in self.param_groups:
            for param in group["params"]:
                if set_to_none:
                    param.decoupled_grad = None
                else:
                    decoupled_grad = getattr(param, "decoupled_grad", None)
                    if decoupled_grad is not None:
                        decoupled_grad.zero_()

    def state_dict(self):
        """Return a state dict with low-precision optimizer states saved as FP32."""

        state_dict = super().state_dict()

        for group, packed_group in zip(self.param_groups, state_dict["param_groups"]):
            for param, param_id in zip(group["params"], packed_group["params"]):
                packed_state = state_dict["state"].get(param_id)
                param_state = self.state.get(param)
                if not isinstance(packed_state, dict) or not isinstance(
                    param_state,
                    dict,
                ):
                    continue
                packed_state = dict(packed_state)
                state_dict["state"][param_id] = packed_state
                for name in ("exp_avg", "exp_avg_sq"):
                    if name in param_state:
                        packed_state[name] = self.get_unscaled_state(param, name)
                if "master_param" in param_state:
                    packed_state["master_param"] = self.get_unscaled_state(
                        param,
                        "master_param",
                    )

        return state_dict

    def load_state_dict(self, state_dict):
        saved_to_current_params = {}
        for saved_group, current_group in zip(
            state_dict["param_groups"],
            self.param_groups,
        ):
            for saved_param_id, param in zip(
                saved_group["params"],
                current_group["params"],
            ):
                saved_to_current_params[saved_param_id] = param

        saved_state_tensors = {}
        for param_id, param in saved_to_current_params.items():
            saved_state = state_dict["state"].get(param_id)
            if not isinstance(saved_state, dict):
                continue
            for name in ("exp_avg", "exp_avg_sq", "master_param"):
                value = saved_state.get(name)
                if self._is_float8_state(value):
                    saved_state_tensors[(param, name)] = (
                        value.dequantize(dtype=torch.float32).detach().clone()
                    )
                elif torch.is_tensor(value):
                    if name in ("exp_avg", "exp_avg_sq") and value.dtype == torch.uint8:
                        raise RuntimeError(
                            "legacy plain uint8 FP8 optimizer state is not supported. "
                            "Please load from an FP32 unscaled state_dict, or "
                            "regenerate the checkpoint."
                        )
                    if name == "master_param" and not value.is_floating_point():
                        raise NotImplementedError(
                            "Loading non-floating master_param optimizer state is not "
                            "supported by NPU FusedAdam fallback. Checkpoints that use "
                            "store_param_remainders must be regenerated as FP32 "
                            "master_param state."
                        )
                    saved_state_tensors[(param, name)] = value.detach().clone()

        result = super().load_state_dict(state_dict)
        self._scales = {}
        self._state_quantizers = {}

        for param in saved_to_current_params.values():
            state = self.state.get(param)
            if not isinstance(state, dict):
                continue

            for name in ("exp_avg", "exp_avg_sq"):
                value = saved_state_tensors.get((param, name))
                if value is not None:
                    self.set_scaled_state(param, name, value)

            master_param = saved_state_tensors.get((param, "master_param"))
            if self.master_weights and param.dtype in (torch.float16, torch.bfloat16):
                if master_param is None:
                    master_param = param.detach().float().clone()
                self.set_scaled_state(param, "master_param", master_param)
            else:
                state.pop("master_param", None)
                self._scale_dict(param).pop("master_param", None)
        return result

    @staticmethod
    def _scalar_to_float(value, name: str) -> float:
        """Convert a scalar optimizer hyperparameter to a Python float."""

        if torch.is_tensor(value):
            if value.numel() != 1:
                raise ValueError(f"Optimizer parameter {name} must be a scalar.")
            return float(value.detach().item())
        return float(value)

    @staticmethod
    def _step_to_int(step) -> int:
        """Convert a scalar optimizer step to a Python int."""

        if torch.is_tensor(step):
            if step.numel() != 1:
                raise ValueError("Optimizer step must be a scalar.")
            return int(step.detach().item())
        return int(step)

    @staticmethod
    def _has_sparse_grad(grad: torch.Tensor) -> bool:
        """Return True for COO or compressed sparse gradients."""

        if grad.is_sparse:
            return True
        is_sparse_csr = getattr(grad, "is_sparse_csr", False)
        if callable(is_sparse_csr):
            is_sparse_csr = is_sparse_csr()
        return bool(is_sparse_csr)

    @staticmethod
    def _check_param_dtype(param: torch.Tensor) -> None:
        """Validate parameter dtype supported by this fallback."""

        supported_dtypes = (torch.float32, torch.float16, torch.bfloat16)
        if param.dtype not in supported_dtypes:
            raise RuntimeError(
                "FusedAdam fallback supports parameters in torch.float32, "
                "torch.float16, and torch.bfloat16 only, but got "
                f"{param.dtype}."
            )

    @staticmethod
    def _check_state_dtype(dtype: torch.dtype, name: str) -> None:
        """Validate optimizer state dtype supported by this fallback."""

        supported_dtypes = (torch.float32, torch.float16, torch.bfloat16, torch.uint8)
        if dtype not in supported_dtypes:
            raise NotImplementedError(
                "NPU FusedAdam fallback supports "
                f"{name}=torch.float32, torch.float16, torch.bfloat16, "
                f"or torch.uint8 only; "
                f"got {dtype}."
            )

    @staticmethod
    def _is_dtensor_param(param) -> bool:
        """Return whether a parameter is a DTensor model weight."""

        return isinstance(param, DTensor)

    @staticmethod
    def _is_quantized_model_param(param) -> bool:
        """Return whether a parameter is a TE QuantizedTensor model weight."""

        return isinstance(param, QuantizedTensor)

    @staticmethod
    def _is_float8_model_param(param) -> bool:
        """Return whether a parameter itself is a Float8Tensor model weight."""

        return isinstance(param, Float8Tensor)

    def _check_unsupported_param_type(self, param: torch.Tensor) -> None:
        """Reject advanced model weight types that the fallback cannot update."""

        if self._is_dtensor_param(param):
            raise NotImplementedError(
                "DTensor parameters are not supported by NPU FusedAdam fallback yet."
            )
        if self._is_float8_model_param(param):
            raise NotImplementedError(
                "FP8 model weights are not supported by NPU FusedAdam fallback yet. "
                "FP8 optimizer state is supported separately from FP8 model parameters."
            )
        if self._is_quantized_model_param(param):
            raise NotImplementedError(
                "QuantizedTensor model weights are not supported by NPU fallback yet."
            )

    def _check_param_group_supported(self, group: dict) -> None:
        """Reject unsupported parameter combinations in a param group."""

        has_fp16 = False
        has_bf16 = False
        for param in group["params"]:
            self._check_unsupported_param_type(param)
            has_fp16 = has_fp16 or param.dtype == torch.float16
            has_bf16 = has_bf16 or param.dtype == torch.bfloat16
        if has_fp16 and has_bf16:
            raise RuntimeError(
                "FusedAdam fallback does not support mixing torch.float16 and "
                "torch.bfloat16 parameters in the same param group."
            )

    @staticmethod
    def _float8_state_available() -> bool:
        """Return whether the NPU Float8Tensor state path can be used."""

        if not hasattr(torch_npu, "npu_dynamic_quant"):
            return False
        return hasattr(torch, "float8_e4m3fn") and torch.float8_e4m3fn != torch.bfloat16

    @staticmethod
    def _is_float8_state(value) -> bool:
        """Return whether a state value is a TransformerEngineNPU Float8Tensor."""

        return isinstance(value, Float8Tensor)

    @staticmethod
    def _float8_state_device(value) -> Optional[torch.device]:
        """Return the device backing a TransformerEngineNPU Float8Tensor state."""

        if not FusedAdam._is_float8_state(value):
            return None
        try:
            device = torch.device(value.device)
        except (AttributeError, TypeError, RuntimeError):
            device = None
        else:
            return device
        for attr_name in ("_data", "_transpose"):
            tensor = getattr(value, attr_name, None)
            if torch.is_tensor(tensor):
                return tensor.device
        return None

    @staticmethod
    def _float8_quantizer_device(quantizer) -> Optional[torch.device]:
        """Return the device backing a Float8 optimizer-state quantizer."""

        for attr_name in ("scale", "amax"):
            tensor = getattr(quantizer, attr_name, None)
            if torch.is_tensor(tensor):
                return tensor.device
        return None

    @staticmethod
    def _devices_match(left: torch.device, right: torch.device) -> bool:
        """Compare devices while treating implicit device index as compatible."""

        left = torch.device(left)
        right = torch.device(right)
        if left == right:
            return True
        return left.type == right.type and (left.index is None or right.index is None)

    def _state_dtype_for_name(self, name: str) -> torch.dtype:
        """Return configured storage dtype for an optimizer state tensor."""

        if name == "exp_avg":
            return self.exp_avg_dtype
        if name == "exp_avg_sq":
            return self.exp_avg_sq_dtype
        if name == "master_param":
            return self.master_weight_dtype
        raise KeyError(f"Unknown optimizer state name: {name}")

    def _scale_dict(self, param: torch.Tensor) -> dict:
        """Return per-state scale dictionary for a parameter."""

        return self._scales.setdefault(param, {})

    def _quantizer_dict(self, param: torch.Tensor) -> dict:
        """Return per-state FP8 quantizer dictionary for a parameter."""

        return self._state_quantizers.setdefault(param, {})

    def _state_quantizer(self, param: torch.Tensor, state_name: str):
        """Return the Float8CurrentScalingQuantizer for an FP8 optimizer state."""

        if not self._float8_state_available():
            raise NotImplementedError(
                "FP8 optimizer state requires TransformerEngineNPU Float8Tensor, "
                "torch_npu.npu_dynamic_quant, and torch.float8_e4m3fn support."
            )

        quantizers = self._quantizer_dict(param)
        quantizer = quantizers.get(state_name)
        if quantizer is not None:
            quantizer_device = self._float8_quantizer_device(quantizer)
            if quantizer_device is None or not self._devices_match(
                quantizer_device,
                param.device,
            ):
                quantizer = None

        if quantizer is None:
            quantizer = Float8CurrentScalingQuantizer(
                fp8_dtype=torch.float8_e4m3fn,
                rowwise=True,
                columnwise=False,
                device=param.device,
            )
            quantizers[state_name] = quantizer
        return quantizer

    @staticmethod
    def _fp8_quant_input_dtype(param: torch.Tensor) -> torch.dtype:
        """Return dense input dtype accepted by NPU FP8 dynamic quantization."""

        if param.dtype in (torch.float16, torch.bfloat16):
            return param.dtype
        return torch.bfloat16

    def _to_fp8_quant_input(
        self,
        param: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Cast FP32 optimizer state to the dtype accepted by the FP8 quantizer."""

        return value.to(
            device=param.device,
            dtype=self._fp8_quant_input_dtype(param),
            non_blocking=True,
        )

    def _state_to_fp8_storage_value(
        self,
        state_name: str,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Map an fp32 optimizer state to the value stored in FP8."""

        if (
            state_name == "exp_avg_sq"
            and self.exp_avg_sq_dtype == torch.uint8
            and self.fp8_exp_avg_sq_storage == "sqrt"
        ):
            # NPU-specific internal representation: storing sqrt(exp_avg_sq)
            # reduces denominator drift while preserving the public fp32
            # exp_avg_sq semantics through get_unscaled_state/state_dict.
            return value.clamp_min(0.0).sqrt()
        return value

    def _state_from_fp8_storage_value(
        self,
        state_name: str,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Map a dequantized FP8 storage value back to optimizer-state semantics."""

        if (
            state_name == "exp_avg_sq"
            and self.exp_avg_sq_dtype == torch.uint8
            and self.fp8_exp_avg_sq_storage == "sqrt"
        ):
            value = value.clamp_min(0.0)
            return value.mul(value).clamp_min(torch.finfo(torch.float32).tiny)
        return value

    def _set_unit_state_scale(self, param: torch.Tensor, state_name: str) -> None:
        """Record a no-op fp32 scale for states that are stored directly."""

        self._scale_dict(param)[state_name] = torch.ones(
            (),
            dtype=torch.float32,
            device=param.device,
        )

    def _ensure_state_scale(self, param: torch.Tensor, state_name: str) -> torch.Tensor:
        """Return the fp32 scalar scale for a state tensor."""

        scales = self._scale_dict(param)
        scale = scales.get(state_name)
        if scale is None:
            scale = torch.ones((), dtype=torch.float32, device=param.device)
            scales[state_name] = scale
        elif scale.device != param.device:
            scale = scale.detach().to(device=param.device, dtype=torch.float32)
            scales[state_name] = scale
        return scale

    def _require_state_scale(
        self,
        param: torch.Tensor,
        state_name: str,
    ) -> torch.Tensor:
        """Return an existing state scale, preserving it across device moves."""

        scales = self._scale_dict(param)
        scale = scales.get(state_name)
        if scale is None:
            raise RuntimeError(
                f"Missing scale for low-precision optimizer state {state_name}. "
                "Cannot safely rebuild scaled optimizer state."
            )
        if scale.device != param.device:
            scale = scale.detach().to(device=param.device, dtype=torch.float32)
            scales[state_name] = scale
        return scale

    def _get_scaled_state_as_unscaled(
        self,
        param: torch.Tensor,
        state_name: str,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Return a stored low-precision state as unscaled fp32 data."""

        if self._is_float8_state(value):
            scale = self._require_state_scale(param, state_name)
            storage_value = (
                value.dequantize(dtype=torch.float32)
                .detach()
                .float()
                .to(
                    device=param.device,
                )
                .mul(scale)
            )
            return self._state_from_fp8_storage_value(state_name, storage_value)

        if value.dtype == torch.float32:
            return value.detach().to(device=param.device, dtype=torch.float32)
        if value.dtype == torch.bfloat16:
            return value.detach().to(device=param.device, dtype=torch.float32)

        scale = self._require_state_scale(param, state_name)
        return value.detach().to(device=param.device, dtype=torch.float32).mul(scale)

    def _compute_state_scale(
        self,
        unscaled_state: torch.Tensor,
        dtype: torch.dtype,
    ) -> torch.Tensor:
        """Compute a scalar fp32 scale for low-precision state storage."""

        if dtype in (torch.float32, torch.bfloat16):
            return torch.ones((), dtype=torch.float32, device=unscaled_state.device)

        max_abs = unscaled_state.detach().abs().max()
        if (not torch.isfinite(max_abs)) or max_abs.item() == 0.0:
            return torch.ones((), dtype=torch.float32, device=unscaled_state.device)

        storage_amax = self.dtype_to_range_map.get(dtype)
        if storage_amax is None:
            storage_amax = torch.finfo(dtype).max * 0.5
        scale = (max_abs / storage_amax).to(dtype=torch.float32)
        return scale.clamp_min(torch.finfo(torch.float32).tiny)

    def get_unscaled_state(self, param: torch.Tensor, state_name: str) -> torch.Tensor:
        """Return optimizer state as unscaled fp32 data."""

        state = self.state[param]
        value = state[state_name]
        if self._is_float8_state(value):
            return self._get_scaled_state_as_unscaled(param, state_name, value)
        if value.dtype == torch.float32:
            return value.detach().float().clone()
        if value.dtype == torch.bfloat16:
            self._set_unit_state_scale(param, state_name)
            return value.detach().float().clone()

        return self._get_scaled_state_as_unscaled(param, state_name, value)

    def set_scaled_state(
        self,
        param: torch.Tensor,
        state_name: str,
        unscaled_state: torch.Tensor,
    ) -> None:
        """Store fp32 state in the configured dtype with TE-style scaling."""

        state_dtype = self._state_dtype_for_name(state_name)
        value = unscaled_state.detach().to(
            device=param.device,
            dtype=torch.float32,
            non_blocking=True,
        )

        if state_dtype in (torch.float32, torch.bfloat16):
            self.state[param][state_name] = value.to(dtype=state_dtype)
            self._set_unit_state_scale(param, state_name)
            return
        if state_dtype == torch.uint8:
            quantizer = self._state_quantizer(param, state_name)
            old_state = self.state[param].get(state_name)
            storage_value = self._state_to_fp8_storage_value(state_name, value)
            scale = self._compute_state_scale(storage_value, state_dtype)
            self._scale_dict(param)[state_name] = scale
            scaled_value = storage_value.div(scale)
            quant_input = self._to_fp8_quant_input(param, scaled_value)
            with torch.no_grad():
                if self._is_float8_state(old_state):
                    self.state[param][state_name] = quantizer.update_quantized(
                        quant_input,
                        old_state,
                    )
                else:
                    self.state[param][state_name] = quantizer(quant_input)
            return

        scale = self._compute_state_scale(value, state_dtype)
        self._scale_dict(param)[state_name] = scale
        scaled_value = value.div(scale)
        self.state[param][state_name] = scaled_value.to(dtype=state_dtype)

    def _init_or_fix_state(self, param: torch.Tensor) -> dict:
        """Create missing state and keep buffers in configured storage dtypes."""

        state = self.state[param]

        for name in ("exp_avg", "exp_avg_sq"):
            value = state.get(name)
            state_dtype = self._state_dtype_for_name(name)
            if value is None:
                self.set_scaled_state(
                    param,
                    name,
                    torch.zeros_like(
                        param,
                        dtype=torch.float32,
                        memory_format=torch.preserve_format,
                    ),
                )
            elif state_dtype == torch.uint8:
                if self._is_float8_state(value):
                    state_device = self._float8_state_device(value)
                    if state_device is None or not self._devices_match(
                        state_device,
                        param.device,
                    ):
                        fp32_value = self._get_scaled_state_as_unscaled(
                            param,
                            name,
                            value,
                        )
                        state.pop(name, None)
                        self.set_scaled_state(param, name, fp32_value)
                    else:
                        self._require_state_scale(param, name)
                        self._state_quantizer(param, name)
                elif torch.is_tensor(value) and value.dtype == torch.uint8:
                    raise RuntimeError(
                        "legacy plain uint8 FP8 optimizer state is not supported. "
                        "Please load from an FP32 unscaled state_dict, or regenerate "
                        "the checkpoint."
                    )
                else:
                    unscaled_value = self._get_scaled_state_as_unscaled(
                        param,
                        name,
                        value,
                    )
                    self.set_scaled_state(param, name, unscaled_value)
            elif value.dtype != state_dtype or value.device != param.device:
                unscaled_value = self._get_scaled_state_as_unscaled(param, name, value)
                self.set_scaled_state(param, name, unscaled_value)
            elif state_dtype in (torch.float32, torch.bfloat16):
                self._set_unit_state_scale(param, name)
            else:
                self._require_state_scale(param, name)
        if self.master_weights and param.dtype in (torch.float16, torch.bfloat16):
            master_param = state.get("master_param")
            if master_param is None:
                self.set_scaled_state(
                    param,
                    "master_param",
                    param.detach().float().clone(),
                )
            else:
                master_dtype = self._state_dtype_for_name("master_param")
                if master_param.dtype != master_dtype or master_param.device != param.device:
                    master_fp32 = self._get_scaled_state_as_unscaled(
                        param,
                        "master_param",
                        master_param,
                    )
                    self.set_scaled_state(param, "master_param", master_fp32)
                elif master_dtype == torch.bfloat16:
                    self._set_unit_state_scale(param, "master_param")
                elif master_dtype == torch.float16:
                    self._require_state_scale(param, "master_param")
                else:
                    self._ensure_state_scale(param, "master_param")
        return state

    def _max_param_state_step(self, group: dict) -> int:
        """Return max legacy per-parameter step from a param group."""

        max_step = 0
        for param in group["params"]:
            state = self.state.get(param)
            if state is None or "step" not in state:
                continue
            max_step = max(max_step, self._step_to_int(state["step"]))
        return max_step

    def _advance_group_step(self, group: dict) -> int:
        """Advance and return the TE-style param-group step."""

        if "step" not in group:
            group["step"] = self._max_param_state_step(group)

        step = group["step"]
        if torch.is_tensor(step):
            step.add_(1)
            return self._step_to_int(step)

        group["step"] = int(step) + 1
        return group["step"]

    @torch.no_grad()
    def step(self, closure=None, grad_scaler=None):
        """Perform a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grad_scaler (optional): Accepted for TE FusedAdam API compatibility.
                The fallback assumes gradients have already been unscaled by the
                caller, matching the non-capturable upstream path.
        """

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if len(group["params"]) == 0:
                continue
            if group.get("amsgrad", False):
                raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
            self._check_param_group_supported(group)

            step = self._advance_group_step(group)
            lr = self._scalar_to_float(group["lr"], "lr")
            eps = self._scalar_to_float(group["eps"], "eps")
            weight_decay = self._scalar_to_float(group["weight_decay"], "weight_decay")
            beta1, beta2 = group["betas"]
            beta1 = self._scalar_to_float(beta1, "beta1")
            beta2 = self._scalar_to_float(beta2, "beta2")
            bias_correction = bool(group["bias_correction"])

            for param in group["params"]:
                self._check_param_dtype(param)
                if self.use_decoupled_grad:
                    grad = getattr(param, "decoupled_grad", None)
                else:
                    grad = param.grad
                if grad is None:
                    continue
                if self._has_sparse_grad(grad):
                    raise RuntimeError("FusedAdam does not support sparse gradients.")

                self._init_or_fix_state(param)
                exp_avg = self.get_unscaled_state(param, "exp_avg")
                exp_avg_sq = self.get_unscaled_state(param, "exp_avg_sq")

                grad_fp32 = grad.detach().float()
                using_master_param = self.master_weights and param.dtype in (
                    torch.float16,
                    torch.bfloat16,
                )
                if using_master_param:
                    param_fp32 = self.get_unscaled_state(param, "master_param")
                else:
                    param_fp32 = param.detach().float()

                if weight_decay != 0.0:
                    if self.adam_w_mode:
                        param_fp32.add_(param_fp32, alpha=-lr * weight_decay)
                    else:
                        grad_fp32 = grad_fp32.add(param_fp32, alpha=weight_decay)

                exp_avg.mul_(beta1).add_(grad_fp32, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(
                    grad_fp32,
                    grad_fp32,
                    value=1.0 - beta2,
                )

                denom = exp_avg_sq.sqrt()
                if bias_correction:
                    bias_correction1 = 1.0 - beta1**step
                    bias_correction2 = 1.0 - beta2**step
                    step_size = lr / bias_correction1
                    denom.div_(math.sqrt(bias_correction2))
                else:
                    step_size = lr
                denom.add_(eps)

                param_fp32.addcdiv_(exp_avg, denom, value=-step_size)
                self.set_scaled_state(param, "exp_avg", exp_avg)
                self.set_scaled_state(param, "exp_avg_sq", exp_avg_sq)
                if using_master_param:
                    self.set_scaled_state(param, "master_param", param_fp32)
                param.copy_(param_fp32.to(dtype=param.dtype))

        return loss