"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
统一的 pytest 配置文件,包含所有 core 测试目录的通用配置。
"""
import sys
if sys.version_info < (3, 9):
import typing
import typing_extensions
typing.Annotated = typing_extensions.Annotated
import pytest
from unittest.mock import Mock, patch, MagicMock
from testing_utils.mock import mock_kia_library, mock_security_library, mock_init_config
mock_init_config()
mock_kia_library()
mock_security_library()
def _mock_check_dirpath_before_read(path):
"""Mock function for check_dirpath_before_read that bypasses validation"""
return path
if 'msmodelslim.utils.security.path' not in sys.modules:
sys.modules['msmodelslim.utils.security.path'] = MagicMock()
sys.modules['msmodelslim.utils.security.path'].check_dirpath_before_read = _mock_check_dirpath_before_read
if 'wcmatch' not in sys.modules:
sys.modules['wcmatch'] = MagicMock()
@pytest.fixture
def sample_torch_tensor():
"""标准 float tensor,供 observer/quantizer 测试复用。"""
import torch
return torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@pytest.fixture
def mock_dataset_loader():
"""返回固定校准数据的 DatasetLoader mock。"""
loader = Mock()
loader.get_dataset_by_name.return_value = [{"input_ids": [1, 2, 3]}]
return loader
@pytest.fixture
def mock_context_factory():
"""返回 LocalDictContext 的 IContextFactory mock。"""
from msmodelslim.core.context.context_factory import ContextFactory
factory = ContextFactory()
mock_factory = Mock()
mock_factory.create.return_value = factory.create(is_distributed=False)
return mock_factory
@pytest.fixture
def mock_pipeline_interface():
"""最小 PipelineInterface stub。"""
adapter = Mock()
adapter.model_type = "test"
adapter.init_model.return_value = Mock()
adapter.handle_dataset.return_value = [[{"input_ids": [1]}]]
return adapter
@pytest.fixture
def mock_torch():
"""Mock torch库,确保不会误判NPU可用"""
with patch('torch') as patched_torch:
patched_torch.device.return_value = Mock()
patched_torch.manual_seed.return_value = None
mock_npu = Mock()
mock_npu.manual_seed.return_value = None
mock_npu.manual_seed_all.return_value = None
mock_npu.Stream.return_value = Mock()
mock_npu.set_compile_mode.return_value = None
mock_npu.is_available.return_value = False
patched_torch.npu = mock_npu
yield patched_torch