"""Test checkpoint resharding."""
import operator
from functools import reduce
from unittest.mock import MagicMock
import os
import pytest
import numpy as np
import mindspore as ms
from mindspore import Tensor, Parameter
from mindspore.parallel import Layout
from mindformers.checkpoint.reshard import (
smart_slice,
balance_load,
infer_slice_area_by_rank,
ReshardHandler,
ReshardLoader
)
from mindformers.checkpoint.sharded_tensor import build_sharded_tensor
from mindformers.checkpoint.utils import get_sharded_tensor_shard_id
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_full_slice_tensor():
"""
Feature: smart_slice function - full slice scenario
Description: Test smart_slice with a complete slice (no actual slicing needed)
Expectation: Returns the original tensor when slice covers entire tensor
"""
tensor = Tensor(np.arange(24).reshape(4, 6), dtype=ms.float32)
slice_ranges = [(0, 4), (0, 6)]
result = smart_slice(tensor, slice_ranges, load_from_multi_rank=False)
assert result is tensor
assert np.array_equal(result.asnumpy(), tensor.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_partial_slice_tensor():
"""
Feature: smart_slice function - partial slice scenario
Description: Test smart_slice with partial slicing
Expectation: Returns correctly sliced tensor
"""
tensor = Tensor(np.arange(24).reshape(4, 6), dtype=ms.float32)
slice_ranges = [(1, 3), (2, 5)]
result = smart_slice(tensor, slice_ranges, load_from_multi_rank=False)
expected = tensor[1:3, 2:5]
assert np.array_equal(result.asnumpy(), expected.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_full_slice_with_multi_rank():
"""
Feature: smart_slice function - full slice with load_from_multi_rank=True
Description: Test smart_slice forces slicing even for full slice when load_from_multi_rank=True
Expectation: Returns numpy array even for full slice
"""
tensor = Tensor(np.arange(24).reshape(4, 6), dtype=ms.float32)
slice_ranges = [(0, 4), (0, 6)]
result = smart_slice(tensor, slice_ranges, load_from_multi_rank=True)
assert isinstance(result, np.ndarray)
assert np.array_equal(result, tensor.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_partial_slice_with_multi_rank():
"""
Feature: smart_slice function - partial slice with load_from_multi_rank=True
Description: Test smart_slice with partial slicing and multi-rank loading
Expectation: Returns numpy array with correct slice
"""
tensor = Tensor(np.arange(24).reshape(4, 6), dtype=ms.float32)
slice_ranges = [(1, 3), (2, 5)]
result = smart_slice(tensor, slice_ranges, load_from_multi_rank=True)
assert isinstance(result, np.ndarray)
expected = tensor.asnumpy()[1:3, 2:5]
assert np.array_equal(result, expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_parameter_type():
"""
Feature: smart_slice function - Parameter type
Description: Test smart_slice with Parameter type input
Expectation: Handles Parameter type correctly
"""
param = Parameter(Tensor(np.arange(12).reshape(3, 4), dtype=ms.float32))
slice_ranges = [(0, 2), (1, 3)]
result = smart_slice(param, slice_ranges, load_from_multi_rank=False)
expected = param[0:2, 1:3]
assert np.array_equal(result.asnumpy(), expected.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_numpy_array():
"""
Feature: smart_slice function - numpy array input
Description: Test smart_slice with numpy array input
Expectation: Handles numpy array correctly
"""
arr = np.arange(24).reshape(4, 6)
slice_ranges = [(1, 3), (2, 5)]
result = smart_slice(arr, slice_ranges, load_from_multi_rank=False)
expected = arr[1:3, 2:5]
assert np.array_equal(result, expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_dimension_mismatch():
"""
Feature: smart_slice function - dimension mismatch error
Description: Test smart_slice raises ValueError when slice dimension doesn't match tensor dimension
Expectation: Raises ValueError with appropriate message
"""
tensor = Tensor(np.arange(24).reshape(4, 6), dtype=ms.float32)
slice_ranges = [(0, 4), (0, 6), (0, 2)]
with pytest.raises(ValueError) as exc_info:
smart_slice(tensor, slice_ranges, load_from_multi_rank=False)
assert "dimension count" in str(exc_info.value).lower()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_3d_tensor():
"""
Feature: smart_slice function - 3D tensor slicing
Description: Test smart_slice with 3D tensor
Expectation: Correctly slices 3D tensor
"""
tensor = Tensor(np.arange(60).reshape(3, 4, 5), dtype=ms.float32)
slice_ranges = [(1, 3), (0, 2), (2, 4)]
result = smart_slice(tensor, slice_ranges, load_from_multi_rank=False)
expected = tensor[1:3, 0:2, 2:4]
assert np.array_equal(result.asnumpy(), expected.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_smart_slice_single_dimension():
"""
Feature: smart_slice function - 1D tensor slicing
Description: Test smart_slice with 1D tensor
Expectation: Correctly slices 1D tensor
"""
tensor = Tensor(np.arange(10), dtype=ms.float32)
slice_ranges = [(2, 7)]
result = smart_slice(tensor, slice_ranges, load_from_multi_rank=False)
expected = tensor[2:7]
assert np.array_equal(result.asnumpy(), expected.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balance_load_basic():
"""
Feature: balance_load function - basic load balancing
Description: Test balance_load distributes parameters evenly
Expectation: Parameters are distributed across groups with balanced sizes
"""
params = [
{"size": 100, "name": "param1"},
{"size": 200, "name": "param2"},
{"size": 150, "name": "param3"},
{"size": 50, "name": "param4"},
]
num_groups = 2
result = balance_load(params, num_groups)
assert len(result) == num_groups
total_params = sum(len(group) for group in result)
assert total_params == len(params)
group_sizes = [sum(p["size"] for p in group) for group in result]
max_size = max(group_sizes)
min_size = min(group_sizes)
avg_size = sum(group_sizes) / len(group_sizes)
assert (max_size - min_size) <= avg_size * 0.5
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balance_load_single_group():
"""
Feature: balance_load function - single group
Description: Test balance_load with single group
Expectation: All parameters go to single group
"""
params = [
{"size": 100, "name": "param1"},
{"size": 200, "name": "param2"},
]
num_groups = 1
result = balance_load(params, num_groups)
assert len(result) == 1
assert len(result[0]) == len(params)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balance_load_more_groups_than_params():
"""
Feature: balance_load function - more groups than parameters
Description: Test balance_load when num_groups > len(params)
Expectation: Some groups will be empty
"""
params = [
{"size": 100, "name": "param1"},
{"size": 200, "name": "param2"},
]
num_groups = 5
result = balance_load(params, num_groups)
assert len(result) == num_groups
non_empty_groups = [g for g in result if len(g) > 0]
assert len(non_empty_groups) == len(params)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balance_load_large_imbalance():
"""
Feature: balance_load function - large size imbalance
Description: Test balance_load with very different parameter sizes
Expectation: Large parameters are distributed first to balance load
"""
params = [
{"size": 1000, "name": "large1"},
{"size": 1000, "name": "large2"},
{"size": 10, "name": "small1"},
{"size": 10, "name": "small2"},
{"size": 10, "name": "small3"},
]
num_groups = 2
result = balance_load(params, num_groups)
group_sizes = [sum(p["size"] for p in group) for group in result]
assert abs(group_sizes[0] - group_sizes[1]) <= 20
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balance_load_empty_params():
"""
Feature: balance_load function - empty parameter list
Description: Test balance_load with empty parameter list
Expectation: Returns empty groups
"""
params = []
num_groups = 3
result = balance_load(params, num_groups)
assert len(result) == num_groups
assert all(len(group) == 0 for group in result)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balance_load_sorted_order():
"""
Feature: balance_load function - sorting order
Description: Test that balance_load sorts parameters by size (descending)
Expectation: Largest parameters are assigned first
"""
params = [
{"size": 50, "name": "small"},
{"size": 300, "name": "large"},
{"size": 100, "name": "medium"},
]
num_groups = 2
result = balance_load(params, num_groups)
first_group_sizes = [p["size"] for p in result[0]]
assert 300 in first_group_sizes
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_balance_load_equal_sizes():
"""
Feature: balance_load function - equal sized parameters
Description: Test balance_load with parameters of equal size
Expectation: Parameters distributed evenly
"""
params = [
{"size": 100, "name": f"param{i}"} for i in range(6)
]
num_groups = 3
result = balance_load(params, num_groups)
assert all(len(group) == 2 for group in result)
assert all(sum(p["size"] for p in group) == 200 for group in result)
def get_slice_data(full_data, offset):
area = ()
for begin, end in offset:
area += (slice(begin, end),)
return full_data[area]
def reshard_tensor_func(param_name, full_shape, from_layout, to_layout, to_rank_id):
"""reshard tensor and verify"""
reshard = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id)
reshard.infer_all_tensor_offset()
all_offset = reshard.global_union_area_map
ele_num = reduce(operator.mul, full_shape)
full_data = np.array(range(ele_num), np.int32).reshape(full_shape)
from_tensor_map = {}
for rank, offset in all_offset.items():
from_tensor_map[rank] = Tensor(get_slice_data(full_data, offset))
actual_result = reshard.get_real_tensor(from_tensor_map).asnumpy()
to_layout_dict = to_layout.to_dict()
to_offset = infer_slice_area_by_rank(
to_layout_dict['device_matrix'],
to_layout_dict['tensor_map'],
to_layout_dict['rank_list'].index(to_rank_id),
full_shape
)
expect_result = get_slice_data(full_data, to_offset)
assert np.all(actual_result == expect_result)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_between_fully_shard():
"""
Feature: Tensor resharding between fully sharded modes.
Description: Test bidirectional tensor resharding between two different fully sharded layouts.
Expectation: The reshard_tensor_func executes successfully without throwing exceptions,
completing bidirectional tensor resharding between the two fully sharded layouts.
"""
param_name = "weight"
full_shape = (64, 64)
layout_0 = Layout((2, 2, 2, 2), ('dp', 'cp', 'rep', 'tp'), rank_list=list(range(16, 32)))
from_layout = layout_0(('rep', 'cp'), 'tp')
layout_1 = Layout((2, 4), ('dp', 'tp'), rank_list=list(range(8, 16)))
to_layout = layout_1('dp', 'tp')
from_rank_id = 19
to_rank_id = 14
reshard_tensor_func(param_name, full_shape, from_layout, to_layout, to_rank_id)
reshard_tensor_func(param_name, full_shape, to_layout, from_layout, from_rank_id)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_between_fully_shard_and_not_shard():
"""
Feature: Tensor resharding between fully sharded and non-sharded modes.
Description: Test bidirectional tensor resharding between a fully sharded layout and a non-sharded layout.
Expectation: The reshard_tensor_func executes successfully without throwing exceptions,
completing bidirectional tensor resharding between the fully sharded and non-sharded layouts.
"""
param_name = "weight"
full_shape = (64, 64)
layout_0 = Layout((2, 4, 4), ('dp', 'cp', 'tp'), rank_list=list(range(32, 64)))
from_layout = layout_0(('cp', 'dp'), 'tp')
layout_1 = Layout((8,), ('dp',), rank_list=list(range(0, 8)))
to_layout = layout_1('None', 'None')
from_rank_id = 35
to_rank_id = 7
reshard_tensor_func(param_name, full_shape, from_layout, to_layout, to_rank_id)
reshard_tensor_func(param_name, full_shape, to_layout, from_layout, from_rank_id)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_between_not_fully_shard():
"""
Feature: Tensor resharding between non-fully sharded modes.
Description: Test bidirectional tensor resharding between two different non-fully sharded layouts.
Expectation: The reshard_tensor_func executes successfully without throwing exceptions,
completing bidirectional tensor resharding between the two non-fully sharded layouts.
"""
param_name = "weight"
full_shape = (64, 64)
layout_0 = Layout((2, 4, 4), ('dp', 'cp', 'tp'), rank_list=list(range(32, 64)))
from_layout = layout_0('cp', 'tp')
layout_1 = Layout((2, 4, 4), ('dp', 'cp', 'tp'), rank_list=list(range(0, 32)))
to_layout = layout_1('tp', 'dp')
from_rank_id = 35
to_rank_id = 7
reshard_tensor_func(param_name, full_shape, from_layout, to_layout, to_rank_id)
reshard_tensor_func(param_name, full_shape, to_layout, from_layout, from_rank_id)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_from_tensor_map_missing_rank():
"""
Feature: Exception handling for incomplete from_tensor_map.
Description: Test the scenario where from_tensor_map is missing a required rank.
Expectation: A ValueError exception is raised when calling get_real_tensor()
with an incomplete from_tensor_map that lacks a necessary rank entry.
"""
param_name = "weight"
full_shape = (64, 64)
layout_0 = Layout((2, 4, 4), ('dp', 'cp', 'tp'), rank_list=list(range(32, 64)))
from_layout = layout_0(('cp', 'dp'), 'tp')
layout_1 = Layout((8,), ('dp',), rank_list=list(range(0, 8)))
to_layout = layout_1('None', 'None')
to_rank_id = 7
reshard = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id)
reshard.infer_all_tensor_offset()
all_offset = reshard.global_union_area_map
ele_num = reduce(operator.mul, full_shape)
full_data = np.array(range(ele_num), np.int32).reshape(full_shape)
from_tensor_map = {}
pop_rank = 0
for rank, offset in all_offset.items():
pop_rank = rank
from_tensor_map[rank] = Tensor(get_slice_data(full_data, offset))
from_tensor_map.pop(pop_rank)
with pytest.raises(ValueError):
reshard.get_real_tensor(from_tensor_map)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_from_tensor_map_has_unexpected_data():
"""
Feature: Exception handling for invalid data in from_tensor_map.
Description: Test the scenario where a rank in from_tensor_map contains data with an unexpected shape.
Expectation: A ValueError exception is raised when calling get_real_tensor()
with from_tensor_map containing a rank with data that has an unexpected shape.
"""
param_name = "weight"
full_shape = (64, 64)
layout_0 = Layout((2, 4, 4), ('dp', 'cp', 'tp'), rank_list=list(range(32, 64)))
from_layout = layout_0(('cp', 'dp'), 'tp')
layout_1 = Layout((8,), ('dp',), rank_list=list(range(0, 8)))
to_layout = layout_1('None', 'None')
to_rank_id = 7
reshard = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id)
reshard.infer_all_tensor_offset()
all_offset = reshard.global_union_area_map
ele_num = reduce(operator.mul, full_shape)
full_data = np.array(range(ele_num), np.int32).reshape(full_shape)
from_tensor_map = {}
modify_rank = 0
for rank, offset in all_offset.items():
modify_rank = rank
from_tensor_map[rank] = Tensor(get_slice_data(full_data, offset))
modify_shape = from_tensor_map[modify_rank].shape
modify_shape = modify_shape[:-1] + (modify_shape[-1] + 2,)
from_tensor_map[modify_rank] = Tensor(np.zeros(modify_shape, full_data.dtype))
with pytest.raises(ValueError):
reshard.get_real_tensor(from_tensor_map)
class TestReshardLoader:
"""Test class for ReshardLoader functionality."""
@pytest.fixture
def mock_checkpoint_dir(self, tmp_path):
"""Create a mock checkpoint directory with safetensors files."""
checkpoint_dir = os.path.join(tmp_path, "checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
for i in range(2):
file_path = os.path.join(checkpoint_dir, f"model-{i:08d}.safetensors")
with open(file_path, 'wb') as f:
f.write(b"mock_safetensors_data")
return checkpoint_dir
@pytest.fixture
def simple_sharded_tensor(self):
"""Create a simple ShardedTensor for testing."""
layout = Layout((1,), ('dp',), rank_list=[0])
simple_layout = layout('None', 'None')
return build_sharded_tensor(
param_name="test.weight",
param_dtype=ms.float32,
local_shape=(10, 20),
global_shape=(10, 20),
global_offset=(0, 0),
axis_fragmentations=(1, 1),
layout=simple_layout
)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_loader_init_without_template(self, mock_checkpoint_dir, simple_sharded_tensor):
"""
Feature: ReshardLoader initialization without template
Description: Test ReshardLoader initialization for self-trained weights (no template)
Expectation: ReshardLoader is initialized successfully with correct bidirectional mapping
"""
dst_metas = {
"test.weight": simple_sharded_tensor
}
src_metas = {
"test.weight": [simple_sharded_tensor]
}
mapping_key = get_sharded_tensor_shard_id("test.weight", (0, 0))
param_file_mappings = {
mapping_key: [{
"file_name": "model-00000000.safetensors",
"storage_rank": 0,
"rank_group": [0]
}]
}
loader = ReshardLoader(
checkpoint_dir=mock_checkpoint_dir,
dst_sharded_tensor_metas=dst_metas,
src_sharded_tensor_metas=src_metas,
param_file_mappings=param_file_mappings,
reshard_worker_num=1,
template=None
)
assert loader.checkpoint_dir == mock_checkpoint_dir
assert loader.dst_metas == dst_metas
assert loader.src_metas == src_metas
assert loader.template is None
assert loader.get_dst_name("test.weight") == "test.weight"
assert loader.get_src_names("test.weight") == ["test.weight"]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_loader_bidirectional_mapping_without_template(self, mock_checkpoint_dir, simple_sharded_tensor):
"""
Feature: ReshardLoader bidirectional mapping without template
Description: Test bidirectional mapping construction for self-trained weights
Expectation: src_to_dst and dst_to_src mappings are correctly built
"""
dst_metas = {
"param1.weight": simple_sharded_tensor,
"param2.weight": simple_sharded_tensor
}
src_metas = {
"param1.weight": [simple_sharded_tensor],
"param2.weight": [simple_sharded_tensor]
}
param_file_mappings = {}
loader = ReshardLoader(
checkpoint_dir=mock_checkpoint_dir,
dst_sharded_tensor_metas=dst_metas,
src_sharded_tensor_metas=src_metas,
param_file_mappings=param_file_mappings,
reshard_worker_num=1,
template=None
)
assert loader.get_dst_name("param1.weight") == "param1.weight"
assert loader.get_dst_name("param2.weight") == "param2.weight"
assert loader.get_src_names("param1.weight") == ["param1.weight"]
assert loader.get_src_names("param2.weight") == ["param2.weight"]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_loader_bidirectional_mapping_with_template(self, mock_checkpoint_dir, simple_sharded_tensor):
"""
Feature: ReshardLoader bidirectional mapping with template
Description: Test bidirectional mapping construction for HuggingFace weights with template
Expectation: src_to_dst and dst_to_src mappings are correctly built using template
"""
mock_template = MagicMock()
mock_template.get_mf_name = MagicMock(
side_effect=lambda x: ("qkv.weight",)
if "q_proj" in x
or "k_proj" in x
or "v_proj" in x
else ("other.weight",)
)
mock_template.check_weights_for_experts = MagicMock(return_value=False)
mock_template.check_weights_for_qkv = MagicMock(return_value=False)
mock_template.stack_hf_experts_weight = MagicMock(
side_effect=lambda dst_name, num_experts, src_tensor: src_tensor)
dst_metas = {
"qkv.weight": simple_sharded_tensor
}
src_metas = {
"q_proj.weight": [simple_sharded_tensor],
"k_proj.weight": [simple_sharded_tensor],
"v_proj.weight": [simple_sharded_tensor]
}
param_file_mappings = {}
loader = ReshardLoader(
checkpoint_dir=mock_checkpoint_dir,
dst_sharded_tensor_metas=dst_metas,
src_sharded_tensor_metas=src_metas,
param_file_mappings=param_file_mappings,
reshard_worker_num=1,
template=mock_template
)
src_names = loader.get_src_names("qkv.weight")
assert "q_proj.weight" in src_names
assert "k_proj.weight" in src_names
assert "v_proj.weight" in src_names
assert loader.get_dst_name("q_proj.weight") == "qkv.weight"
assert loader.get_dst_name("k_proj.weight") == "qkv.weight"
assert loader.get_dst_name("v_proj.weight") == "qkv.weight"
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_loader_get_dst_name_nonexistent(self, mock_checkpoint_dir, simple_sharded_tensor):
"""
Feature: ReshardLoader get_dst_name for nonexistent parameter
Description: Test get_dst_name returns None for parameters not in mapping
Expectation: Returns None for nonexistent parameters
"""
dst_metas = {
"test.weight": simple_sharded_tensor
}
src_metas = {
"test.weight": [simple_sharded_tensor]
}
param_file_mappings = {}
loader = ReshardLoader(
checkpoint_dir=mock_checkpoint_dir,
dst_sharded_tensor_metas=dst_metas,
src_sharded_tensor_metas=src_metas,
param_file_mappings=param_file_mappings,
reshard_worker_num=1,
template=None
)
assert loader.get_dst_name("nonexistent.weight") is None
assert loader.get_src_names("nonexistent.weight") is None
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_reshard_loader_missing_src_in_dst(self, mock_checkpoint_dir, simple_sharded_tensor):
"""
Feature: ReshardLoader handles missing source parameters
Description: Test ReshardLoader initialization when src_metas has parameters not in dst_metas
Expectation: Parameters not in dst_metas are excluded from mapping
"""
dst_metas = {
"param1.weight": simple_sharded_tensor
}
src_metas = {
"param1.weight": [simple_sharded_tensor],
"param2.weight": [simple_sharded_tensor]
}
param_file_mappings = {}
loader = ReshardLoader(
checkpoint_dir=mock_checkpoint_dir,
dst_sharded_tensor_metas=dst_metas,
src_sharded_tensor_metas=src_metas,
param_file_mappings=param_file_mappings,
reshard_worker_num=1,
template=None
)
assert loader.get_dst_name("param1.weight") == "param1.weight"
assert loader.get_dst_name("param2.weight") is None