import sys
from unittest.mock import MagicMock, patch, Mock
import unittest
import threading
from collections import namedtuple
import torch
from mindie_llm.runtime.utils.distributed.model_cache_pool import (
ModelCachePool, GROUP_KEY
)
from mindie_llm.runtime.utils.cache_spec import CacheGroupInfo, CacheType
sys.modules['torch_npu'] = MagicMock()
class TestModelCachePool(unittest.TestCase):
def setUp(self):
"""Reset singleton state for test isolation"""
ModelCachePool._instance = None
ModelCachePool._lock = threading.Lock()
def _create_mock_attn(self, shape=(16, 32, 128), ratio=1.0, cache_type=CacheType.TOKEN,
dtype=torch.float16, acl_format=29):
"""Create mock attention object with required interface"""
mock_attn = Mock()
mock_attn.get_cache_spec.return_value = Mock(
ratio=[ratio],
type=[cache_type],
dtype=[dtype],
shape=[shape],
format=acl_format
)
mock_attn.bind_model_cache = Mock()
mock_attn.clear = Mock()
return mock_attn
def test_singleton_pattern(self):
"""Test singleton pattern"""
instance1 = ModelCachePool()
instance2 = ModelCachePool()
self.assertIs(instance1, instance2)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
def test_initialize_and_update_caches_info(self, mock_get_attn_dict):
"""Test initialize and _update_caches_info"""
mock_attn = self._create_mock_attn()
mock_get_attn_dict.return_value = {'layer_0': mock_attn}
pool = ModelCachePool()
pool.initialize(Mock(), device='cpu', max_batch_size=8)
self.assertTrue(pool.initialized)
self.assertEqual(len(pool._groups), 1)
self.assertEqual(len(pool._caches), 1)
self.assertEqual(pool._max_batch_size, 8)
def test_get_groups_info(self):
"""Test get_groups_info"""
pool = ModelCachePool()
pool._groups = {
GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN):
CacheGroupInfo(ratio=1.0, block_size=16, type=CacheType.TOKEN, num_blocks=100)
}
groups = pool.get_groups_info()
self.assertEqual(len(groups), 1)
self.assertEqual(groups[0].num_blocks, 100)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.torch_npu.empty_with_format')
def test_warmup_device_cache(self, mock_empty, mock_get_attn_dict):
"""Test warmup_device_cache"""
mock_attn = self._create_mock_attn()
mock_get_attn_dict.return_value = {'layer_0': mock_attn}
mock_tensor = MagicMock()
mock_tensor.data_ptr.return_value = 12345
mock_tensor.fill_ = MagicMock(return_value=mock_tensor)
mock_empty.return_value = mock_tensor
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._device_caches_addrs = []
pool._caches = [[(100, (16, 32, 128), 0, torch.float16, 29)]]
pool._group_keys = [GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN)]
pool._groups = {
pool._group_keys[0]: CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=0
)
}
pool._device = 'cpu'
cache_size = pool.warmup_device_cache(device_mem=1024*1024*1024)
self.assertGreater(cache_size, 0)
self.assertEqual(len(pool._device_caches_addrs), 1)
mock_attn.bind_model_cache.assert_called_once()
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
def test_cal_num_blocks_with_token(self, mock_get_attn_dict):
"""Test _cal_num_blocks with TOKEN cache type"""
mock_get_attn_dict.return_value = {}
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._groups = {
GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN):
CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=0
)
}
pool._cal_num_blocks(1024 * 1024 * 1024)
group = pool._groups[GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN)]
self.assertGreater(group.num_blocks, 0)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
def test_calculate_groups_info(self, mock_get_attn_dict):
"""Test calculate_groups_info"""
mock_get_attn_dict.return_value = {}
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._groups = {
GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN):
CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=50
)
}
groups_info = pool.calculate_groups_info(device_mem=1024*1024*1024)
self.assertEqual(len(groups_info), 1)
self.assertGreater(groups_info[0].num_blocks, 0)
self.assertEqual(pool._groups[GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN)].num_blocks, 50)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.torch_npu.empty_with_format')
def test_allocate_device_cache_flow(self, mock_empty, mock_get_attn_dict):
"""Test allocate_device_cache"""
mock_attn = self._create_mock_attn()
mock_get_attn_dict.return_value = {'layer_0': mock_attn}
mock_tensor = MagicMock()
mock_tensor.data_ptr.return_value = 54321
mock_tensor.fill_ = MagicMock(return_value=mock_tensor)
mock_empty.return_value = mock_tensor
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._device_caches_addrs = []
pool._caches = [[(0, (16, 32, 128), 0, torch.float16, 2)]]
pool._group_keys = [GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN)]
pool._groups = {
pool._group_keys[0]: CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=0
)
}
pool._device = 'cpu'
pool.allocate_device_cache(device_mem=1024*1024*1024, is_dmi=False)
updated_num_blocks = pool._caches[0][0][0]
self.assertGreater(updated_num_blocks, 0)
mock_empty.assert_called_once()
mock_attn.bind_model_cache.assert_called_once()
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
def test_clear_method(self, mock_get_attn_dict):
"""Test _clear method"""
mock_attn1 = self._create_mock_attn()
mock_attn2 = self._create_mock_attn()
mock_get_attn_dict.return_value = {'layer_0': mock_attn1, 'layer_1': mock_attn2}
pool = ModelCachePool()
pool._clear()
mock_attn1.clear.assert_called_once()
mock_attn2.clear.assert_called_once()
def test_get_caches_info(self):
"""Test get_caches_info"""
pool = ModelCachePool()
pool._caches = [[(100, (16, 32, 128), 0, torch.float16, 29)]]
result = pool.get_caches_info()
self.assertEqual(result, pool._caches)
def test_get_caches_addrs(self):
"""Test get_caches_addrs"""
pool = ModelCachePool()
pool._device_caches_addrs = [[12345, 67890]]
result = pool.get_caches_addrs()
self.assertEqual(result, pool._device_caches_addrs)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
def test_cal_num_blocks_sliding_window_and_sequence(self, mock_get_attn_dict):
"""Test _cal_num_blocks with SLIDING_WINDOW and SEQUENCE types"""
mock_get_attn_dict.return_value = {}
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._groups = {
GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN):
CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=0
),
GROUP_KEY(ratio=1.0, block_size=32, type=CacheType.SLIDING_WINDOW):
CacheGroupInfo(
ratio=1.0, block_size=32, type=CacheType.SLIDING_WINDOW,
bytes_of_blocks=32*16*64*2,
num_blocks=0
),
GROUP_KEY(ratio=1.0, block_size=64, type=CacheType.SEQUENCE):
CacheGroupInfo(
ratio=1.0, block_size=64, type=CacheType.SEQUENCE,
bytes_of_blocks=64*8*32*2,
num_blocks=0
)
}
pool._cal_num_blocks(1024 * 1024 * 1024)
token_group = pool._groups[GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN)]
sliding_group = pool._groups[GROUP_KEY(ratio=1.0, block_size=32, type=CacheType.SLIDING_WINDOW)]
sequence_group = pool._groups[GROUP_KEY(ratio=1.0, block_size=64, type=CacheType.SEQUENCE)]
self.assertGreater(token_group.num_blocks, 0)
self.assertEqual(sliding_group.num_blocks, 12 * 8 + 2)
self.assertEqual(sequence_group.num_blocks, 8 + 2)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
def test_calculate_groups_info_with_sliding_window(self, mock_get_attn_dict):
"""Test calculate_groups_info with SLIDING_WINDOW"""
mock_get_attn_dict.return_value = {}
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._groups = {
GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN):
CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=50
),
GROUP_KEY(ratio=1.0, block_size=32, type=CacheType.SLIDING_WINDOW):
CacheGroupInfo(
ratio=1.0, block_size=32, type=CacheType.SLIDING_WINDOW,
bytes_of_blocks=32*16*64*2,
num_blocks=100
)
}
groups_info = pool.calculate_groups_info(device_mem=1024*1024*1024)
self.assertEqual(len(groups_info), 2)
sliding_info = next(g for g in groups_info if g.type == CacheType.SLIDING_WINDOW)
self.assertEqual(sliding_info.num_blocks, 12 * 8)
self.assertEqual(pool._groups[GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN)].num_blocks, 50)
self.assertEqual(pool._groups[GROUP_KEY(ratio=1.0, block_size=32, type=CacheType.SLIDING_WINDOW)].num_blocks, 100)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
def test_cal_num_blocks_oom_exception(self, mock_get_attn_dict):
"""Test OOM exception in _cal_num_blocks"""
mock_get_attn_dict.return_value = {}
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._groups = {
GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN):
CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=0
)
}
with self.assertRaises(RuntimeError) as cm:
pool._cal_num_blocks(device_mem=100)
self.assertIn("Npu out of memory", str(cm.exception))
self.assertIn("negative number", str(cm.exception))
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.torch.empty')
def test_create_aligned_tensor_nd_format(self, mock_torch_empty, mock_get_attn_dict):
"""Test _create_aligned_tensor with ND format (format=2)"""
mock_tensor = MagicMock()
mock_tensor.data_ptr.return_value = 1048576
mock_tensor.__getitem__.return_value = mock_tensor
mock_tensor.contiguous.return_value = mock_tensor
mock_tensor.view.return_value = mock_tensor
mock_torch_empty.return_value = mock_tensor
pool = ModelCachePool()
pool.initialized = True
pool._device = 'cpu'
result = pool._create_aligned_tensor(
target_shape=(100, 16, 32, 128),
dtype=torch.float16,
device='cpu',
format=2
)
mock_torch_empty.assert_called_once()
self.assertEqual(result, mock_tensor)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.torch_npu.empty_with_format')
def test_create_aligned_tensor_nz_format(self, mock_empty_with_format, mock_get_attn_dict):
"""Test _create_aligned_tensor with NZ format (format != 2)"""
mock_tensor = MagicMock()
mock_empty_with_format.return_value = mock_tensor
pool = ModelCachePool()
pool.initialized = True
pool._device = 'cpu'
result = pool._create_aligned_tensor(
target_shape=(100, 16, 32, 128),
dtype=torch.float16,
device='cpu',
format=29
)
mock_empty_with_format.assert_called_once()
self.assertEqual(result, mock_tensor)
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.get_global_attn_dict')
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.torch.empty')
@patch('mindie_llm.runtime.utils.distributed.model_cache_pool.torch_npu.empty_with_format')
def test_allocate_device_cache_with_dmi(self, mock_npu_empty, mock_torch_empty, mock_get_attn_dict):
"""Test allocate_device_cache with is_dmi=True"""
mock_attn = self._create_mock_attn()
mock_get_attn_dict.return_value = {'layer_0': mock_attn}
mock_tensor = MagicMock()
mock_tensor.data_ptr.return_value = 54321
mock_tensor.fill_ = MagicMock(return_value=mock_tensor)
mock_torch_empty.return_value = mock_tensor
pool = ModelCachePool()
pool.initialized = True
pool._max_batch_size = 8
pool._device_caches_addrs = []
pool._caches = [[(0, (16, 32, 128), 0, torch.float16, 2)]]
pool._group_keys = [GROUP_KEY(ratio=1.0, block_size=16, type=CacheType.TOKEN)]
pool._groups = {
pool._group_keys[0]: CacheGroupInfo(
ratio=1.0, block_size=16, type=CacheType.TOKEN,
bytes_of_blocks=16*32*128*2,
num_blocks=0
)
}
pool._device = 'cpu'
pool.allocate_device_cache(device_mem=1024*1024*1024, is_dmi=True)
mock_torch_empty.assert_called_once()
mock_npu_empty.assert_not_called()
mock_attn.bind_model_cache.assert_called_once()
if __name__ == '__main__':
unittest.main()