import sys
import unittest
from unittest.mock import patch, MagicMock, PropertyMock
import torch
import mindie_llm.runtime.utils.weight_prefetcher as weight_prefetcher
if 'torch_npu' not in sys.modules:
sys.modules['torch_npu'] = MagicMock()
class TestWeightPrefetcher(unittest.TestCase):
"""
Test cases for WeightPrefetchMethod and helper functions.
Ensures all NPU interactions are properly mocked and logic flows
correctly under simulated conditions.
"""
def setUp(self) -> None:
"""
Set up test fixtures before each test method.
Resets global state and configures mocks for NPU dependencies.
"""
weight_prefetcher.weight_prefetcher.disable_weight_prefetch()
weight_prefetcher.weight_prefetcher.prefetch_weights = {}
self.patcher_torch_npu = patch(
'mindie_llm.runtime.utils.weight_prefetcher.torch_npu'
)
self.mock_torch_npu = self.patcher_torch_npu.start()
self.patcher_torch_npu_ns = patch(
'mindie_llm.runtime.utils.weight_prefetcher.torch.npu'
)
self.mock_torch_npu_ns = self.patcher_torch_npu_ns.start()
self.mock_stream = MagicMock()
self.mock_torch_npu_ns.Stream.return_value = self.mock_stream
self.mock_torch_npu_ns.current_stream.return_value = MagicMock()
def tearDown(self) -> None:
"""
Tear down test fixtures after each test method.
Stops all active patches to prevent side effects on other tests.
"""
self.patcher_torch_npu.stop()
self.patcher_torch_npu_ns.stop()
def test_enable_disable_prefetch(self) -> None:
"""
Test enabling and disabling the prefetch mechanism.
Verifies flag updates and stream initialization logic.
"""
prefetcher = weight_prefetcher.WeightPrefetchMethod()
self.assertFalse(prefetcher.is_prefetch_enabled())
prefetcher.enable_weight_prefetch()
self.assertTrue(prefetcher.is_prefetch_enabled())
self.mock_torch_npu_ns.Stream.assert_called_once()
prefetcher.disable_weight_prefetch()
self.assertFalse(prefetcher.is_prefetch_enabled())
self.assertIsNone(prefetcher.prefetch_stream)
def test_maybe_npu_prefetch_logic(self) -> None:
"""
Test the internal helper function _maybe_npu_prefetch.
Verifies size calculation and parameter passing to NPU interface.
"""
dummy_tensor = torch.randn(5, 5)
dummy_dep = torch.randn(1)
element_size = dummy_tensor.element_size()
numel = dummy_tensor.numel()
total_size = element_size * numel
weight_prefetcher._maybe_npu_prefetch(dummy_tensor, dummy_dep, max_size=0)
self.mock_torch_npu.npu_prefetch.assert_called_with(
dummy_tensor, dummy_dep, total_size, 0
)
self.mock_torch_npu.npu_prefetch.reset_mock()
weight_prefetcher._maybe_npu_prefetch(dummy_tensor, dummy_dep, max_size=total_size + 100)
self.mock_torch_npu.npu_prefetch.assert_called_with(
dummy_tensor, dummy_dep, total_size, 0
)
self.mock_torch_npu.npu_prefetch.reset_mock()
valid_size = total_size - 10
weight_prefetcher._maybe_npu_prefetch(dummy_tensor, dummy_dep, max_size=valid_size)
self.mock_torch_npu.npu_prefetch.assert_called_with(
dummy_tensor, dummy_dep, valid_size, 0
)
def test_prefetch_preprocess_stream_logic(self) -> None:
"""
Test _prefetch_preprocess function.
Verifies stream waiting and context manager usage.
"""
dummy_weight = torch.randn(10, 10)
dummy_flag = torch.randn(1)
calc_stream = MagicMock()
prefetch_stream = MagicMock()
self.mock_torch_npu_ns.current_stream.return_value = calc_stream
prefetch_stream.__enter__ = MagicMock(return_value=prefetch_stream)
prefetch_stream.__exit__ = MagicMock(return_value=None)
weight_prefetcher._prefetch_preprocess(
weight=dummy_weight,
start_flag=dummy_flag,
max_weight_size=1024,
weight_prefetch_stream=prefetch_stream
)
prefetch_stream.wait_stream.assert_called_once_with(calc_stream)
self.mock_torch_npu.npu_prefetch.assert_called()
def test_prefetch_weight_preprocess_integration(self) -> None:
"""
Test WeightPrefetchMethod.prefetch_weight_preprocess.
Verifies integration of size calculation and helper call.
"""
prefetcher = weight_prefetcher.WeightPrefetchMethod()
prefetcher.enable_weight_prefetch()
dummy_weight = torch.randn(10, 10)
dummy_flag = torch.randn(1)
ratio = 0.5
expected_size = int(dummy_weight.element_size() * dummy_weight.numel() * ratio)
prefetcher.prefetch_weight_preprocess(
weight=dummy_weight,
start_flag=dummy_flag,
ratio=ratio
)
self.mock_torch_npu.npu_prefetch.assert_called()
last_call_args = self.mock_torch_npu.npu_prefetch.call_args
called_max_size = last_call_args[0][2]
self.assertEqual(called_max_size, expected_size)
def test_prefetch_weight_postprocess(self) -> None:
"""
Test WeightPrefetchMethod.prefetch_weight_postprocess.
Verifies stream synchronization logic.
"""
prefetcher = weight_prefetcher.WeightPrefetchMethod()
prefetcher.enable_weight_prefetch()
calc_stream = MagicMock()
self.mock_torch_npu_ns.current_stream.return_value = calc_stream
prefetcher.prefetch_weight_postprocess()
calc_stream.wait_stream.assert_called_once_with(prefetcher.prefetch_stream)
def test_global_instance_state(self) -> None:
"""
Test the global weight_prefetcher instance behavior.
Ensures global state management works as expected across calls.
"""
self.assertFalse(weight_prefetcher.weight_prefetcher.is_prefetch_enabled())
weight_prefetcher.weight_prefetcher.enable_weight_prefetch()
self.assertTrue(weight_prefetcher.weight_prefetcher.is_prefetch_enabled())
weight_prefetcher.weight_prefetcher.disable_weight_prefetch()
if __name__ == '__main__':
unittest.main()