"""Basic FusedAdam/FusedSGD tests for the NPU Python fallback optimizers."""
from __future__ import annotations
import copy
import inspect
import pytest
import torch
from utils import npu_available
pytestmark = pytest.mark.skipif(not npu_available(), reason="NPU device is required")
if npu_available():
from transformer_engine.pytorch.optimizers import FusedAdam, FusedSGD
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
Float8Tensor,
)
else:
FusedAdam = None
FusedSGD = None
Float8CurrentScalingQuantizer = None
Float8Tensor = None
def _device():
return torch.device("npu")
def _randn(shape, dtype=torch.float32):
return torch.randn(shape, dtype=dtype, device="cpu").to(device=_device())
def _make_linear():
return torch.nn.Linear(8, 4, device="cpu", dtype=torch.float32).to(device=_device())
def _clone_model(model):
return copy.deepcopy(model).to(device=_device())
def _assert_models_close(model, ref_model, rtol=1e-5, atol=1e-6):
for param, ref_param in zip(model.parameters(), ref_model.parameters()):
torch.testing.assert_close(
param.detach().to(device="cpu", dtype=torch.float32),
ref_param.detach().to(device="cpu", dtype=torch.float32),
rtol=rtol,
atol=atol,
)
def _max_param_diff(model, ref_model):
max_diff = 0.0
for param, ref_param in zip(model.parameters(), ref_model.parameters()):
diff = (
param.detach().to(device="cpu", dtype=torch.float32)
- ref_param.detach().to(device="cpu", dtype=torch.float32)
).abs()
max_diff = max(max_diff, float(diff.max().item()))
return max_diff
def _assert_adam_reference_close(model, ref_model):
"""Check FusedAdam against torch.optim on NPU with practical kernel tolerance."""
_assert_models_close(model, ref_model, rtol=1e-4, atol=1e-5)
_FP8_STATE_AVAILABLE = None
def _fp8_state_available():
global _FP8_STATE_AVAILABLE
if _FP8_STATE_AVAILABLE is not None:
return _FP8_STATE_AVAILABLE
if FusedAdam is None or not FusedAdam._float8_state_available():
_FP8_STATE_AVAILABLE = False
return _FP8_STATE_AVAILABLE
try:
value = torch.zeros(8, dtype=torch.bfloat16, device=_device())
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=torch.float8_e4m3fn,
rowwise=True,
columnwise=False,
device=value.device,
)
fp8_value = quantizer(value)
updated_value = quantizer.update_quantized(value + 1.0, fp8_value)
param = torch.nn.Parameter(torch.zeros(8, dtype=torch.float32, device=_device()))
optimizer = FusedAdam(
[param],
lr=1e-3,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.uint8,
)
param.grad = torch.ones_like(param)
optimizer.step()
state = optimizer.state[param]
_FP8_STATE_AVAILABLE = all(
(
isinstance(fp8_value, Float8Tensor),
isinstance(updated_value, Float8Tensor),
fp8_value.dequantize(dtype=torch.float32).dtype == torch.float32,
isinstance(state["exp_avg"], Float8Tensor),
isinstance(state["exp_avg_sq"], Float8Tensor),
)
)
except (AttributeError, NotImplementedError, RuntimeError, TypeError, ValueError):
_FP8_STATE_AVAILABLE = False
return _FP8_STATE_AVAILABLE
def _skip_if_fp8_state_unavailable():
if not _fp8_state_available() or Float8Tensor is None:
pytest.skip("NPU Float8Tensor optimizer state is not available in this environment")
def _make_param(dtype=torch.float32, values=None):
if values is None:
values = torch.linspace(-0.5, 0.5, steps=8, dtype=torch.float32, device="cpu")
values = values.to(device=_device(), dtype=dtype)
requires_grad = torch.is_floating_point(values)
return torch.nn.Parameter(values, requires_grad=requires_grad)
def _param_grad_like(param, offset=0):
grad = torch.linspace(
0.1,
0.8,
steps=param.numel(),
dtype=torch.float32,
device="cpu",
)
grad = grad.reshape_as(param).add_(offset * 0.01)
return grad.to(device=param.device, dtype=param.dtype)
def _adam_param_step(param, optimizer, step_idx=0):
param.grad = _param_grad_like(param, step_idx)
optimizer.step()
optimizer.zero_grad()
def _first_optimizer_state(state_dict):
return next(iter(state_dict["state"].values()))
def _manual_adam_update(
param,
grad,
exp_avg,
exp_avg_sq,
*,
lr,
betas,
eps,
weight_decay,
adam_w_mode,
bias_correction,
step,
):
beta1, beta2 = betas
param_fp32 = param.float().clone()
grad_fp32 = grad.float().clone()
if weight_decay != 0.0:
if adam_w_mode:
param_fp32.add_(param_fp32, alpha=-lr * weight_decay)
else:
grad_fp32 = grad_fp32.add(param_fp32, alpha=weight_decay)
exp_avg = exp_avg.mul(beta1).add(grad_fp32, alpha=1.0 - beta1)
exp_avg_sq = exp_avg_sq.mul(beta2).addcmul(grad_fp32, grad_fp32, value=1.0 - beta2)
denom = exp_avg_sq.sqrt()
if bias_correction:
step_size = lr / (1.0 - beta1**step)
denom = denom / ((1.0 - beta2**step) ** 0.5)
else:
step_size = lr
param_fp32.addcdiv_(exp_avg, denom.add(eps), value=-step_size)
return param_fp32, exp_avg, exp_avg_sq
def _mse_step(model, optimizer, x, target):
optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(model(x), target)
loss.backward()
optimizer.step()
return loss
def _backward_only(model, x, target):
loss = torch.nn.functional.mse_loss(model(x), target)
loss.backward()
return loss
def _sparse_grad_like(param):
indices = torch.tensor([[0]], dtype=torch.long, device=_device())
values = torch.ones(1, dtype=param.dtype, device=_device())
return torch.sparse_coo_tensor(indices, values, size=param.shape, device=_device())
def test_import_fused_sgd():
assert inspect.isclass(FusedSGD)
print("FusedSGD module:", FusedSGD.__module__)
print("FusedSGD file:", inspect.getfile(FusedSGD))
def test_sgd_matches_torch_sgd():
torch.manual_seed(1234)
model = _make_linear()
ref_model = _clone_model(model)
kwargs = {
"lr": 0.01,
"momentum": 0.9,
"dampening": 0.0,
"weight_decay": 0.01,
"nesterov": False,
}
optimizer = FusedSGD(model.parameters(), wd_after_momentum=False, **kwargs)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), **kwargs)
x = _randn((3, 8))
target = _randn((3, 4))
for _ in range(3):
_mse_step(model, optimizer, x, target)
_mse_step(ref_model, ref_optimizer, x, target)
_assert_models_close(model, ref_model)
def test_sgd_nesterov_matches_torch_sgd():
torch.manual_seed(1234)
model = _make_linear()
ref_model = _clone_model(model)
kwargs = {
"lr": 0.01,
"momentum": 0.9,
"dampening": 0.0,
"weight_decay": 0.01,
"nesterov": True,
}
optimizer = FusedSGD(model.parameters(), wd_after_momentum=False, **kwargs)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), **kwargs)
x = _randn((3, 8))
target = _randn((3, 4))
for _ in range(3):
_mse_step(model, optimizer, x, target)
_mse_step(ref_model, ref_optimizer, x, target)
_assert_models_close(model, ref_model)
def test_sgd_multiple_param_groups():
torch.manual_seed(1234)
model = _make_linear()
ref_model = _clone_model(model)
kwargs = {
"lr": 0.01,
"momentum": 0.9,
"dampening": 0.0,
"nesterov": False,
}
optimizer = FusedSGD(
[
{"params": [model.weight], "weight_decay": 0.01},
{"params": [model.bias], "weight_decay": 0.0},
],
wd_after_momentum=False,
**kwargs,
)
ref_optimizer = torch.optim.SGD(
[
{"params": [ref_model.weight], "weight_decay": 0.01},
{"params": [ref_model.bias], "weight_decay": 0.0},
],
**kwargs,
)
x = _randn((3, 8))
target = _randn((3, 4))
for _ in range(3):
_mse_step(model, optimizer, x, target)
_mse_step(ref_model, ref_optimizer, x, target)
_assert_models_close(model, ref_model)
def test_sgd_state_dict_load_state_dict():
torch.manual_seed(1234)
model = _make_linear()
kwargs = {"lr": 0.01, "momentum": 0.9, "weight_decay": 0.01}
optimizer = FusedSGD(model.parameters(), **kwargs)
x = _randn((3, 8))
target = _randn((3, 4))
for _ in range(2):
_mse_step(model, optimizer, x, target)
state_dict = optimizer.state_dict()
state_count = len(state_dict["state"])
optimizer2 = FusedSGD(model.parameters(), **kwargs)
optimizer2.load_state_dict(state_dict)
assert len(optimizer2.state_dict()["state"]) == state_count
_mse_step(model, optimizer2, x, target)
assert len(optimizer2.state_dict()["state"]) == state_count
def test_sgd_zero_grad():
torch.manual_seed(1234)
model = _make_linear()
optimizer = FusedSGD(model.parameters(), lr=0.01, momentum=0.9)
x = _randn((3, 8))
target = _randn((3, 4))
_backward_only(model, x, target)
optimizer.zero_grad()
assert all(param.grad is None for param in model.parameters())
_backward_only(model, x, target)
optimizer.zero_grad(set_to_none=True)
assert all(param.grad is None for param in model.parameters())
_backward_only(model, x, target)
optimizer.zero_grad(set_to_none=False)
for param in model.parameters():
assert param.grad is not None
torch.testing.assert_close(param.grad, torch.zeros_like(param.grad))
def test_sgd_skip_none_grad():
torch.manual_seed(1234)
model = _make_linear()
optimizer = FusedSGD(model.parameters(), lr=0.01, momentum=0.9)
x = _randn((3, 8))
target = _randn((3, 4))
_backward_only(model, x, target)
before_weight = model.weight.detach().clone()
before_bias = model.bias.detach().clone()
model.weight.grad = None
optimizer.step()
torch.testing.assert_close(model.weight, before_weight)
assert (model.bias.detach() - before_bias).abs().max().item() > 0.0
def test_sgd_sparse_grad_raises():
param = torch.nn.Parameter(torch.ones(4, device=_device()))
optimizer = FusedSGD([param], lr=0.01)
param.grad = _sparse_grad_like(param)
with pytest.raises(RuntimeError, match="sparse gradients"):
optimizer.step()
def test_import_fused_adam():
assert inspect.isclass(FusedAdam)
print("FusedAdam module:", FusedAdam.__module__)
print("FusedAdam file:", inspect.getfile(FusedAdam))
def test_fused_adam_signature():
signature = inspect.signature(FusedAdam.__init__)
expected_params = {
"params",
"lr",
"betas",
"eps",
"weight_decay",
"bias_correction",
"adam_w_mode",
"master_weights",
"master_weight_dtype",
"exp_avg_dtype",
"exp_avg_sq_dtype",
"use_decoupled_grad",
"store_param_remainders",
"set_grad_none",
}
assert expected_params.issubset(signature.parameters)
assert signature.parameters["master_weights"].default is False
def test_fused_adam_signature_matches_te():
signature = inspect.signature(FusedAdam.__init__)
expected_names = [
"self",
"params",
"lr",
"betas",
"eps",
"weight_decay",
"amsgrad",
"bias_correction",
"adam_w_mode",
"capturable",
"master_weights",
"master_weight_dtype",
"exp_avg_dtype",
"exp_avg_sq_dtype",
"use_decoupled_grad",
"store_param_remainders",
"set_grad_none",
]
assert list(signature.parameters) == expected_names
for name in expected_names[:7]:
assert signature.parameters[name].kind is inspect.Parameter.POSITIONAL_OR_KEYWORD
for name in expected_names[7:]:
assert signature.parameters[name].kind is inspect.Parameter.KEYWORD_ONLY
def test_fused_adam_defaults_match_te():
signature = inspect.signature(FusedAdam.__init__)
expected_defaults = {
"lr": 1e-3,
"betas": (0.9, 0.999),
"eps": 1e-8,
"weight_decay": 0.0,
"amsgrad": False,
"bias_correction": True,
"adam_w_mode": True,
"capturable": False,
"master_weights": False,
"master_weight_dtype": torch.float32,
"exp_avg_dtype": torch.float32,
"exp_avg_sq_dtype": torch.float32,
"use_decoupled_grad": False,
"store_param_remainders": False,
"set_grad_none": None,
}
for name, expected_default in expected_defaults.items():
assert signature.parameters[name].default == expected_default
def test_capturable_explicitly_not_supported_on_npu():
with pytest.raises(NotImplementedError, match="capturable=True"):
FusedAdam(_make_linear().parameters(), lr=1e-3, capturable=True)
def test_store_param_remainders_explicitly_not_supported_on_npu():
with pytest.raises(NotImplementedError, match="store_param_remainders=True"):
FusedAdam(_make_linear().parameters(), lr=1e-3, store_param_remainders=True)
def test_master_weight_dtype_fp16_explicit_behavior():
model = _make_linear().to(dtype=torch.float16)
optimizer = FusedAdam(
model.parameters(),
lr=1e-3,
master_weights=True,
master_weight_dtype=torch.float16,
)
x = _randn((3, 8), dtype=torch.float16)
target = _randn((3, 4), dtype=torch.float16)
_mse_step(model, optimizer, x, target)
for param in model.parameters():
assert optimizer.state[param]["master_param"].dtype == torch.float16
assert optimizer.get_unscaled_state(param, "master_param").dtype == torch.float32
@pytest.mark.parametrize("state_dtype_name", ["exp_avg_dtype", "exp_avg_sq_dtype"])
def test_unsupported_state_dtype_errors_include_argument_name(state_dtype_name):
kwargs = {state_dtype_name: torch.int8}
with pytest.raises(NotImplementedError, match=state_dtype_name):
FusedAdam(_make_linear().parameters(), lr=1e-3, **kwargs)
@pytest.mark.parametrize("dtype", [torch.float64, torch.int32])
def test_unsupported_parameter_dtype_errors_explicitly(dtype):
values = torch.ones(4, dtype=dtype, device=_device())
param = torch.nn.Parameter(values, requires_grad=torch.is_floating_point(values))
optimizer = FusedAdam([param], lr=1e-3)
if param.requires_grad:
param.grad = torch.ones_like(param)
with pytest.raises(
RuntimeError,
match="torch.float32, torch.float16, and torch.bfloat16",
):
optimizer.step()
def test_adamw_matches_manual_adamw_reference():
param = _make_param(dtype=torch.float32)
ref_param = param.detach().float().clone()
ref_exp_avg = torch.zeros_like(ref_param)
ref_exp_avg_sq = torch.zeros_like(ref_param)
kwargs = {
"lr": 1e-3,
"betas": (0.9, 0.999),
"eps": 1e-8,
"weight_decay": 0.01,
"bias_correction": True,
"adam_w_mode": True,
}
optimizer = FusedAdam([param], **kwargs)
for step in range(1, 6):
grad = _param_grad_like(param, step)
param.grad = grad
optimizer.step()
ref_param, ref_exp_avg, ref_exp_avg_sq = _manual_adam_update(
ref_param,
grad,
ref_exp_avg,
ref_exp_avg_sq,
**kwargs,
step=step,
)
optimizer.zero_grad()
torch.testing.assert_close(param.detach().float(), ref_param, rtol=1e-6, atol=1e-7)
def test_adam_l2_matches_manual_adam_reference():
param = _make_param(dtype=torch.float32)
ref_param = param.detach().float().clone()
ref_exp_avg = torch.zeros_like(ref_param)
ref_exp_avg_sq = torch.zeros_like(ref_param)
kwargs = {
"lr": 1e-3,
"betas": (0.9, 0.999),
"eps": 1e-8,
"weight_decay": 0.01,
"bias_correction": True,
"adam_w_mode": False,
}
optimizer = FusedAdam([param], **kwargs)
for step in range(1, 6):
grad = _param_grad_like(param, step)
param.grad = grad
optimizer.step()
ref_param, ref_exp_avg, ref_exp_avg_sq = _manual_adam_update(
ref_param,
grad,
ref_exp_avg,
ref_exp_avg_sq,
**kwargs,
step=step,
)
optimizer.zero_grad()
torch.testing.assert_close(param.detach().float(), ref_param, rtol=1e-6, atol=1e-7)
@pytest.mark.parametrize("adam_w_mode", [False, True])
def test_adam_bias_correction_false_matches_manual_reference(adam_w_mode):
param = _make_param(dtype=torch.float32)
optimizer = FusedAdam(
[param],
lr=1e-3,
betas=(0.8, 0.95),
eps=1e-6,
weight_decay=0.01,
bias_correction=False,
adam_w_mode=adam_w_mode,
)
ref_param = param.detach().float().clone()
ref_exp_avg = torch.zeros_like(ref_param)
ref_exp_avg_sq = torch.zeros_like(ref_param)
for step_idx in range(1, 5):
grad = _param_grad_like(param, step_idx)
param.grad = grad
optimizer.step()
ref_param, ref_exp_avg, ref_exp_avg_sq = _manual_adam_update(
ref_param,
grad,
ref_exp_avg,
ref_exp_avg_sq,
lr=1e-3,
betas=(0.8, 0.95),
eps=1e-6,
weight_decay=0.01,
adam_w_mode=adam_w_mode,
bias_correction=False,
step=step_idx,
)
torch.testing.assert_close(param.detach().float(), ref_param, rtol=1e-6, atol=1e-7)
def test_adam_multiple_param_groups():
weight = _make_param(dtype=torch.float32)
bias = _make_param(
dtype=torch.float32,
values=torch.linspace(0.25, 0.95, steps=8, dtype=torch.float32, device="cpu"),
)
ref_weight = weight.detach().float().clone()
ref_bias = bias.detach().float().clone()
weight_exp_avg = torch.zeros_like(ref_weight)
weight_exp_avg_sq = torch.zeros_like(ref_weight)
bias_exp_avg = torch.zeros_like(ref_bias)
bias_exp_avg_sq = torch.zeros_like(ref_bias)
common_kwargs = {
"lr": 1e-3,
"betas": (0.9, 0.999),
"eps": 1e-8,
"bias_correction": True,
"adam_w_mode": True,
}
optimizer = FusedAdam(
[
{"params": [weight], "weight_decay": 0.01},
{"params": [bias], "weight_decay": 0.0},
],
**common_kwargs,
)
for step in range(1, 6):
weight_grad = _param_grad_like(weight, step)
bias_grad = _param_grad_like(bias, step + 10)
weight.grad = weight_grad
bias.grad = bias_grad
optimizer.step()
ref_weight, weight_exp_avg, weight_exp_avg_sq = _manual_adam_update(
ref_weight,
weight_grad,
weight_exp_avg,
weight_exp_avg_sq,
weight_decay=0.01,
**common_kwargs,
step=step,
)
ref_bias, bias_exp_avg, bias_exp_avg_sq = _manual_adam_update(
ref_bias,
bias_grad,
bias_exp_avg,
bias_exp_avg_sq,
weight_decay=0.0,
**common_kwargs,
step=step,
)
optimizer.zero_grad()
assert optimizer.param_groups[0]["step"] == 5
assert optimizer.param_groups[1]["step"] == 5
torch.testing.assert_close(
weight.detach().float(),
ref_weight,
rtol=1e-6,
atol=1e-7,
)
torch.testing.assert_close(bias.detach().float(), ref_bias, rtol=1e-6, atol=1e-7)
def test_adam_state_dict_load_state_dict():
torch.manual_seed(1234)
model = _make_linear()
optimizer = FusedAdam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01,
bias_correction=True,
adam_w_mode=True,
)
x = _randn((3, 8))
target = _randn((3, 4))
for _ in range(2):
_mse_step(model, optimizer, x, target)
state_dict = optimizer.state_dict()
state_count = len(state_dict["state"])
optimizer2 = FusedAdam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01,
bias_correction=True,
adam_w_mode=True,
)
optimizer2.load_state_dict(state_dict)
assert len(optimizer2.state_dict()["state"]) == state_count
_mse_step(model, optimizer2, x, target)
assert len(optimizer2.state_dict()["state"]) == state_count
def test_adam_group_step_saved_and_loaded():
param = _make_param()
optimizer = FusedAdam([param], lr=1e-3)
_adam_param_step(param, optimizer, 0)
_adam_param_step(param, optimizer, 1)
state_dict = optimizer.state_dict()
assert state_dict["param_groups"][0]["step"] == 2
target_param = _make_param()
target_optimizer = FusedAdam([target_param], lr=1e-3)
target_optimizer.load_state_dict(state_dict)
assert target_optimizer.param_groups[0]["step"] == 2
_adam_param_step(target_param, target_optimizer, 2)
assert target_optimizer.param_groups[0]["step"] == 3
def test_adam_zero_grad():
torch.manual_seed(1234)
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3)
x = _randn((3, 8))
target = _randn((3, 4))
_backward_only(model, x, target)
optimizer.zero_grad()
assert all(param.grad is None for param in model.parameters())
_backward_only(model, x, target)
optimizer.zero_grad(set_to_none=True)
assert all(param.grad is None for param in model.parameters())
_backward_only(model, x, target)
optimizer.zero_grad(set_to_none=False)
for param in model.parameters():
assert param.grad is not None
torch.testing.assert_close(param.grad, torch.zeros_like(param.grad))
def test_adam_set_grad_none_constructor_warns():
with pytest.warns(DeprecationWarning, match="set_grad_none"):
FusedAdam(_make_linear().parameters(), lr=1e-3, set_grad_none=True)
def test_adam_set_grad_none_conflicting_zero_grad_raises():
with pytest.warns(DeprecationWarning, match="set_grad_none"):
optimizer = FusedAdam(_make_linear().parameters(), lr=1e-3, set_grad_none=True)
with pytest.raises(ValueError, match="set_grad_none=True"):
optimizer.zero_grad(set_to_none=False)
def test_adam_zero_grad_default_is_set_to_none_true():
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3)
for param in model.parameters():
param.grad = torch.ones_like(param)
optimizer.zero_grad()
assert all(param.grad is None for param in model.parameters())
def test_adam_zero_grad_false_zeros_regular_grads_without_decoupled_grad():
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3, use_decoupled_grad=False)
for param in model.parameters():
param.grad = torch.ones_like(param)
optimizer.zero_grad(set_to_none=False)
for param in model.parameters():
assert param.grad is not None
torch.testing.assert_close(param.grad, torch.zeros_like(param.grad))
def test_adam_zero_grad_decoupled_default_clears_only_decoupled_grad():
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3, use_decoupled_grad=True)
grad_refs = []
for param in model.parameters():
param.grad = torch.ones_like(param)
param.decoupled_grad = torch.full_like(param, 2.0)
grad_refs.append(param.grad.detach().clone())
optimizer.zero_grad()
for param, grad_ref in zip(model.parameters(), grad_refs):
assert param.decoupled_grad is None
torch.testing.assert_close(param.grad, grad_ref)
def test_adam_zero_grad_decoupled_false_zeros_only_decoupled_grad():
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3, use_decoupled_grad=True)
grad_refs = []
for param in model.parameters():
param.grad = torch.ones_like(param)
param.decoupled_grad = torch.full_like(param, 2.0)
grad_refs.append(param.grad.detach().clone())
optimizer.zero_grad(set_to_none=False)
for param, grad_ref in zip(model.parameters(), grad_refs):
torch.testing.assert_close(param.decoupled_grad, torch.zeros_like(param))
torch.testing.assert_close(param.grad, grad_ref)
def test_adam_zero_grad_decoupled_false_missing_decoupled_grad_is_noop():
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3, use_decoupled_grad=True)
for param in model.parameters():
param.grad = torch.ones_like(param)
if hasattr(param, "decoupled_grad"):
delattr(param, "decoupled_grad")
optimizer.zero_grad(set_to_none=False)
for param in model.parameters():
assert not hasattr(param, "decoupled_grad")
def test_adam_zero_grad_decoupled_false_after_set_to_none_is_noop():
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3, use_decoupled_grad=True)
grad_refs = []
for param in model.parameters():
param.grad = torch.ones_like(param)
param.decoupled_grad = torch.full_like(param, 2.0)
grad_refs.append(param.grad.detach().clone())
optimizer.zero_grad(set_to_none=True)
optimizer.zero_grad(set_to_none=False)
for param, grad_ref in zip(model.parameters(), grad_refs):
assert param.decoupled_grad is None
torch.testing.assert_close(param.grad, grad_ref)
def test_adam_skip_none_grad():
torch.manual_seed(1234)
model = _make_linear()
optimizer = FusedAdam(model.parameters(), lr=1e-3)
x = _randn((3, 8))
target = _randn((3, 4))
_backward_only(model, x, target)
before_weight = model.weight.detach().clone()
before_bias = model.bias.detach().clone()
model.weight.grad = None
optimizer.step()
torch.testing.assert_close(model.weight, before_weight)
assert (model.bias.detach() - before_bias).abs().max().item() > 0.0
assert optimizer.param_groups[0]["step"] == 1
def test_adam_none_grad_uses_group_step_when_grad_returns():
param = _make_param()
optimizer = FusedAdam([param], lr=1e-3, betas=(0.9, 0.999), bias_correction=True)
before = param.detach().clone()
optimizer.step()
assert optimizer.param_groups[0]["step"] == 1
torch.testing.assert_close(param.detach(), before)
grad = _param_grad_like(param, 0)
param.grad = grad
optimizer.step()
assert optimizer.param_groups[0]["step"] == 2
ref_param, _, _ = _manual_adam_update(
before,
grad,
torch.zeros_like(before, dtype=torch.float32),
torch.zeros_like(before, dtype=torch.float32),
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
adam_w_mode=True,
bias_correction=True,
step=2,
)
torch.testing.assert_close(param.detach().float(), ref_param, rtol=1e-6, atol=1e-7)
param.grad = None
optimizer.step()
assert optimizer.param_groups[0]["step"] == 3
before_return = param.detach().clone()
param.grad = _param_grad_like(param, 2)
optimizer.step()
assert optimizer.param_groups[0]["step"] == 4
assert (param.detach() - before_return).abs().max().item() > 0.0
def test_adam_sparse_grad_raises():
param = torch.nn.Parameter(torch.ones(4, device=_device()))
optimizer = FusedAdam([param], lr=1e-3)
param.grad = _sparse_grad_like(param)
with pytest.raises(RuntimeError, match="sparse gradients"):
optimizer.step()
def test_adam_capturable_explicitly_not_supported_on_npu():
with pytest.raises(NotImplementedError, match="capturable=True"):
FusedAdam(_make_linear().parameters(), lr=1e-3, capturable=True)
def test_adam_store_param_remainders_explicitly_not_supported_on_npu():
with pytest.raises(NotImplementedError, match="store_param_remainders=True"):
FusedAdam(_make_linear().parameters(), lr=1e-3, store_param_remainders=True)
def test_adam_unsupported_master_weight_dtype_raises_on_master_path():
model = _make_linear().to(dtype=torch.float16)
with pytest.raises(NotImplementedError, match="master_weight_dtype"):
FusedAdam(
model.parameters(),
lr=1e-3,
master_weights=True,
master_weight_dtype=torch.bfloat16,
)
def test_adam_use_decoupled_grad():
torch.manual_seed(1234)
model = _make_linear()
optimizer = FusedAdam(
model.parameters(),
lr=1e-3,
adam_w_mode=True,
use_decoupled_grad=True,
)
params = list(model.parameters())
before_params = [param.detach().clone() for param in params]
try:
for param in params:
param.grad = None
param.decoupled_grad = torch.ones_like(param)
optimizer.step()
for before_param, param in zip(before_params, params):
assert (param.detach() - before_param).abs().max().item() > 0.0
finally:
for param in params:
if hasattr(param, "decoupled_grad"):
delattr(param, "decoupled_grad")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_adam_fp16_bf16_params_step_with_fp32_default_states(dtype):
param = _make_param(dtype=dtype)
optimizer = FusedAdam([param], lr=1e-3)
_adam_param_step(param, optimizer, 0)
state = optimizer.state[param]
assert param.dtype == dtype
assert torch.isfinite(param.detach().float()).all()
assert state["exp_avg"].dtype == torch.float32
assert state["exp_avg_sq"].dtype == torch.float32
assert "master_param" not in state
@pytest.mark.parametrize(
("exp_avg_dtype", "exp_avg_sq_dtype"),
[
(torch.float16, torch.float16),
(torch.bfloat16, torch.bfloat16),
(torch.float16, torch.float32),
(torch.float32, torch.bfloat16),
],
)
def test_adam_low_precision_state_checkpoint_semantics(exp_avg_dtype, exp_avg_sq_dtype):
param = _make_param(dtype=torch.float16)
optimizer = FusedAdam(
[param],
lr=1e-3,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
)
_adam_param_step(param, optimizer, 0)
state = optimizer.state[param]
assert state["exp_avg"].dtype == exp_avg_dtype
assert state["exp_avg_sq"].dtype == exp_avg_sq_dtype
assert optimizer.get_unscaled_state(param, "exp_avg").dtype == torch.float32
assert optimizer.get_unscaled_state(param, "exp_avg_sq").dtype == torch.float32
if exp_avg_dtype == torch.bfloat16:
torch.testing.assert_close(
optimizer._scales[param]["exp_avg"],
torch.ones((), device=param.device, dtype=torch.float32),
)
if exp_avg_sq_dtype == torch.bfloat16:
torch.testing.assert_close(
optimizer._scales[param]["exp_avg_sq"],
torch.ones((), device=param.device, dtype=torch.float32),
)
state_dict = optimizer.state_dict()
checkpoint_state = _first_optimizer_state(state_dict)
assert checkpoint_state["exp_avg"].dtype == torch.float32
assert checkpoint_state["exp_avg_sq"].dtype == torch.float32
target_param = _make_param(dtype=torch.float16)
target_optimizer = FusedAdam(
[target_param],
lr=1e-3,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
)
target_optimizer.load_state_dict(state_dict)
target_state = target_optimizer.state[target_param]
assert target_state["exp_avg"].dtype == exp_avg_dtype
assert target_state["exp_avg_sq"].dtype == exp_avg_sq_dtype
_adam_param_step(target_param, target_optimizer, 1)
@pytest.mark.parametrize("param_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("master_weight_dtype", [torch.float32, torch.float16])
def test_adam_master_weights_checkpoint_semantics(param_dtype, master_weight_dtype):
param = _make_param(dtype=param_dtype)
optimizer = FusedAdam(
[param],
lr=1e-3,
master_weights=True,
master_weight_dtype=master_weight_dtype,
)
_adam_param_step(param, optimizer, 0)
state = optimizer.state[param]
assert state["master_param"].dtype == master_weight_dtype
assert optimizer.get_unscaled_state(param, "master_param").dtype == torch.float32
torch.testing.assert_close(
param.detach(),
optimizer.get_unscaled_state(param, "master_param").to(dtype=param_dtype),
rtol=1e-2,
atol=1e-2,
)
state_dict = optimizer.state_dict()
checkpoint_state = _first_optimizer_state(state_dict)
assert checkpoint_state["master_param"].dtype == torch.float32
target_param = _make_param(dtype=param_dtype)
target_optimizer = FusedAdam(
[target_param],
lr=1e-3,
master_weights=True,
master_weight_dtype=master_weight_dtype,
)
target_optimizer.load_state_dict(state_dict)
target_state = target_optimizer.state[target_param]
assert target_state["master_param"].dtype == master_weight_dtype
assert target_optimizer.get_unscaled_state(target_param, "master_param").dtype == torch.float32
_adam_param_step(target_param, target_optimizer, 1)
def test_adam_load_state_dict_rejects_store_param_remainders_master_param():
param = _make_param(dtype=torch.bfloat16)
optimizer = FusedAdam([param], lr=1e-3, master_weights=True)
_adam_param_step(param, optimizer, 0)
state_dict = optimizer.state_dict()
checkpoint_state = _first_optimizer_state(state_dict)
checkpoint_state["master_param"] = torch.zeros_like(
checkpoint_state["master_param"],
dtype=torch.int16,
)
target_param = _make_param(dtype=torch.bfloat16)
target_optimizer = FusedAdam([target_param], lr=1e-3, master_weights=True)
with pytest.raises(NotImplementedError, match="store_param_remainders"):
target_optimizer.load_state_dict(state_dict)
@pytest.mark.parametrize(
("exp_avg_dtype", "exp_avg_sq_dtype"),
[
(torch.uint8, torch.float32),
(torch.float32, torch.uint8),
(torch.uint8, torch.uint8),
],
)
def test_adam_fp8_state_checkpoint_semantics(exp_avg_dtype, exp_avg_sq_dtype):
_skip_if_fp8_state_unavailable()
param = _make_param(dtype=torch.float32)
optimizer = FusedAdam(
[param],
lr=1e-3,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
)
_adam_param_step(param, optimizer, 0)
state = optimizer.state[param]
if exp_avg_dtype == torch.uint8:
assert isinstance(state["exp_avg"], Float8Tensor)
assert optimizer.get_unscaled_state(param, "exp_avg").dtype == torch.float32
else:
assert state["exp_avg"].dtype == torch.float32
if exp_avg_sq_dtype == torch.uint8:
assert isinstance(state["exp_avg_sq"], Float8Tensor)
assert optimizer.fp8_exp_avg_sq_storage == "sqrt"
assert optimizer.get_unscaled_state(param, "exp_avg_sq").dtype == torch.float32
else:
assert state["exp_avg_sq"].dtype == torch.float32
state_dict = optimizer.state_dict()
checkpoint_state = _first_optimizer_state(state_dict)
assert checkpoint_state["exp_avg"].dtype == torch.float32
assert checkpoint_state["exp_avg_sq"].dtype == torch.float32
assert not isinstance(checkpoint_state["exp_avg"], Float8Tensor)
assert not isinstance(checkpoint_state["exp_avg_sq"], Float8Tensor)
target_param = _make_param(dtype=torch.float32)
target_optimizer = FusedAdam(
[target_param],
lr=1e-3,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
)
target_optimizer.load_state_dict(state_dict)
target_state = target_optimizer.state[target_param]
if exp_avg_dtype == torch.uint8:
assert isinstance(target_state["exp_avg"], Float8Tensor)
if exp_avg_sq_dtype == torch.uint8:
assert isinstance(target_state["exp_avg_sq"], Float8Tensor)
_adam_param_step(target_param, target_optimizer, 1)
def test_adam_load_state_dict_rejects_legacy_plain_uint8_fp8_state():
_skip_if_fp8_state_unavailable()
param = _make_param(dtype=torch.float32)
optimizer = FusedAdam([param], lr=1e-3)
_adam_param_step(param, optimizer, 0)
state_dict = optimizer.state_dict()
checkpoint_state = _first_optimizer_state(state_dict)
checkpoint_state["exp_avg"] = torch.zeros_like(
checkpoint_state["exp_avg"],
dtype=torch.uint8,
)
target_param = _make_param(dtype=torch.float32)
target_optimizer = FusedAdam(
[target_param],
lr=1e-3,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
)
with pytest.raises(RuntimeError, match="legacy plain uint8"):
target_optimizer.load_state_dict(state_dict)