"""Python fallback implementation of FusedAdam for Ascend NPU."""
from __future__ import annotations
import math
import os
import warnings
from collections.abc import Iterable
from typing import Optional
import torch
import torch_npu
from torch.distributed._tensor import DTensor
from ..quantized_tensor import QuantizedTensor
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Tensor
class FusedAdam(torch.optim.Optimizer):
"""Adam/AdamW optimizer with the Transformer Engine FusedAdam interface.
This NPU implementation intentionally does not depend on NVIDIA's
``transformer_engine_torch`` CUDA extension. It uses ordinary PyTorch tensor
operations as a correctness-first fallback, while keeping the public
constructor and ``step`` signatures compatible with TE/Apex FusedAdam.
"""
def __init__(
self,
params: Iterable[torch.nn.Parameter | dict],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
*,
bias_correction=True,
adam_w_mode=True,
capturable=False,
master_weights=False,
master_weight_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False,
store_param_remainders=False,
set_grad_none: Optional[bool] = None,
):
if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
lr_value = self._scalar_to_float(lr, "lr")
eps_value = self._scalar_to_float(eps, "eps")
weight_decay_value = self._scalar_to_float(weight_decay, "weight_decay")
beta1_value = self._scalar_to_float(betas[0], "beta1")
beta2_value = self._scalar_to_float(betas[1], "beta2")
if lr_value < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps_value < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= beta1_value < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= beta2_value < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if weight_decay_value < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if capturable:
raise NotImplementedError(
"NPU FusedAdam fallback does not support capturable=True yet."
)
if master_weights and master_weight_dtype not in (torch.float32, torch.float16):
raise NotImplementedError(
"NPU FusedAdam fallback currently supports only "
"master_weight_dtype=torch.float32 or torch.float16."
)
self._check_state_dtype(exp_avg_dtype, "exp_avg_dtype")
self._check_state_dtype(exp_avg_sq_dtype, "exp_avg_sq_dtype")
if torch.uint8 in (exp_avg_dtype, exp_avg_sq_dtype) and not self._float8_state_available():
raise NotImplementedError(
"NPU FusedAdam fallback requires TransformerEngineNPU Float8Tensor, "
"torch_npu.npu_dynamic_quant, and torch.float8_e4m3fn support when "
"exp_avg_dtype or exp_avg_sq_dtype is torch.uint8."
)
if store_param_remainders:
raise NotImplementedError(
"NPU FusedAdam fallback does not support store_param_remainders=True yet."
)
defaults = {
"lr": lr_value,
"bias_correction": bias_correction,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
}
super().__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.capturable = capturable
self.master_weights = master_weights
self.master_weight_dtype = master_weight_dtype
self.exp_avg_dtype = exp_avg_dtype
self.exp_avg_sq_dtype = exp_avg_sq_dtype
self.use_decoupled_grad = use_decoupled_grad
self.store_param_remainders = False
self.dtype_to_range_map = {
torch.float16: torch.finfo(torch.float16).max / 2.0,
torch.uint8: 448.0,
}
self._scales = {}
self._state_quantizers = {}
self.fp8_exp_avg_sq_storage = os.getenv(
"TE_NPU_FUSED_ADAM_FP8_EXP_AVG_SQ_STORAGE",
"sqrt",
).lower()
self.set_grad_none = set_grad_none
if self.set_grad_none is not None:
warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)
def zero_grad(self, set_to_none: Optional[bool] = None) -> None:
"""Reset parameter gradients.
If ``set_to_none`` is omitted, use the deprecated constructor option
``set_grad_none`` when provided; otherwise match TE FusedAdam's default
of setting gradients to ``None``.
"""
if self.set_grad_none is not None:
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
"but FusedAdam was initialized with "
f"set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
set_to_none = True
if not self.use_decoupled_grad:
super().zero_grad(set_to_none=set_to_none)
return
for group in self.param_groups:
for param in group["params"]:
if set_to_none:
param.decoupled_grad = None
else:
decoupled_grad = getattr(param, "decoupled_grad", None)
if decoupled_grad is not None:
decoupled_grad.zero_()
def state_dict(self):
"""Return a state dict with low-precision optimizer states saved as FP32."""
state_dict = super().state_dict()
for group, packed_group in zip(self.param_groups, state_dict["param_groups"]):
for param, param_id in zip(group["params"], packed_group["params"]):
packed_state = state_dict["state"].get(param_id)
param_state = self.state.get(param)
if not isinstance(packed_state, dict) or not isinstance(
param_state,
dict,
):
continue
packed_state = dict(packed_state)
state_dict["state"][param_id] = packed_state
for name in ("exp_avg", "exp_avg_sq"):
if name in param_state:
packed_state[name] = self.get_unscaled_state(param, name)
if "master_param" in param_state:
packed_state["master_param"] = self.get_unscaled_state(
param,
"master_param",
)
return state_dict
def load_state_dict(self, state_dict):
saved_to_current_params = {}
for saved_group, current_group in zip(
state_dict["param_groups"],
self.param_groups,
):
for saved_param_id, param in zip(
saved_group["params"],
current_group["params"],
):
saved_to_current_params[saved_param_id] = param
saved_state_tensors = {}
for param_id, param in saved_to_current_params.items():
saved_state = state_dict["state"].get(param_id)
if not isinstance(saved_state, dict):
continue
for name in ("exp_avg", "exp_avg_sq", "master_param"):
value = saved_state.get(name)
if self._is_float8_state(value):
saved_state_tensors[(param, name)] = (
value.dequantize(dtype=torch.float32).detach().clone()
)
elif torch.is_tensor(value):
if name in ("exp_avg", "exp_avg_sq") and value.dtype == torch.uint8:
raise RuntimeError(
"legacy plain uint8 FP8 optimizer state is not supported. "
"Please load from an FP32 unscaled state_dict, or "
"regenerate the checkpoint."
)
if name == "master_param" and not value.is_floating_point():
raise NotImplementedError(
"Loading non-floating master_param optimizer state is not "
"supported by NPU FusedAdam fallback. Checkpoints that use "
"store_param_remainders must be regenerated as FP32 "
"master_param state."
)
saved_state_tensors[(param, name)] = value.detach().clone()
result = super().load_state_dict(state_dict)
self._scales = {}
self._state_quantizers = {}
for param in saved_to_current_params.values():
state = self.state.get(param)
if not isinstance(state, dict):
continue
for name in ("exp_avg", "exp_avg_sq"):
value = saved_state_tensors.get((param, name))
if value is not None:
self.set_scaled_state(param, name, value)
master_param = saved_state_tensors.get((param, "master_param"))
if self.master_weights and param.dtype in (torch.float16, torch.bfloat16):
if master_param is None:
master_param = param.detach().float().clone()
self.set_scaled_state(param, "master_param", master_param)
else:
state.pop("master_param", None)
self._scale_dict(param).pop("master_param", None)
return result
@staticmethod
def _scalar_to_float(value, name: str) -> float:
"""Convert a scalar optimizer hyperparameter to a Python float."""
if torch.is_tensor(value):
if value.numel() != 1:
raise ValueError(f"Optimizer parameter {name} must be a scalar.")
return float(value.detach().item())
return float(value)
@staticmethod
def _step_to_int(step) -> int:
"""Convert a scalar optimizer step to a Python int."""
if torch.is_tensor(step):
if step.numel() != 1:
raise ValueError("Optimizer step must be a scalar.")
return int(step.detach().item())
return int(step)
@staticmethod
def _has_sparse_grad(grad: torch.Tensor) -> bool:
"""Return True for COO or compressed sparse gradients."""
if grad.is_sparse:
return True
is_sparse_csr = getattr(grad, "is_sparse_csr", False)
if callable(is_sparse_csr):
is_sparse_csr = is_sparse_csr()
return bool(is_sparse_csr)
@staticmethod
def _check_param_dtype(param: torch.Tensor) -> None:
"""Validate parameter dtype supported by this fallback."""
supported_dtypes = (torch.float32, torch.float16, torch.bfloat16)
if param.dtype not in supported_dtypes:
raise RuntimeError(
"FusedAdam fallback supports parameters in torch.float32, "
"torch.float16, and torch.bfloat16 only, but got "
f"{param.dtype}."
)
@staticmethod
def _check_state_dtype(dtype: torch.dtype, name: str) -> None:
"""Validate optimizer state dtype supported by this fallback."""
supported_dtypes = (torch.float32, torch.float16, torch.bfloat16, torch.uint8)
if dtype not in supported_dtypes:
raise NotImplementedError(
"NPU FusedAdam fallback supports "
f"{name}=torch.float32, torch.float16, torch.bfloat16, "
f"or torch.uint8 only; "
f"got {dtype}."
)
@staticmethod
def _is_dtensor_param(param) -> bool:
"""Return whether a parameter is a DTensor model weight."""
return isinstance(param, DTensor)
@staticmethod
def _is_quantized_model_param(param) -> bool:
"""Return whether a parameter is a TE QuantizedTensor model weight."""
return isinstance(param, QuantizedTensor)
@staticmethod
def _is_float8_model_param(param) -> bool:
"""Return whether a parameter itself is a Float8Tensor model weight."""
return isinstance(param, Float8Tensor)
def _check_unsupported_param_type(self, param: torch.Tensor) -> None:
"""Reject advanced model weight types that the fallback cannot update."""
if self._is_dtensor_param(param):
raise NotImplementedError(
"DTensor parameters are not supported by NPU FusedAdam fallback yet."
)
if self._is_float8_model_param(param):
raise NotImplementedError(
"FP8 model weights are not supported by NPU FusedAdam fallback yet. "
"FP8 optimizer state is supported separately from FP8 model parameters."
)
if self._is_quantized_model_param(param):
raise NotImplementedError(
"QuantizedTensor model weights are not supported by NPU fallback yet."
)
def _check_param_group_supported(self, group: dict) -> None:
"""Reject unsupported parameter combinations in a param group."""
has_fp16 = False
has_bf16 = False
for param in group["params"]:
self._check_unsupported_param_type(param)
has_fp16 = has_fp16 or param.dtype == torch.float16
has_bf16 = has_bf16 or param.dtype == torch.bfloat16
if has_fp16 and has_bf16:
raise RuntimeError(
"FusedAdam fallback does not support mixing torch.float16 and "
"torch.bfloat16 parameters in the same param group."
)
@staticmethod
def _float8_state_available() -> bool:
"""Return whether the NPU Float8Tensor state path can be used."""
if not hasattr(torch_npu, "npu_dynamic_quant"):
return False
return hasattr(torch, "float8_e4m3fn") and torch.float8_e4m3fn != torch.bfloat16
@staticmethod
def _is_float8_state(value) -> bool:
"""Return whether a state value is a TransformerEngineNPU Float8Tensor."""
return isinstance(value, Float8Tensor)
@staticmethod
def _float8_state_device(value) -> Optional[torch.device]:
"""Return the device backing a TransformerEngineNPU Float8Tensor state."""
if not FusedAdam._is_float8_state(value):
return None
try:
device = torch.device(value.device)
except (AttributeError, TypeError, RuntimeError):
device = None
else:
return device
for attr_name in ("_data", "_transpose"):
tensor = getattr(value, attr_name, None)
if torch.is_tensor(tensor):
return tensor.device
return None
@staticmethod
def _float8_quantizer_device(quantizer) -> Optional[torch.device]:
"""Return the device backing a Float8 optimizer-state quantizer."""
for attr_name in ("scale", "amax"):
tensor = getattr(quantizer, attr_name, None)
if torch.is_tensor(tensor):
return tensor.device
return None
@staticmethod
def _devices_match(left: torch.device, right: torch.device) -> bool:
"""Compare devices while treating implicit device index as compatible."""
left = torch.device(left)
right = torch.device(right)
if left == right:
return True
return left.type == right.type and (left.index is None or right.index is None)
def _state_dtype_for_name(self, name: str) -> torch.dtype:
"""Return configured storage dtype for an optimizer state tensor."""
if name == "exp_avg":
return self.exp_avg_dtype
if name == "exp_avg_sq":
return self.exp_avg_sq_dtype
if name == "master_param":
return self.master_weight_dtype
raise KeyError(f"Unknown optimizer state name: {name}")
def _scale_dict(self, param: torch.Tensor) -> dict:
"""Return per-state scale dictionary for a parameter."""
return self._scales.setdefault(param, {})
def _quantizer_dict(self, param: torch.Tensor) -> dict:
"""Return per-state FP8 quantizer dictionary for a parameter."""
return self._state_quantizers.setdefault(param, {})
def _state_quantizer(self, param: torch.Tensor, state_name: str):
"""Return the Float8CurrentScalingQuantizer for an FP8 optimizer state."""
if not self._float8_state_available():
raise NotImplementedError(
"FP8 optimizer state requires TransformerEngineNPU Float8Tensor, "
"torch_npu.npu_dynamic_quant, and torch.float8_e4m3fn support."
)
quantizers = self._quantizer_dict(param)
quantizer = quantizers.get(state_name)
if quantizer is not None:
quantizer_device = self._float8_quantizer_device(quantizer)
if quantizer_device is None or not self._devices_match(
quantizer_device,
param.device,
):
quantizer = None
if quantizer is None:
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=torch.float8_e4m3fn,
rowwise=True,
columnwise=False,
device=param.device,
)
quantizers[state_name] = quantizer
return quantizer
@staticmethod
def _fp8_quant_input_dtype(param: torch.Tensor) -> torch.dtype:
"""Return dense input dtype accepted by NPU FP8 dynamic quantization."""
if param.dtype in (torch.float16, torch.bfloat16):
return param.dtype
return torch.bfloat16
def _to_fp8_quant_input(
self,
param: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Cast FP32 optimizer state to the dtype accepted by the FP8 quantizer."""
return value.to(
device=param.device,
dtype=self._fp8_quant_input_dtype(param),
non_blocking=True,
)
def _state_to_fp8_storage_value(
self,
state_name: str,
value: torch.Tensor,
) -> torch.Tensor:
"""Map an fp32 optimizer state to the value stored in FP8."""
if (
state_name == "exp_avg_sq"
and self.exp_avg_sq_dtype == torch.uint8
and self.fp8_exp_avg_sq_storage == "sqrt"
):
return value.clamp_min(0.0).sqrt()
return value
def _state_from_fp8_storage_value(
self,
state_name: str,
value: torch.Tensor,
) -> torch.Tensor:
"""Map a dequantized FP8 storage value back to optimizer-state semantics."""
if (
state_name == "exp_avg_sq"
and self.exp_avg_sq_dtype == torch.uint8
and self.fp8_exp_avg_sq_storage == "sqrt"
):
value = value.clamp_min(0.0)
return value.mul(value).clamp_min(torch.finfo(torch.float32).tiny)
return value
def _set_unit_state_scale(self, param: torch.Tensor, state_name: str) -> None:
"""Record a no-op fp32 scale for states that are stored directly."""
self._scale_dict(param)[state_name] = torch.ones(
(),
dtype=torch.float32,
device=param.device,
)
def _ensure_state_scale(self, param: torch.Tensor, state_name: str) -> torch.Tensor:
"""Return the fp32 scalar scale for a state tensor."""
scales = self._scale_dict(param)
scale = scales.get(state_name)
if scale is None:
scale = torch.ones((), dtype=torch.float32, device=param.device)
scales[state_name] = scale
elif scale.device != param.device:
scale = scale.detach().to(device=param.device, dtype=torch.float32)
scales[state_name] = scale
return scale
def _require_state_scale(
self,
param: torch.Tensor,
state_name: str,
) -> torch.Tensor:
"""Return an existing state scale, preserving it across device moves."""
scales = self._scale_dict(param)
scale = scales.get(state_name)
if scale is None:
raise RuntimeError(
f"Missing scale for low-precision optimizer state {state_name}. "
"Cannot safely rebuild scaled optimizer state."
)
if scale.device != param.device:
scale = scale.detach().to(device=param.device, dtype=torch.float32)
scales[state_name] = scale
return scale
def _get_scaled_state_as_unscaled(
self,
param: torch.Tensor,
state_name: str,
value: torch.Tensor,
) -> torch.Tensor:
"""Return a stored low-precision state as unscaled fp32 data."""
if self._is_float8_state(value):
scale = self._require_state_scale(param, state_name)
storage_value = (
value.dequantize(dtype=torch.float32)
.detach()
.float()
.to(
device=param.device,
)
.mul(scale)
)
return self._state_from_fp8_storage_value(state_name, storage_value)
if value.dtype == torch.float32:
return value.detach().to(device=param.device, dtype=torch.float32)
if value.dtype == torch.bfloat16:
return value.detach().to(device=param.device, dtype=torch.float32)
scale = self._require_state_scale(param, state_name)
return value.detach().to(device=param.device, dtype=torch.float32).mul(scale)
def _compute_state_scale(
self,
unscaled_state: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute a scalar fp32 scale for low-precision state storage."""
if dtype in (torch.float32, torch.bfloat16):
return torch.ones((), dtype=torch.float32, device=unscaled_state.device)
max_abs = unscaled_state.detach().abs().max()
if (not torch.isfinite(max_abs)) or max_abs.item() == 0.0:
return torch.ones((), dtype=torch.float32, device=unscaled_state.device)
storage_amax = self.dtype_to_range_map.get(dtype)
if storage_amax is None:
storage_amax = torch.finfo(dtype).max * 0.5
scale = (max_abs / storage_amax).to(dtype=torch.float32)
return scale.clamp_min(torch.finfo(torch.float32).tiny)
def get_unscaled_state(self, param: torch.Tensor, state_name: str) -> torch.Tensor:
"""Return optimizer state as unscaled fp32 data."""
state = self.state[param]
value = state[state_name]
if self._is_float8_state(value):
return self._get_scaled_state_as_unscaled(param, state_name, value)
if value.dtype == torch.float32:
return value.detach().float().clone()
if value.dtype == torch.bfloat16:
self._set_unit_state_scale(param, state_name)
return value.detach().float().clone()
return self._get_scaled_state_as_unscaled(param, state_name, value)
def set_scaled_state(
self,
param: torch.Tensor,
state_name: str,
unscaled_state: torch.Tensor,
) -> None:
"""Store fp32 state in the configured dtype with TE-style scaling."""
state_dtype = self._state_dtype_for_name(state_name)
value = unscaled_state.detach().to(
device=param.device,
dtype=torch.float32,
non_blocking=True,
)
if state_dtype in (torch.float32, torch.bfloat16):
self.state[param][state_name] = value.to(dtype=state_dtype)
self._set_unit_state_scale(param, state_name)
return
if state_dtype == torch.uint8:
quantizer = self._state_quantizer(param, state_name)
old_state = self.state[param].get(state_name)
storage_value = self._state_to_fp8_storage_value(state_name, value)
scale = self._compute_state_scale(storage_value, state_dtype)
self._scale_dict(param)[state_name] = scale
scaled_value = storage_value.div(scale)
quant_input = self._to_fp8_quant_input(param, scaled_value)
with torch.no_grad():
if self._is_float8_state(old_state):
self.state[param][state_name] = quantizer.update_quantized(
quant_input,
old_state,
)
else:
self.state[param][state_name] = quantizer(quant_input)
return
scale = self._compute_state_scale(value, state_dtype)
self._scale_dict(param)[state_name] = scale
scaled_value = value.div(scale)
self.state[param][state_name] = scaled_value.to(dtype=state_dtype)
def _init_or_fix_state(self, param: torch.Tensor) -> dict:
"""Create missing state and keep buffers in configured storage dtypes."""
state = self.state[param]
for name in ("exp_avg", "exp_avg_sq"):
value = state.get(name)
state_dtype = self._state_dtype_for_name(name)
if value is None:
self.set_scaled_state(
param,
name,
torch.zeros_like(
param,
dtype=torch.float32,
memory_format=torch.preserve_format,
),
)
elif state_dtype == torch.uint8:
if self._is_float8_state(value):
state_device = self._float8_state_device(value)
if state_device is None or not self._devices_match(
state_device,
param.device,
):
fp32_value = self._get_scaled_state_as_unscaled(
param,
name,
value,
)
state.pop(name, None)
self.set_scaled_state(param, name, fp32_value)
else:
self._require_state_scale(param, name)
self._state_quantizer(param, name)
elif torch.is_tensor(value) and value.dtype == torch.uint8:
raise RuntimeError(
"legacy plain uint8 FP8 optimizer state is not supported. "
"Please load from an FP32 unscaled state_dict, or regenerate "
"the checkpoint."
)
else:
unscaled_value = self._get_scaled_state_as_unscaled(
param,
name,
value,
)
self.set_scaled_state(param, name, unscaled_value)
elif value.dtype != state_dtype or value.device != param.device:
unscaled_value = self._get_scaled_state_as_unscaled(param, name, value)
self.set_scaled_state(param, name, unscaled_value)
elif state_dtype in (torch.float32, torch.bfloat16):
self._set_unit_state_scale(param, name)
else:
self._require_state_scale(param, name)
if self.master_weights and param.dtype in (torch.float16, torch.bfloat16):
master_param = state.get("master_param")
if master_param is None:
self.set_scaled_state(
param,
"master_param",
param.detach().float().clone(),
)
else:
master_dtype = self._state_dtype_for_name("master_param")
if master_param.dtype != master_dtype or master_param.device != param.device:
master_fp32 = self._get_scaled_state_as_unscaled(
param,
"master_param",
master_param,
)
self.set_scaled_state(param, "master_param", master_fp32)
elif master_dtype == torch.bfloat16:
self._set_unit_state_scale(param, "master_param")
elif master_dtype == torch.float16:
self._require_state_scale(param, "master_param")
else:
self._ensure_state_scale(param, "master_param")
return state
def _max_param_state_step(self, group: dict) -> int:
"""Return max legacy per-parameter step from a param group."""
max_step = 0
for param in group["params"]:
state = self.state.get(param)
if state is None or "step" not in state:
continue
max_step = max(max_step, self._step_to_int(state["step"]))
return max_step
def _advance_group_step(self, group: dict) -> int:
"""Advance and return the TE-style param-group step."""
if "step" not in group:
group["step"] = self._max_param_state_step(group)
step = group["step"]
if torch.is_tensor(step):
step.add_(1)
return self._step_to_int(step)
group["step"] = int(step) + 1
return group["step"]
@torch.no_grad()
def step(self, closure=None, grad_scaler=None):
"""Perform a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grad_scaler (optional): Accepted for TE FusedAdam API compatibility.
The fallback assumes gradients have already been unscaled by the
caller, matching the non-capturable upstream path.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if len(group["params"]) == 0:
continue
if group.get("amsgrad", False):
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
self._check_param_group_supported(group)
step = self._advance_group_step(group)
lr = self._scalar_to_float(group["lr"], "lr")
eps = self._scalar_to_float(group["eps"], "eps")
weight_decay = self._scalar_to_float(group["weight_decay"], "weight_decay")
beta1, beta2 = group["betas"]
beta1 = self._scalar_to_float(beta1, "beta1")
beta2 = self._scalar_to_float(beta2, "beta2")
bias_correction = bool(group["bias_correction"])
for param in group["params"]:
self._check_param_dtype(param)
if self.use_decoupled_grad:
grad = getattr(param, "decoupled_grad", None)
else:
grad = param.grad
if grad is None:
continue
if self._has_sparse_grad(grad):
raise RuntimeError("FusedAdam does not support sparse gradients.")
self._init_or_fix_state(param)
exp_avg = self.get_unscaled_state(param, "exp_avg")
exp_avg_sq = self.get_unscaled_state(param, "exp_avg_sq")
grad_fp32 = grad.detach().float()
using_master_param = self.master_weights and param.dtype in (
torch.float16,
torch.bfloat16,
)
if using_master_param:
param_fp32 = self.get_unscaled_state(param, "master_param")
else:
param_fp32 = param.detach().float()
if weight_decay != 0.0:
if self.adam_w_mode:
param_fp32.add_(param_fp32, alpha=-lr * weight_decay)
else:
grad_fp32 = grad_fp32.add(param_fp32, alpha=weight_decay)
exp_avg.mul_(beta1).add_(grad_fp32, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(
grad_fp32,
grad_fp32,
value=1.0 - beta2,
)
denom = exp_avg_sq.sqrt()
if bias_correction:
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
step_size = lr / bias_correction1
denom.div_(math.sqrt(bias_correction2))
else:
step_size = lr
denom.add_(eps)
param_fp32.addcdiv_(exp_avg, denom, value=-step_size)
self.set_scaled_state(param, "exp_avg", exp_avg)
self.set_scaled_state(param, "exp_avg_sq", exp_avg_sq)
if using_master_param:
self.set_scaled_state(param, "master_param", param_fp32)
param.copy_(param_fp32.to(dtype=param.dtype))
return loss