"""Unit tests for FSDP checkpoint path, metadata, and key migration helpers."""
import os
import tempfile
import pytest
class TestCheckpointNameAndTracker:
@pytest.mark.parametrize(
"iteration,expected_suffix",
[
(1, "iter_0000001"),
(7, "iter_0000007"),
(10, "iter_0000010"),
(42, "iter_0000042"),
(99, "iter_0000099"),
(100, "iter_0000100"),
(321, "iter_0000321"),
(999, "iter_0000999"),
(1000, "iter_0001000"),
(2024, "iter_0002024"),
(9999, "iter_0009999"),
(10000, "iter_0010000"),
(123456, "iter_0123456"),
(999999, "iter_0999999"),
(1000000, "iter_1000000"),
],
)
def test_get_checkpoint_name_zero_pads_iterations(self, iteration, expected_suffix):
from mindspeed_mm.fsdp.checkpoint.utils import get_checkpoint_name
assert get_checkpoint_name("/tmp/checkpoints", iteration).endswith(expected_suffix)
@pytest.mark.parametrize(
"root",
[
"/tmp/checkpoints",
"/tmp/checkpoints/",
"relative/checkpoints",
".",
"",
],
)
def test_get_checkpoint_tracker_filename_uses_latest_checkpoint_file(self, root):
from mindspeed_mm.fsdp.checkpoint.utils import get_checkpoint_tracker_filename
tracker = get_checkpoint_tracker_filename(root)
assert tracker == os.path.join(root, "latest_checkpointed_iteration.txt")
def test_get_checkpoint_name_release_ignores_iteration(self):
from mindspeed_mm.fsdp.checkpoint.utils import get_checkpoint_name
assert get_checkpoint_name("/tmp/checkpoints", 1, release=True) == os.path.join("/tmp/checkpoints", "release")
assert get_checkpoint_name("/tmp/checkpoints", 999999, release=True) == os.path.join(
"/tmp/checkpoints", "release"
)
@pytest.mark.parametrize(
"metadata,expected_iteration,expected_release",
[
("1", 1, False),
("7", 7, False),
("10\n", 10, False),
(" 42 ", 42, False),
("\t99\n", 99, False),
("123456", 123456, False),
("release", 0, True),
(" release\n", 0, True),
],
)
def test_read_metadata_accepts_iteration_or_release(self, metadata, expected_iteration, expected_release):
from mindspeed_mm.fsdp.checkpoint.utils import read_metadata
with tempfile.TemporaryDirectory() as temp_dir:
tracker = os.path.join(temp_dir, "latest.txt")
with open(tracker, "w", encoding="utf-8") as handle:
handle.write(metadata)
iteration, release = read_metadata(tracker)
assert iteration == expected_iteration
assert release is expected_release
@pytest.mark.parametrize(
"metadata",
[
"",
"0",
"-1",
"latest",
"Release",
"release-candidate",
"1.5",
"iter_0000001",
"None",
"null",
],
)
def test_read_metadata_rejects_invalid_non_release_values(self, metadata):
from mindspeed_mm.fsdp.checkpoint.utils import read_metadata
with tempfile.TemporaryDirectory() as temp_dir:
tracker = os.path.join(temp_dir, "latest.txt")
with open(tracker, "w", encoding="utf-8") as handle:
handle.write(metadata)
if metadata in {"0", "-1"}:
iteration, release = read_metadata(tracker)
assert iteration == int(metadata)
assert release is False
else:
with pytest.raises(ValueError, match="Invalid metadata file"):
read_metadata(tracker)
class TestBaseLayerKeyMigration:
@pytest.mark.parametrize(
"state_dict,expected_mapping,expected_keys",
[
(
{"model.layer.base_layer.weight": 1},
{"model.layer.base_layer.weight": "model.layer.weight"},
{"model.layer.weight"},
),
(
{"model.layer.base_layer.bias": 2},
{"model.layer.base_layer.bias": "model.layer.bias"},
{"model.layer.bias"},
),
(
{"a.base_layer.b.base_layer.c": 3},
{"a.base_layer.b.base_layer.c": "a.b.c"},
{"a.b.c"},
),
(
{"no_base_layer.weight": 4},
{},
{"no_base_layer.weight"},
),
(
{"encoder.0.base_layer.weight": 5, "encoder.1.weight": 6},
{"encoder.0.base_layer.weight": "encoder.0.weight"},
{"encoder.0.weight", "encoder.1.weight"},
),
(
{"lm_head.base_layer.weight": 7, "lm_head.base_layer.bias": 8},
{
"lm_head.base_layer.weight": "lm_head.weight",
"lm_head.base_layer.bias": "lm_head.bias",
},
{"lm_head.weight", "lm_head.bias"},
),
],
)
def test_remove_base_layer_keys_rewrites_matching_keys_in_place(
self,
state_dict,
expected_mapping,
expected_keys,
):
from mindspeed_mm.fsdp.checkpoint.utils import remove_base_layer_keys
original_id = id(state_dict)
mapping = remove_base_layer_keys(state_dict)
assert id(state_dict) == original_id
assert mapping == expected_mapping
assert set(state_dict.keys()) == expected_keys
@pytest.mark.parametrize(
"bad_state_dict",
[
None,
[],
(),
"model.layer.base_layer.weight",
123,
object(),
],
)
def test_remove_base_layer_keys_returns_empty_mapping_for_non_dict_inputs(self, bad_state_dict):
from mindspeed_mm.fsdp.checkpoint.utils import remove_base_layer_keys
assert remove_base_layer_keys(bad_state_dict) == {}
def test_remove_base_layer_keys_preserves_values_when_keys_move(self):
from mindspeed_mm.fsdp.checkpoint.utils import remove_base_layer_keys
value = object()
state_dict = {
"module.base_layer.weight": value,
"module.other.weight": "kept",
}
mapping = remove_base_layer_keys(state_dict)
assert mapping == {"module.base_layer.weight": "module.weight"}
assert state_dict["module.weight"] is value
assert state_dict["module.other.weight"] == "kept"
def test_restore_base_layer_keys_moves_rewritten_keys_back(self):
from mindspeed_mm.fsdp.checkpoint.utils import remove_base_layer_keys, restore_base_layer_keys
state_dict = {
"model.layers.0.base_layer.weight": "w0",
"model.layers.0.base_layer.bias": "b0",
"model.layers.1.weight": "w1",
}
mapping = remove_base_layer_keys(state_dict)
restore_base_layer_keys(state_dict, mapping)
assert state_dict == {
"model.layers.0.base_layer.weight": "w0",
"model.layers.0.base_layer.bias": "b0",
"model.layers.1.weight": "w1",
}
def test_restore_base_layer_keys_is_noop_for_invalid_inputs(self):
from mindspeed_mm.fsdp.checkpoint.utils import restore_base_layer_keys
restore_base_layer_keys(None, {"a": "b"})
restore_base_layer_keys([], {"a": "b"})
restore_base_layer_keys({"a": 1}, None)
restore_base_layer_keys({"a": 1}, [])
def test_restore_base_layer_keys_ignores_mapping_entries_not_present(self):
from mindspeed_mm.fsdp.checkpoint.utils import restore_base_layer_keys
state_dict = {"kept.weight": 1}
restore_base_layer_keys(
state_dict,
{
"missing.base_layer.weight": "missing.weight",
"other.base_layer.bias": "other.bias",
},
)
assert state_dict == {"kept.weight": 1}
def test_restore_base_layer_keys_handles_partial_restore(self):
from mindspeed_mm.fsdp.checkpoint.utils import restore_base_layer_keys
state_dict = {
"layer.weight": "restored",
"layer.bias": "left alone because missing from reverse mapping",
}
mapping = {
"layer.base_layer.weight": "layer.weight",
}
restore_base_layer_keys(state_dict, mapping)
assert state_dict == {
"layer.base_layer.weight": "restored",
"layer.bias": "left alone because missing from reverse mapping",
}
@pytest.mark.parametrize(
"keys",
[
["a.base_layer.weight"],
["a.base_layer.weight", "a.base_layer.bias"],
["a.base_layer.weight", "b.base_layer.weight", "c.weight"],
["prefix.base_layer.inner.base_layer.weight", "untouched"],
["adapter.base_layer.lora_A.weight", "adapter.base_layer.lora_B.weight"],
],
)
def test_remove_then_restore_round_trips_key_sets(self, keys):
from mindspeed_mm.fsdp.checkpoint.utils import remove_base_layer_keys, restore_base_layer_keys
original = {key: idx for idx, key in enumerate(keys)}
state_dict = dict(original)
mapping = remove_base_layer_keys(state_dict)
restore_base_layer_keys(state_dict, mapping)
assert state_dict == original