#  -*- coding: utf-8 -*-
# -------------------------------------------------------------------------
# This file is part of the MindStudio project.
# Copyright (c) 2025-2026 Huawei Technologies Co.,Ltd.
#
# MindStudio is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          `http://license.coscl.org.cn/MulanPSL2`
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# -------------------------------------------------------------------------

import pytest
import os
from unittest.mock import Mock, patch

from msmodelslim.utils.cache.pth import (
    load_cached_data, InputCapture, DumperManager, to_device
)
from msmodelslim.utils.exception import SchemaValidateError


class TestLoadCachedData:
    """测试load_cached_data函数"""

    @pytest.fixture
    def mock_pth_file_path(self):
        """创建模拟的PTH文件路径"""
        return os.path.join("test", "cache", "calib_data.pth")

    @pytest.fixture
    def mock_generate_func(self):
        """创建模拟的生成函数"""
        return Mock()

    @pytest.fixture
    def mock_model(self):
        """创建模拟的模型"""
        return Mock()

    @pytest.fixture
    def mock_dump_config(self):
        """创建模拟的dump配置"""
        config = Mock()
        config.capture_mode = "args"
        return config

    def test_load_cached_data_file_exists(self, mock_pth_file_path, mock_generate_func, mock_model, mock_dump_config):
        """测试缓存文件存在时的加载"""
        # Mock文件存在
        with patch('os.path.exists', return_value=True):
            # Mock safe_torch_load
            with patch('msmodelslim.utils.cache.pth.safe_torch_load') as mock_load:
                mock_data = {"test": "data"}
                mock_load.return_value = mock_data

                # Mock get_valid_read_path
                with patch('msmodelslim.utils.cache.pth.get_valid_read_path') as mock_valid_path:
                    mock_valid_path.return_value = mock_pth_file_path

                    # Mock logger
                    with patch('msmodelslim.utils.cache.pth.get_logger') as mock_logger:
                        result = load_cached_data(
                            mock_pth_file_path,
                            mock_generate_func,
                            mock_model,
                            mock_dump_config
                        )

                        assert result == mock_data
                        mock_load.assert_called_once_with(mock_pth_file_path)

    def test_load_cached_data_file_not_exists(self, mock_pth_file_path, mock_generate_func, mock_model,
                                              mock_dump_config):
        """测试缓存文件不存在时的处理"""
        # Mock文件不存在
        with patch('os.path.exists', return_value=False):
            # Mock DumperManager
            with patch('msmodelslim.utils.cache.pth.DumperManager') as mock_dumper_class:
                mock_dumper = Mock()
                mock_dumper_class.return_value = mock_dumper

                # Mock safe_torch_load
                with patch('msmodelslim.utils.cache.pth.safe_torch_load') as mock_load:
                    mock_data = {"generated": "data"}
                    mock_load.return_value = mock_data

                    # Mock logger
                    with patch('msmodelslim.utils.cache.pth.get_logger') as mock_logger:
                        result = load_cached_data(
                            mock_pth_file_path,
                            mock_generate_func,
                            mock_model,
                            mock_dump_config
                        )

                        assert result == mock_data
                        # 验证生成函数被调用
                        mock_generate_func.assert_called_once()
                        # 验证dump管理器被创建和保存
                        mock_dumper_class.assert_called_once_with(mock_model, capture_mode="args")
                        mock_dumper.save.assert_called_once_with(mock_pth_file_path)


class TestInputCapture:
    """测试InputCapture类"""

    def test_input_capture_reset(self):
        """测试InputCapture的reset方法"""
        # 设置一些测试数据
        InputCapture.add_record({"test": "data"})
        assert len(InputCapture.get_all()) > 0

        # 重置
        InputCapture.reset()
        assert len(InputCapture.get_all()) == 0

    def test_input_capture_add_and_get_record(self):
        """测试InputCapture的add_record和get_all方法"""
        InputCapture.reset()

        # 添加记录
        test_record = {"test_key": "test_value"}
        InputCapture.add_record(test_record)

        # 获取所有记录
        all_records = InputCapture.get_all()
        assert len(all_records) == 1
        assert all_records[0] == test_record

    def test_input_capture_capture_forward_inputs_args_mode(self):
        """测试InputCapture的capture_forward_inputs方法(args模式)"""
        InputCapture.reset()

        # 创建测试函数
        def test_function(arg1, arg2, kwarg1="default"):
            return arg1 + arg2

        # 应用装饰器
        wrapped_func = InputCapture.capture_forward_inputs(test_function, capture_mode="args")

        # 调用函数
        result = wrapped_func(10, 20, kwarg1="custom")

        # 验证结果
        assert result == 30

        # 验证捕获的数据
        captured = InputCapture.get_all()
        assert len(captured) == 1
        # 注意:args模式只捕获位置参数,不包含关键字参数
        # 由于Mock对象的特性,这里只验证捕获了数据,不验证具体内容
        assert len(captured[0]) >= 2  # 至少应该有两个位置参数

    def test_input_capture_capture_forward_inputs_method(self):
        """测试InputCapture的capture_forward_inputs方法(方法调用)"""
        InputCapture.reset()

        # 创建测试类
        class TestClass:
            def test_method(self, arg1, arg2):
                return arg1 + arg2

        # 应用装饰器
        wrapped_method = InputCapture.capture_forward_inputs(TestClass.test_method, capture_mode="args")

        # 创建实例并调用方法
        obj = TestClass()
        result = wrapped_method(obj, 15, 25)

        # 验证结果
        assert result == 40

        # 验证捕获的数据(不包含self)
        captured = InputCapture.get_all()
        assert len(captured) == 1
        # 由于Mock对象的特性,这里只验证捕获了数据,不验证具体内容
        assert len(captured[0]) >= 2  # 至少应该有两个位置参数

    def test_input_capture_capture_forward_inputs_invalid_mode(self):
        """测试InputCapture的capture_forward_inputs方法使用无效模式"""

        def test_function():
            pass

        # InputCapture.capture_forward_inputs本身不验证capture_mode
        # 验证是在DumperManager构造函数中进行的
        # 这里测试装饰器能正常应用,即使使用无效模式
        wrapped_func = InputCapture.capture_forward_inputs(test_function, capture_mode="invalid_mode")
        assert wrapped_func is not None
        assert callable(wrapped_func)


