"""
Test module for testing the AdamW interface used for MindFormers.
How to run this:
pytest tests/st/test_optim/test_adamw.py
"""
import pytest
import numpy as np
import mindspore as ms
from mindspore import Parameter, Tensor, dtype as mstype
from mindspore.nn import Cell
from tests.st.test_optim.optimizer_util import (
build_network,
FakeNet,
default_fc1_weight_adamw_m,
default_fc2_weight_adamw_m,
default_fc1_weight_adamw_v,
default_fc2_weight_adamw_v
)
from mindformers.core.optim.adamw import AdamW, _check_param_value
ms.set_context(mode=0)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
class TestAdamW:
"""A test class for testing optimizer computation."""
def test_computation(self):
"""
Feature: Trainer.train()
Description: Test computation of AdamW in training.
Expectation: AssertionError
"""
config = {'type': 'AdamW', "weight_decay": 0.1}
_, cells = build_network(config, FakeNet(), is_group=True)
assert np.allclose(cells.exp_avg[0].asnumpy(), default_fc1_weight_adamw_m, atol=1.e-4)
assert np.allclose(cells.exp_avg[2].asnumpy(), default_fc2_weight_adamw_m, atol=1.e-4)
assert np.allclose(cells.exp_avg_sq[0].asnumpy(), default_fc1_weight_adamw_v, atol=1.e-4)
assert np.allclose(cells.exp_avg_sq[2].asnumpy(), default_fc2_weight_adamw_v, atol=1.e-4)
class SimpleNet(Cell):
"""Simple network for testing"""
def __init__(self):
super().__init__()
self.weight = Parameter(Tensor(np.ones([2, 3]), mstype.float32), name="weight")
self.bias = Parameter(Tensor(np.zeros([3]), mstype.float32), name="bias")
def construct(self, x):
return x * self.weight + self.bias
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init():
"""
Feature: AdamW optimizer initialization
Description: Test AdamW initialization with default parameters
Expectation: Successfully initialize AdamW optimizer with default parameters and verify beta1, beta2, and eps values
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
assert optimizer is not None
assert np.allclose(optimizer.beta1.asnumpy(), np.array([0.9]))
assert np.allclose(optimizer.beta2.asnumpy(), np.array([0.999]))
assert np.allclose(optimizer.eps.asnumpy(), np.array([1e-8]))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init_with_custom_params():
"""
Feature: AdamW optimizer initialization with custom parameters
Description: Test AdamW initialization with custom learning rate, betas, eps, and weight decay
Expectation: Successfully initialize AdamW with custom parameters and verify the values are set correctly
"""
net = SimpleNet()
learning_rate = 0.005
betas = (0.8, 0.99)
eps = 1e-7
weight_decay = 0.01
optimizer = AdamW(
net.trainable_params(),
learning_rate=learning_rate,
betas=betas,
eps=eps,
weight_decay=weight_decay
)
assert np.allclose(optimizer.beta1.asnumpy(), np.array([0.8]))
assert np.allclose(optimizer.beta2.asnumpy(), np.array([0.99]))
assert np.allclose(optimizer.eps.asnumpy(), np.array([1e-7]))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init_with_swap():
"""
Feature: AdamW optimizer initialization with swap parameter
Description: Test AdamW initialization with swap=True to offload optimizer states to CPU
Expectation: Successfully initialize AdamW with swap=True and verify optimizer states are on CPU
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), swap=True)
assert optimizer.swap is True
for param in optimizer.exp_avg:
assert param.device == 'CPU'
for param in optimizer.exp_avg_sq:
assert param.device == 'CPU'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_init_with_group_params():
"""
Feature: AdamW optimizer initialization with group parameters
Description: Test AdamW initialization with grouped parameters having different learning rates and weight decays
Expectation: Successfully initialize AdamW with grouped parameters
"""
net = SimpleNet()
params = [
{'params': [net.weight], 'lr': 0.001, 'weight_decay': 0.01},
{'params': [net.bias], 'lr': 0.0001, 'weight_decay': 0.0}
]
optimizer = AdamW(params)
assert optimizer is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value():
"""
Feature: _check_param_value function
Description: Test _check_param_value function with valid parameters
Expectation: Successfully validate parameters without raising exceptions
"""
betas = (0.9, 0.999)
eps = 1e-8
weight_decay = 0.01
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_betas_type():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid betas type (string instead of tuple/list)
Expectation: Raise TypeError when betas is not a tuple or list
"""
betas = "invalid"
eps = 1e-8
weight_decay = 0.01
with pytest.raises(TypeError):
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_betas_length():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid betas length (3 elements instead of 2)
Expectation: Raise ValueError when betas length is not 2
"""
betas = (0.9, 0.999, 0.9999)
eps = 1e-8
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_beta1():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid beta1 (equal to 1.0, which is out of range)
Expectation: Raise ValueError when beta1 is not in (0.0, 1.0)
"""
betas = (1.0, 0.999)
eps = 1e-8
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_beta2():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid beta2 (equal to 0.0, which is out of range)
Expectation: Raise ValueError when beta2 is not in (0.0, 1.0)
"""
betas = (0.9, 0.0)
eps = 1e-8
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_eps():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid eps (equal to 0.0, which is not greater than 0)
Expectation: Raise ValueError when eps is not greater than 0
"""
betas = (0.9, 0.999)
eps = 0.0
weight_decay = 0.01
with pytest.raises(ValueError):
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_invalid_weight_decay():
"""
Feature: _check_param_value function parameter validation
Description: Test _check_param_value function with invalid weight_decay type (string instead of float/int/Cell)
Expectation: Raise TypeError when weight_decay is not a float, int, or Cell
"""
betas = (0.9, 0.999)
eps = 1e-8
weight_decay = "invalid"
with pytest.raises(TypeError):
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct():
"""
Feature: AdamW optimizer construct method
Description: Test AdamW construct method with dummy gradients
Expectation: Successfully execute AdamW construct method and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct_with_group_lr():
"""
Feature: AdamW optimizer construct method with group learning rates
Description: Test AdamW construct method with grouped parameters having different learning rates
Expectation: Successfully execute AdamW construct method with group learning rates and return non-None result
"""
net = SimpleNet()
params = [
{'params': [net.weight], 'lr': 0.001},
{'params': [net.bias], 'lr': 0.0001}
]
optimizer = AdamW(params)
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state():
"""
Feature: AdamW optimizer clone_state method
Description: Test AdamW clone_state method to create copies of optimizer states
Expectation: Successfully clone optimizer states with correct shape and dtype
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
cloned_state = optimizer.clone_state("test", "zeros")
assert len(cloned_state) == len(optimizer.parameters)
for i, param in enumerate(cloned_state):
assert param.shape == optimizer.parameters[i].shape
assert param.dtype == mstype.float32
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state_with_swap():
"""
Feature: AdamW optimizer clone_state method with swap=True
Description: Test AdamW clone_state method with swap=True to clone states to CPU
Expectation: Successfully clone optimizer states to CPU when swap=True
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), swap=True)
cloned_state = optimizer.clone_state("test", "zeros")
assert len(cloned_state) == len(optimizer.parameters)
for param in cloned_state:
assert param.device == 'CPU'
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_dynamic_weight_decay():
"""
Feature: AdamW optimizer with dynamic weight decay
Description: Test AdamW initialization with dynamic weight decay
Expectation: Successfully initialize AdamW with dynamic weight decay
"""
net = SimpleNet()
weight_decay = 1.0
optimizer = AdamW(net.trainable_params(), weight_decay=weight_decay)
assert optimizer is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_step():
"""
Feature: AdamW optimizer step execution
Description: Test AdamW step execution with dummy gradients
Expectation: Successfully execute one optimization step and update parameters
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
optimizer(gradients)
for param in net.trainable_params():
assert not np.all(param.asnumpy() == np.ones_like(param.asnumpy()))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_zero_gradients():
"""
Feature: AdamW optimizer with zero gradients
Description: Test AdamW execution with zero gradients and non-zero weight decay
Expectation: Parameters are updated due to weight decay even with zero gradients
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), weight_decay=0.1)
gradients = (
Tensor(np.zeros([2, 3]), mstype.float32),
Tensor(np.zeros([3]), mstype.float32)
)
optimizer(gradients)
for param in net.trainable_params():
assert not np.allclose(param.asnumpy(), np.ones_like(param.asnumpy()))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_invalid_learning_rate():
"""
Feature: AdamW optimizer with invalid learning rate
Description: Test AdamW initialization with invalid learning rate type
Expectation: Raise ValueError when learning_rate is invalid
"""
net = SimpleNet()
with pytest.raises(ValueError):
AdamW(net.trainable_params(), learning_rate="invalid")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_large_weight_decay():
"""
Feature: AdamW optimizer with large weight decay
Description: Test AdamW execution with large weight decay (0.1)
Expectation: Parameters are significantly updated due to large weight decay
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), weight_decay=0.1)
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
optimizer(gradients)
for param in net.trainable_params():
assert np.all(param.asnumpy() < np.ones_like(param.asnumpy()))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_small_eps():
"""
Feature: AdamW optimizer with small epsilon
Description: Test AdamW execution with very small epsilon (1e-12)
Expectation: Successfully execute AdamW with small epsilon and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), eps=1e-12)
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_tuple_params():
"""
Feature: AdamW optimizer with tuple parameters
Description: Test AdamW initialization with tuple of parameters
Expectation: Successfully initialize AdamW with tuple of parameters
"""
net = SimpleNet()
params_tuple = tuple(net.trainable_params())
optimizer = AdamW(params_tuple)
assert optimizer is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_global_step_increase():
"""
Feature: AdamW optimizer global step
Description: Test AdamW global step attribute and step execution
Expectation: Verify global_step attribute exists and is a Tensor, and successfully execute one step
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
assert hasattr(optimizer, 'global_step')
assert isinstance(optimizer.global_step, Tensor)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_mixed_precision_params():
"""
Feature: AdamW optimizer with mixed precision parameters
Description: Test AdamW execution with parameters of different precisions
Expectation: Successfully initialize AdamW with mixed precision parameters and execute one step
"""
params = [
Parameter(Tensor(np.ones([2, 3]), mstype.float32), name="fp32_param"),
Parameter(Tensor(np.ones([3]), mstype.float16), name="fp16_param")
]
optimizer = AdamW(params)
assert optimizer is not None
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float16)
)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_optim_filter():
"""
Feature: AdamW optimizer with optim_filter
Description: Test AdamW execution with optim_filter set to include all parameters
Expectation: Successfully execute AdamW with optim_filter and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
optimizer.optim_filter = (True, True)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct_without_group():
"""
Feature: AdamW optimizer construct method without group
Description: Test AdamW construct method with is_group set to False
Expectation: Successfully execute AdamW construct method with is_group=False
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
optimizer.is_group = False
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_large_lr():
"""
Feature: AdamW optimizer with large learning rate
Description: Test AdamW execution with large learning rate (1.0)
Expectation: Successfully execute AdamW with large learning rate and return non-None result
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), learning_rate=1.0)
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_tensor_lr():
"""
Feature: AdamW optimizer with Tensor learning rate
Description: Test AdamW initialization with Tensor learning rate
Expectation: Successfully initialize AdamW with Tensor learning rate
"""
net = SimpleNet()
lr_tensor = Tensor(np.array([0.001]), mstype.float32)
optimizer = AdamW(net.trainable_params(), learning_rate=lr_tensor)
assert optimizer is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_iterable_lr():
"""
Feature: AdamW optimizer with Iterable learning rate
Description: Test AdamW initialization with Iterable learning rate
Expectation: Successfully initialize AdamW with Iterable learning rate
"""
net = SimpleNet()
lr_iter = [0.001, 0.0009, 0.0008]
optimizer = AdamW(net.trainable_params(), learning_rate=lr_iter)
assert optimizer is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_optim_filter_false():
"""
Feature: AdamW optimizer with optim_filter=False
Description: Test AdamW execution with optim_filter set to False for some parameters
Expectation: Successfully execute AdamW with optim_filter=False and verify gradients are returned for
filtered parameters
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
optimizer.optim_filter = (False, False)
result = optimizer(gradients)
assert result is not None
assert len(result) == len(gradients)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_get_weight_decay_and_lr():
"""
Feature: AdamW optimizer get_weight_decay and get_lr methods
Description: Test AdamW get_weight_decay and get_lr methods
Expectation: Successfully call get_weight_decay and get_lr methods
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), weight_decay=0.01, learning_rate=0.001)
weight_decay = optimizer.get_weight_decay()
assert weight_decay is not None
lr = optimizer.get_lr()
assert lr is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state_with_cloned_obj():
"""
Feature: AdamW optimizer clone_state method with existing cloned_obj
Description: Test AdamW clone_state method when old_param.param_info already has cloned_obj
Expectation: Successfully clone state and append to existing cloned_obj list
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
first_clone = optimizer.clone_state("first", "zeros")
second_clone = optimizer.clone_state("second", "zeros")
assert len(first_clone) == len(optimizer.parameters)
assert len(second_clone) == len(optimizer.parameters)
for old_param in optimizer.parameters:
assert hasattr(old_param.param_info, "cloned_obj")
assert len(old_param.param_info.cloned_obj) >= 2
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_construct_all_branches():
"""
Feature: AdamW optimizer construct method branches
Description: Test all branches of AdamW construct method
Expectation: Successfully execute all branches of construct method
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
optimizer.is_group = False
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None
params = [
{'params': [net.weight], 'lr': 0.001},
{'params': [net.bias], 'lr': 0.0001}
]
optimizer = AdamW(params)
optimizer.is_group = True
optimizer.is_group_lr = True
result = optimizer(gradients)
assert result is not None
optimizer.is_group_lr = False
result = optimizer(gradients)
assert result is not None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_check_param_value_more_cases():
"""
Feature: _check_param_value function with more cases
Description: Test _check_param_value function with more parameter combinations
Expectation: Successfully validate parameters or raise expected exceptions
"""
betas = (0.9, 0.999)
eps = 1e-8
weight_decay = 0.0
_check_param_value(betas, eps, weight_decay, "AdamW")
betas = (0.0001, 0.9999)
_check_param_value(betas, eps, weight_decay, "AdamW")
eps = 1e-3
_check_param_value(betas, eps, weight_decay, "AdamW")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_clone_state_with_ones_init():
"""
Feature: AdamW clone_state method with ones init
Description: Test AdamW clone_state method with ones initialization
Expectation: Successfully clone state with ones init
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params())
ones_clone = optimizer.clone_state("adam_m_ones", "ones")
assert len(ones_clone) == len(optimizer.parameters)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_large_learning_rate():
"""
Feature: AdamW optimizer with large learning rate
Description: Test AdamW execution with very large learning rate
Expectation: Successfully execute AdamW with large learning rate and update parameters significantly
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), learning_rate=1.0)
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
optimizer(gradients)
for param in net.trainable_params():
assert not np.allclose(param.asnumpy(), np.ones_like(param.asnumpy()), atol=0.1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_adamw_with_very_small_learning_rate():
"""
Feature: AdamW optimizer with very small learning rate
Description: Test AdamW execution with very small learning rate
Expectation: Successfully execute AdamW with very small learning rate
"""
net = SimpleNet()
optimizer = AdamW(net.trainable_params(), learning_rate=1e-10)
gradients = (
Tensor(np.ones([2, 3]), mstype.float32),
Tensor(np.ones([3]), mstype.float32)
)
result = optimizer(gradients)
assert result is not None