import unittest
from unittest.mock import MagicMock, patch, Mock
import torch
from mindie_llm.runtime.utils.loader.weight_utils import WeightsFileHandler
class TestWeightsFileHandler(unittest.TestCase):
"""Test cases for WeightsFileHandler."""
def setUp(self):
"""Set up test fixtures."""
self.model_path = "/fake/model/path"
self.extension = ".safetensors"
def tearDown(self):
"""Clean up after tests."""
pass
@patch('mindie_llm.runtime.utils.loader.weight_utils.WeightsFileHandler._load_weight_file_routing')
@patch('mindie_llm.runtime.utils.loader.weight_utils.WeightsFileHandler._get_weight_filenames')
def test_init(self, mock_get_weight_filenames, mock_load_routing):
"""Test initialization."""
mock_filenames = ["file1.safetensors", "file2.safetensors"]
mock_routing = {"tensor1": "file1.safetensors", "tensor2": "file2.safetensors"}
mock_get_weight_filenames.return_value = mock_filenames
mock_load_routing.return_value = mock_routing
handler = WeightsFileHandler(self.model_path, self.extension)
self.assertEqual(handler._handlers, {})
self.assertEqual(handler._filenames, mock_filenames)
self.assertEqual(handler._routing, mock_routing)
mock_get_weight_filenames.assert_called_once_with(self.model_path, self.extension)
mock_load_routing.assert_called_once()
def test_extension_property(self):
"""Test extension property."""
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=[]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value={}):
handler = WeightsFileHandler(self.model_path, self.extension)
self.assertEqual(handler.extension, ".safetensors")
def test_release_file_handler_with_handlers(self):
"""Test release_file_handler when handlers exist."""
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=[]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value={}):
handler = WeightsFileHandler(self.model_path, self.extension)
mock_handler1 = MagicMock()
mock_handler2 = MagicMock()
handler._handlers = {
"file1.safetensors": mock_handler1,
"file2.safetensors": mock_handler2
}
handler.release_file_handler()
self.assertEqual(handler._handlers, {})
def test_release_file_handler_without_handlers(self):
"""Test release_file_handler when no handlers exist."""
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=[]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value={}):
handler = WeightsFileHandler(self.model_path, self.extension)
self.assertEqual(handler._handlers, {})
handler.release_file_handler()
self.assertEqual(handler._handlers, {})
def test_release_file_handler_multiple_calls(self):
"""Test release_file_handler can be called multiple times."""
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=[]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value={}):
handler = WeightsFileHandler(self.model_path, self.extension)
handler._handlers = {"file1.safetensors": MagicMock()}
handler.release_file_handler()
handler.release_file_handler()
handler.release_file_handler()
self.assertEqual(handler._handlers, {})
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
def test_get_tensor_single_file(self, mock_safe_open):
"""Test get_tensor with single file."""
mock_routing = {"tensor1": "file1.safetensors"}
mock_file_handler = MagicMock()
mock_tensor = torch.tensor([1.0, 2.0, 3.0])
mock_file_handler.get_tensor.return_value = mock_tensor
mock_safe_open.return_value = mock_file_handler
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value=mock_routing):
handler = WeightsFileHandler(self.model_path, self.extension)
result = handler.get_tensor("tensor1")
mock_safe_open.assert_called_once_with("file1.safetensors", framework="pytorch")
mock_file_handler.get_tensor.assert_called_once_with("tensor1")
self.assertTrue(torch.allclose(result, mock_tensor))
self.assertEqual(len(handler._handlers), 1)
self.assertIn("file1.safetensors", handler._handlers)
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
def test_get_tensor_multiple_files(self, mock_safe_open):
"""Test get_tensor with multiple files."""
mock_routing = {
"tensor1": "file1.safetensors",
"tensor2": "file2.safetensors"
}
mock_file_handler1 = MagicMock()
mock_file_handler2 = MagicMock()
mock_tensor1 = torch.tensor([1.0, 2.0])
mock_tensor2 = torch.tensor([3.0, 4.0])
mock_file_handler1.get_tensor.return_value = mock_tensor1
mock_file_handler2.get_tensor.return_value = mock_tensor2
mock_safe_open.side_effect = [mock_file_handler1, mock_file_handler2]
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors", "file2.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value=mock_routing):
handler = WeightsFileHandler(self.model_path, self.extension)
result1 = handler.get_tensor("tensor1")
result2 = handler.get_tensor("tensor2")
self.assertEqual(mock_safe_open.call_count, 2)
self.assertEqual(len(handler._handlers), 2)
self.assertTrue(torch.allclose(result1, mock_tensor1))
self.assertTrue(torch.allclose(result2, mock_tensor2))
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
def test_get_tensor_handler_caching(self, mock_safe_open):
"""Test that file handlers are cached and reused."""
mock_routing = {"tensor1": "file1.safetensors", "tensor2": "file1.safetensors"}
mock_file_handler = MagicMock()
mock_tensor1 = torch.tensor([1.0, 2.0])
mock_tensor2 = torch.tensor([3.0, 4.0])
mock_file_handler.get_tensor.side_effect = [mock_tensor1, mock_tensor2]
mock_safe_open.return_value = mock_file_handler
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value=mock_routing):
handler = WeightsFileHandler(self.model_path, self.extension)
result1 = handler.get_tensor("tensor1")
result2 = handler.get_tensor("tensor2")
mock_safe_open.assert_called_once_with("file1.safetensors", framework="pytorch")
self.assertEqual(mock_file_handler.get_tensor.call_count, 2)
self.assertTrue(torch.allclose(result1, mock_tensor1))
self.assertTrue(torch.allclose(result2, mock_tensor2))
self.assertEqual(len(handler._handlers), 1)
def test_get_tensor_tensor_not_found(self):
"""Test get_tensor raises ValueError when tensor is not found."""
mock_routing = {"tensor1": "file1.safetensors"}
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value=mock_routing):
handler = WeightsFileHandler(self.model_path, self.extension)
with self.assertRaises(ValueError) as context:
handler.get_tensor("nonexistent_tensor")
self.assertIn("Weight file was not found", str(context.exception))
self.assertIn("nonexistent_tensor", str(context.exception))
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
def test_get_tensor_after_release_file_handler(self, mock_safe_open):
"""Test get_tensor after release_file_handler creates new handlers."""
mock_routing = {"tensor1": "file1.safetensors"}
mock_file_handler1 = MagicMock()
mock_file_handler2 = MagicMock()
mock_tensor1 = torch.tensor([1.0, 2.0])
mock_tensor2 = torch.tensor([3.0, 4.0])
mock_file_handler1.get_tensor.return_value = mock_tensor1
mock_file_handler2.get_tensor.return_value = mock_tensor2
mock_safe_open.side_effect = [mock_file_handler1, mock_file_handler2]
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value=mock_routing):
handler = WeightsFileHandler(self.model_path, self.extension)
result1 = handler.get_tensor("tensor1")
self.assertEqual(len(handler._handlers), 1)
mock_safe_open.assert_called_once()
handler.release_file_handler()
self.assertEqual(len(handler._handlers), 0)
result2 = handler.get_tensor("tensor1")
self.assertEqual(mock_safe_open.call_count, 2)
self.assertEqual(len(handler._handlers), 1)
self.assertTrue(torch.allclose(result1, mock_tensor1))
self.assertTrue(torch.allclose(result2, mock_tensor2))
def test_get_filename_valid_tensor(self):
"""Test _get_filename with valid tensor name."""
mock_routing = {"tensor1": "file1.safetensors"}
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value=mock_routing):
handler = WeightsFileHandler(self.model_path, self.extension)
filename, tensor_name = handler._get_filename("tensor1")
self.assertEqual(filename, "file1.safetensors")
self.assertEqual(tensor_name, "tensor1")
def test_get_filename_invalid_tensor(self):
"""Test _get_filename with invalid tensor name."""
mock_routing = {"tensor1": "file1.safetensors"}
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value=mock_routing):
handler = WeightsFileHandler(self.model_path, self.extension)
with self.assertRaises(ValueError) as context:
handler._get_filename("nonexistent_tensor")
self.assertIn("Weight file was not found", str(context.exception))
self.assertIn("nonexistent_tensor", str(context.exception))
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
def test_get_handler_new_file(self, mock_safe_open):
"""Test _get_handler creates new handler for new file."""
mock_file_handler = MagicMock()
mock_safe_open.return_value = mock_file_handler
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value={}):
handler = WeightsFileHandler(self.model_path, self.extension)
result = handler._get_handler("file1.safetensors")
mock_safe_open.assert_called_once_with("file1.safetensors", framework="pytorch")
self.assertEqual(result, mock_file_handler)
self.assertEqual(handler._handlers["file1.safetensors"], mock_file_handler)
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
def test_get_handler_existing_file(self, mock_safe_open):
"""Test _get_handler returns cached handler for existing file."""
mock_file_handler = MagicMock()
mock_safe_open.return_value = mock_file_handler
with patch.object(WeightsFileHandler, '_get_weight_filenames', return_value=["file1.safetensors"]):
with patch.object(WeightsFileHandler, '_load_weight_file_routing', return_value={}):
handler = WeightsFileHandler(self.model_path, self.extension)
result1 = handler._get_handler("file1.safetensors")
result2 = handler._get_handler("file1.safetensors")
mock_safe_open.assert_called_once()
self.assertEqual(result1, result2)
self.assertEqual(result1, mock_file_handler)
@patch('mindie_llm.runtime.utils.loader.weight_utils.Path')
def test_get_weight_filenames_directory_exists(self, mock_path_class):
"""Test _get_weight_filenames when directory exists."""
mock_path = MagicMock()
mock_path.exists.return_value = True
mock_path.is_dir.return_value = True
mock_file1 = MagicMock()
mock_file1.name = "model-00001-of-00002.safetensors"
mock_file2 = MagicMock()
mock_file2.name = "model-00002-of-00002.safetensors"
mock_path.glob.return_value = [mock_file1, mock_file2]
mock_path_class.return_value = mock_path
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler.quantize = None
result = handler._get_weight_filenames(self.model_path, self.extension)
self.assertEqual(len(result), 2)
self.assertIsInstance(result[0], str)
self.assertIsInstance(result[1], str)
@patch('mindie_llm.runtime.utils.loader.weight_utils.Path')
def test_get_weight_filenames_no_files(self, mock_path_class):
"""Test _get_weight_filenames raises error when no files found."""
mock_path = MagicMock()
mock_path.exists.return_value = True
mock_path.is_dir.return_value = True
mock_path.glob.return_value = []
mock_path_class.return_value = mock_path
with self.assertRaises(FileNotFoundError) as context:
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler.quantize = None
handler._get_weight_filenames(self.model_path, self.extension)
self.assertIn("No local weights found", str(context.exception))
self.assertIn(self.extension, str(context.exception))
@patch('mindie_llm.runtime.utils.loader.weight_utils.os.path.isfile')
@patch('mindie_llm.runtime.utils.loader.weight_utils.Path')
def test_get_weight_filenames_without_index_file(self, mock_path_class, mock_isfile):
"""Test _get_weight_filenames without index file."""
mock_path = MagicMock()
mock_path.exists.return_value = True
mock_path.is_dir.return_value = True
mock_file1 = MagicMock()
mock_file1.name = "model-00001-of-00002.safetensors"
mock_file2 = MagicMock()
mock_file2.name = "model-00002-of-00002.safetensors"
mock_path.glob.return_value = [mock_file1, mock_file2]
mock_path_class.return_value = mock_path
mock_isfile.return_value = False
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler.quantize = None
result = handler._get_weight_filenames(self.model_path, self.extension)
self.assertEqual(len(result), 2)
@patch('mindie_llm.runtime.utils.loader.weight_utils.logger')
@patch('mindie_llm.runtime.utils.loader.weight_utils.json.load')
@patch('mindie_llm.runtime.utils.loader.weight_utils.safe_open')
@patch('mindie_llm.runtime.utils.loader.weight_utils.os.path.isfile')
@patch('mindie_llm.runtime.utils.loader.weight_utils.Path')
def test_get_weight_filenames_with_index_file(self, mock_path_class, mock_isfile, mock_safe_open, mock_json_load, mock_logger):
"""Test _get_weight_filenames with index file filtering."""
mock_path = MagicMock()
mock_path.exists.return_value = True
mock_path.is_dir.return_value = True
mock_file1 = MagicMock()
mock_file1.name = "model-00001-of-00002.safetensors"
mock_file2 = MagicMock()
mock_file2.name = "model-00002-of-00002.safetensors"
mock_file3 = MagicMock()
mock_file3.name = "model-00003-of-00002.safetensors"
mock_path.glob.return_value = [mock_file1, mock_file2, mock_file3]
mock_path_class.return_value = mock_path
mock_isfile.return_value = True
mock_index_content = {
"weight_map": {
"tensor1": "model-00001-of-00002.safetensors",
"tensor2": "model-00002-of-00002.safetensors"
}
}
mock_file_context = MagicMock()
mock_file_context.__enter__ = Mock(return_value=mock_file_context)
mock_file_context.__exit__ = Mock(return_value=False)
mock_safe_open.return_value = mock_file_context
mock_json_load.return_value = mock_index_content
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler.quantize = None
result = handler._get_weight_filenames(self.model_path, self.extension)
self.assertEqual(len(result), 2)
mock_logger.info.assert_called()
@patch('mindie_llm.runtime.utils.loader.weight_utils.Path')
def test_get_weight_filenames_path_not_exists(self, mock_path_class):
"""Test _get_weight_filenames raises error when path doesn't exist."""
mock_path = MagicMock()
mock_path.exists.return_value = False
mock_path_class.return_value = mock_path
with self.assertRaises(FileNotFoundError) as context:
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler.quantize = None
handler._get_weight_filenames(self.model_path, self.extension)
self.assertIn("not exists or not a directory", str(context.exception))
@patch('mindie_llm.runtime.utils.loader.weight_utils.Path')
def test_get_weight_filenames_not_directory(self, mock_path_class):
"""Test _get_weight_filenames raises error when path is not a directory."""
mock_path = MagicMock()
mock_path.exists.return_value = True
mock_path.is_dir.return_value = False
mock_path_class.return_value = mock_path
with self.assertRaises(FileNotFoundError) as context:
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler.quantize = None
handler._get_weight_filenames(self.model_path, self.extension)
self.assertIn("not exists or not a directory", str(context.exception))
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
@patch('mindie_llm.runtime.utils.loader.weight_utils.check_path_permission')
@patch('mindie_llm.runtime.utils.loader.weight_utils.standardize_path')
def test_load_weight_file_routing(self, mock_standardize_path, mock_check_permission, mock_safe_open):
"""Test _load_weight_file_routing."""
filenames = ["file1.safetensors", "file2.safetensors"]
mock_standardize_path.side_effect = lambda x, **kwargs: x
mock_file_handler1 = MagicMock()
mock_file_handler1.__enter__ = Mock(return_value=mock_file_handler1)
mock_file_handler1.__exit__ = Mock(return_value=False)
mock_file_handler1.keys.return_value = ["tensor1", "tensor2"]
mock_file_handler2 = MagicMock()
mock_file_handler2.__enter__ = Mock(return_value=mock_file_handler2)
mock_file_handler2.__exit__ = Mock(return_value=False)
mock_file_handler2.keys.return_value = ["tensor3", "tensor4"]
mock_safe_open.side_effect = [mock_file_handler1, mock_file_handler2]
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler._filenames = filenames
routing = handler._load_weight_file_routing()
self.assertEqual(mock_standardize_path.call_count, 2)
self.assertEqual(mock_check_permission.call_count, 2)
self.assertEqual(len(routing), 4)
self.assertEqual(routing["tensor1"], "file1.safetensors")
self.assertEqual(routing["tensor2"], "file1.safetensors")
self.assertEqual(routing["tensor3"], "file2.safetensors")
self.assertEqual(routing["tensor4"], "file2.safetensors")
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
@patch('mindie_llm.runtime.utils.loader.weight_utils.check_path_permission')
@patch('mindie_llm.runtime.utils.loader.weight_utils.standardize_path')
def test_load_weight_file_routing_duplicate_tensor(self, mock_standardize_path, mock_check_permission, mock_safe_open):
"""Test _load_weight_file_routing raises error for duplicate tensor."""
filenames = ["file1.safetensors", "file2.safetensors"]
mock_standardize_path.side_effect = lambda x, **kwargs: x
mock_file_handler1 = MagicMock()
mock_file_handler1.__enter__ = Mock(return_value=mock_file_handler1)
mock_file_handler1.__exit__ = Mock(return_value=False)
mock_file_handler1.keys.return_value = ["tensor1", "tensor2"]
mock_file_handler2 = MagicMock()
mock_file_handler2.__enter__ = Mock(return_value=mock_file_handler2)
mock_file_handler2.__exit__ = Mock(return_value=False)
mock_file_handler2.keys.return_value = ["tensor2", "tensor3"]
mock_safe_open.side_effect = [mock_file_handler1, mock_file_handler2]
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler._filenames = filenames
with self.assertRaises(ValueError) as context:
handler._load_weight_file_routing()
self.assertIn("Weight was found in multiple files", str(context.exception))
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
@patch('mindie_llm.runtime.utils.loader.weight_utils.check_path_permission')
@patch('mindie_llm.runtime.utils.loader.weight_utils.standardize_path')
def test_load_weight_file_routing_empty_files(self, mock_standardize_path, mock_check_permission, mock_safe_open):
"""Test _load_weight_file_routing with empty files."""
filenames = ["file1.safetensors"]
mock_standardize_path.side_effect = lambda x, **kwargs: x
mock_file_handler = MagicMock()
mock_file_handler.__enter__ = Mock(return_value=mock_file_handler)
mock_file_handler.__exit__ = Mock(return_value=False)
mock_file_handler.keys.return_value = []
mock_safe_open.return_value = mock_file_handler
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler._filenames = filenames
routing = handler._load_weight_file_routing()
self.assertEqual(len(routing), 0)
@patch('mindie_llm.runtime.utils.loader.weight_utils.safetensors.safe_open')
@patch('mindie_llm.runtime.utils.loader.weight_utils.check_path_permission')
@patch('mindie_llm.runtime.utils.loader.weight_utils.standardize_path')
def test_load_weight_file_routing_single_file(self, mock_standardize_path, mock_check_permission, mock_safe_open):
"""Test _load_weight_file_routing with single file."""
filenames = ["file1.safetensors"]
mock_standardize_path.side_effect = lambda x, **kwargs: x
mock_file_handler = MagicMock()
mock_file_handler.__enter__ = Mock(return_value=mock_file_handler)
mock_file_handler.__exit__ = Mock(return_value=False)
mock_file_handler.keys.return_value = ["tensor1", "tensor2", "tensor3"]
mock_safe_open.return_value = mock_file_handler
handler = WeightsFileHandler.__new__(WeightsFileHandler)
handler._filenames = filenames
routing = handler._load_weight_file_routing()
self.assertEqual(len(routing), 3)
self.assertEqual(routing["tensor1"], "file1.safetensors")
self.assertEqual(routing["tensor2"], "file1.safetensors")
self.assertEqual(routing["tensor3"], "file1.safetensors")
if __name__ == '__main__':
unittest.main()