"""Test save/load `metadata.json` info."""
import os
import json
import shutil
import pytest
import mindspore as ms
from mindformers.checkpoint.sharded_tensor import get_sharded_tensor_from_strategy_metadata
from mindformers.checkpoint.metadata import save_metadata, load_metadata, get_metadata_of_checkpoint
from mindformers.checkpoint.checkpoint import save_metadata_json
from mindformers.checkpoint.utils import (
get_checkpoint_iter_dir,
get_metadata_filename,
get_checkpoint_name,
FileType
)
AA = ms.parallel.Layout((2, 2, 2), ("dp", "sp", "mp"))
A = AA("dp", "mp")
GLOBAL_STRATEGY_INFO = {
0: {
"decoder.layers.0.input_layernorm.weight": [A, 'float32', [3584]],
"adam_m.decoder.layers.0.input_layernorm.weight": [A, 'float32', [3584]],
},
1: {
"decoder.layers.0.input_layernorm.weight": [A, 'float32', [3584]],
"adam_m.decoder.layers.0.input_layernorm.weight": [A, 'float32', [3584]],
}
}
MODEL_KEYS = ["decoder.layers.0.input_layernorm.weight"]
USER_PREFIX = "my_test_net"
CHECKPOINT_ROOT_DIR = "./output_megatron_format_metadata"
ITERATION_WITH_OPTIMIZER = 1
ITERATION_WITHOUT_OPTIMIZER = 2
NOT_EXISTS = False
def save_metadata_without_npu(global_strategy_info, model_keys, user_prefix, metadata_file_path, save_optimizer):
"""Saving metadata.json without NPU ranks, using mock data."""
npu_nums = 2
sharded_tensor_metas = {}
param_file_mappings = []
for cur_npu_rank in range(0, npu_nums):
cur_rank_strategy_layout = global_strategy_info[cur_npu_rank]
cur_rank_sharded_tensors = get_sharded_tensor_from_strategy_metadata(
param_infos=cur_rank_strategy_layout,
cur_npu_rank=cur_npu_rank,
filter_func=(lambda x: x in list(model_keys)) if not save_optimizer else None
)
sharded_tensor_metas[cur_npu_rank] = cur_rank_sharded_tensors
for _, sharded_tensor in cur_rank_sharded_tensors.items():
if save_optimizer and sharded_tensor.key not in list(model_keys):
ckpt_name = get_checkpoint_name(None, user_prefix, cur_npu_rank, npu_nums, FileType.OPTIMIZER)
else:
ckpt_name = get_checkpoint_name(None, user_prefix, cur_npu_rank, npu_nums, FileType.MODEL)
param_file_mappings.append(
(
ckpt_name + '.safetensors',
cur_npu_rank,
[cur_npu_rank],
(sharded_tensor.key, sharded_tensor.global_offset)
)
)
sharded_tensor_metas[cur_npu_rank] = cur_rank_sharded_tensors
save_metadata(sharded_tensor_metas, param_file_mappings, metadata_file_path)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_and_load_metadata_case():
"""
Feature: Test save metadata info in none-has optimizer two cases, then load them.
Description: Simulate saving 'metadata.json' in succession
to ensure that the paths and contents of both accesses are normal.
Then load the saved metadata to check whether the load function can obtain the value normally.
Expectation: No error is reported during test this case.
"""
has_optimizer_checkpoint_path = get_checkpoint_iter_dir(CHECKPOINT_ROOT_DIR, ITERATION_WITH_OPTIMIZER)
os.makedirs(has_optimizer_checkpoint_path, exist_ok=True)
has_optimizer_metadata_file_path = get_metadata_filename(CHECKPOINT_ROOT_DIR, ITERATION_WITH_OPTIMIZER)
save_metadata_without_npu(
global_strategy_info=GLOBAL_STRATEGY_INFO,
model_keys=MODEL_KEYS,
user_prefix=USER_PREFIX,
metadata_file_path=has_optimizer_metadata_file_path,
save_optimizer=True
)
assert os.path.isfile(has_optimizer_metadata_file_path)
no_optimizer_checkpoint_path = get_checkpoint_iter_dir(CHECKPOINT_ROOT_DIR, ITERATION_WITHOUT_OPTIMIZER)
os.makedirs(no_optimizer_checkpoint_path, exist_ok=True)
no_optimizer_metadata_file_path = get_metadata_filename(CHECKPOINT_ROOT_DIR, ITERATION_WITHOUT_OPTIMIZER)
save_metadata_without_npu(
global_strategy_info=GLOBAL_STRATEGY_INFO,
model_keys=MODEL_KEYS,
user_prefix=USER_PREFIX,
metadata_file_path=no_optimizer_metadata_file_path,
save_optimizer=False
)
assert os.path.isfile(no_optimizer_metadata_file_path)
has_optimizer_sharded_tensors, has_optimizer_param_file_mappings = load_metadata(
get_metadata_filename(CHECKPOINT_ROOT_DIR, ITERATION_WITH_OPTIMIZER)
)
for sharded_tensor in has_optimizer_sharded_tensors['decoder.layers.0.input_layernorm.weight']:
assert sharded_tensor.local_shape == (1792,)
assert sharded_tensor.global_shape == (3584,)
assert sharded_tensor.global_offset in [(0,), (1,)]
adam_input_layernorm = has_optimizer_sharded_tensors['adam_m.decoder.layers.0.input_layernorm.weight']
assert adam_input_layernorm is not None
adam_mapping_0 = has_optimizer_param_file_mappings["('adam_m.decoder.layers.0.input_layernorm.weight', (0,))"][0]
assert adam_mapping_0["storage_rank"] == 0
assert adam_mapping_0["file_name"] == "my_test_net-opt-0000000-0000002.safetensors"
assert "rank_group" in adam_mapping_0
no_optimizer_sharded_tensors, no_optimizer_param_file_mappings = load_metadata(
get_metadata_filename(CHECKPOINT_ROOT_DIR, ITERATION_WITHOUT_OPTIMIZER)
)
for k in no_optimizer_sharded_tensors:
assert "adam" not in k
for sharded_tensor_no_op in no_optimizer_sharded_tensors['decoder.layers.0.input_layernorm.weight']:
assert sharded_tensor_no_op.local_shape == (1792,)
assert sharded_tensor_no_op.global_shape == (3584,)
assert sharded_tensor_no_op.global_offset in [(0,), (1,)]
decoder_mapping_1 = no_optimizer_param_file_mappings["('decoder.layers.0.input_layernorm.weight', (1,))"][0]
assert decoder_mapping_1["storage_rank"] == 1
assert decoder_mapping_1["file_name"] == "my_test_net-model-0000001-0000002.safetensors"
shutil.rmtree(CHECKPOINT_ROOT_DIR)
assert os.path.exists(CHECKPOINT_ROOT_DIR) == NOT_EXISTS
assert os.path.exists(has_optimizer_metadata_file_path) == NOT_EXISTS
assert os.path.exists(no_optimizer_metadata_file_path) == NOT_EXISTS
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_save_metadata_json_with_none(tmp_path):
"""
Feature: Test save_metadata_json with None sharded_tensor_metas.
Description: Test when sharded_tensor_metas is None.
Expectation: Should not create metadata file.
"""
metadata_file = os.path.join(tmp_path, "metadata.json")
save_metadata_json(None, [], "test", metadata_file)
assert not os.path.exists(metadata_file)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_metadata_of_checkpoint():
"""
Feature: Test get_metadata_of_checkpoint.
Description: Test reading metadata from checkpoint directory.
Expectation: Should return sharded_tensor_metas and param_file_mappings with rank_group field.
"""
test_checkpoint_dir = "./test_checkpoint_metadata"
os.makedirs(test_checkpoint_dir, exist_ok=True)
try:
metadata_path = os.path.join(test_checkpoint_dir, "metadata.json")
mock_metadata = {
"state_dict_metadata": {
"decoder.final_layernorm.weight": {
"properties": {
"dtype": "Float32",
"replica_id": [0, 1, 2, 3],
"allow_shape_mismatch": False,
"allow_to_save": True
},
"global_shape": [896],
"axis_fragmentations": [1],
"layout": {
"device_matrix": [1, 1, 4],
"tensor_map": [-1],
"interleaved_parallel": False,
"alias_name": ["a", "b", "c"],
"rank_list": [0, 1, 2, 3]
},
"chunk": [{
"global_offset": [0],
"local_shape": [896]
}]
}
},
"storage_data": {
"('decoder.final_layernorm.weight', (0,))": [{
"file_name": "model-0000003-0000004.safetensors",
"storage_rank": 3,
"rank_group": [3, 4, 5]
}]
}
}
with open(metadata_path, "w", encoding='utf-8') as f:
json.dump(mock_metadata, f)
sharded_tensor_metas, param_file_mappings = get_metadata_of_checkpoint(test_checkpoint_dir)
assert isinstance(sharded_tensor_metas, dict)
assert isinstance(param_file_mappings, dict)
storage_data_key = "('decoder.final_layernorm.weight', (0,))"
assert storage_data_key in param_file_mappings
storage_info = param_file_mappings[storage_data_key][0]
assert "rank_group" in storage_info
assert storage_info["rank_group"] == [3, 4, 5]
assert storage_info["file_name"] == "model-0000003-0000004.safetensors"
assert storage_info["storage_rank"] == 3
finally:
if os.path.exists(test_checkpoint_dir):
shutil.rmtree(test_checkpoint_dir)