class TestDumperManager:
    """测试DumperManager类"""

    @pytest.fixture
    def mock_module(self):
        """创建模拟的模块"""
        return Mock()

    @pytest.fixture
    def mock_dump_config(self):
        """创建模拟的dump配置"""
        config = Mock()
        config.capture_mode = "args"
        return config

    def test_dumper_manager_initialization(self, mock_module, mock_dump_config):
        """测试DumperManager的初始化"""
        dumper = DumperManager(mock_module, capture_mode="args")

        assert dumper.module is mock_module
        assert dumper.capture_mode == "args"
        assert dumper.old_forward is not None

    def test_dumper_manager_initialization_invalid_capture_mode(self, mock_module):
        """测试DumperManager使用无效capture_mode时的初始化"""
        with pytest.raises(SchemaValidateError, match="Invalid capture_mode: 'invalid_mode'"):
            DumperManager(mock_module, capture_mode="invalid_mode")

    def test_dumper_manager_save(self, mock_module, mock_dump_config):
        """测试DumperManager的save方法"""
        dumper = DumperManager(mock_module, capture_mode="args")

        # 添加一些测试数据
        InputCapture.add_record({"test": "data"})

        # Mock torch.save
        with patch('msmodelslim.utils.cache.pth.torch.save') as mock_torch_save:
            # Mock logger
            with patch('msmodelslim.utils.cache.pth.get_logger') as mock_logger:
                result = dumper.save("/test/output.pth")

                # 验证torch.save被调用
                mock_torch_save.assert_called_once()

                # 验证原始forward方法被恢复
                # 注意:由于Mock对象的特性,这里需要检查是否调用了恢复逻辑
                # 而不是直接比较对象引用
                assert dumper.old_forward is None  # 验证old_forward被重置

    def test_dumper_manager_reset(self, mock_module, mock_dump_config):
        """测试DumperManager的reset方法"""
        dumper = DumperManager(mock_module, capture_mode="args")

        # 添加一些测试数据
        InputCapture.add_record({"test": "data"})
        assert len(InputCapture.get_all()) > 0

        # 重置
        dumper.reset()
        assert len(InputCapture.get_all()) == 0

    def test_dumper_manager_add_hook(self, mock_module, mock_dump_config):
        """测试DumperManager的_add_hook方法"""
        dumper = DumperManager(mock_module, capture_mode="args")

        # 验证hook被添加
        assert mock_module.forward != dumper.old_forward
        assert hasattr(mock_module.forward, '__wrapped__')


class TestToDevice:
    """测试to_device函数"""

    @pytest.fixture
    def mock_torch(self):
        """Mock torch库"""
        with patch('msmodelslim.utils.cache.pth.torch') as mock_torch:
            mock_torch.Tensor = Mock
            yield mock_torch

    def test_to_device_dict(self, mock_torch):
        """测试to_device处理字典类型数据"""
        test_dict = {"key1": Mock(), "key2": Mock()}

        with patch('msmodelslim.utils.cache.pth.to_device') as mock_to_device:
            mock_to_device.return_value = "device_data"

            result = to_device(test_dict, "cpu")

            # 验证递归调用
            assert mock_to_device.call_count >= 2

    def test_to_device_list(self, mock_torch):
        """测试to_device处理列表类型数据"""
        test_list = [Mock(), Mock()]

        with patch('msmodelslim.utils.cache.pth.to_device') as mock_to_device:
            mock_to_device.return_value = "device_data"

            result = to_device(test_list, "cpu")

            # 验证递归调用
            assert mock_to_device.call_count >= 2

    def test_to_device_tuple(self, mock_torch):
        """测试to_device处理元组类型数据"""
        test_tuple = (Mock(), Mock())

        with patch('msmodelslim.utils.cache.pth.to_device') as mock_to_device:
            mock_to_device.return_value = "device_data"

            result = to_device(test_tuple, "cpu")

            # 验证递归调用
            assert mock_to_device.call_count >= 2

    def test_to_device_tensor(self, mock_torch):
        """测试to_device处理张量类型数据"""
        mock_tensor = Mock()
        mock_tensor.to.return_value = "moved_tensor"

        result = to_device(mock_tensor, "cpu")

        assert result == "moved_tensor"
        mock_tensor.to.assert_called_once_with("cpu")

    def test_to_device_other_types(self, mock_torch):
        """测试to_device处理其他类型数据"""
        test_data = "string_data"
        result = to_device(test_data, "cpu")
        assert result == test_data

    def test_to_device_recursion_depth_limit(self, mock_torch):
        """测试to_device的递归深度限制"""
        # 创建嵌套过深的数据结构
        deep_data = {}
        current = deep_data
        for i in range(25):  # 超过MAX_RECURSION_DEPTH (20)
            current["nested"] = {}
            current = current["nested"]

        with pytest.raises(RecursionError, match="Maximum recursion depth 20 exceeded"):
            to_device(deep_data, "cpu")