# -------------------------------------------------------------------------
# This file is part of the MindStudio project.
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
#
# MindStudio is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# -------------------------------------------------------------------------
# pylint: disable=duplicate-code

import unittest
from unittest.mock import patch
from pathlib import Path
from dataclasses import asdict
import shutil

from ms_serviceparam_optimizer.train.state_param import StateParam


class TestStateParam(unittest.TestCase):
    """测试 StateParam 配置类的功能"""

    def setUp(self):
        # 创建临时测试目录
        self.test_dir = Path("test_state_param")
        self.test_dir.mkdir(exist_ok=True)

    def tearDown(self):
        # 清理临时目录
        shutil.rmtree(self.test_dir)

    def test_initialization(self):
        # 创建配置实例
        param = StateParam(
            title="Test Config", base_path=Path("/tmp/test"), xgb_model_train_param={'learning_rate': 0.1}
        )

        # 验证自动创建的路径
        self.assertEqual(param.model_dir, Path("/tmp/test/model"))
        self.assertEqual(param.step_dir, Path("/tmp/test/step"))
        self.assertEqual(param.bak_dir, Path("/tmp/test/bak"))

        # 验证路径确实被创建
        self.assertTrue(param.model_dir.exists())
        self.assertTrue(param.step_dir.exists())

        # 清理测试目录
        param.model_dir.rmdir()
        param.step_dir.rmdir()
        param.bak_dir.rmdir()
        param.base_path.rmdir()

    def test_asdict(self):
        # 创建配置实例
        param = StateParam(title="Test Config", base_path=Path("/tmp/test"))

        # 转换为字典
        param_dict = asdict(param)

        # 验证转换结果
        self.assertEqual(param_dict['title'], "Test Config")
        self.assertEqual(param_dict['predict_field'], "model_execute_time")
        self.assertIsInstance(param_dict['model_dir'], Path)

    def test_default_initialization(self):
        """测试默认参数初始化"""
        sp = StateParam(base_path=self.test_dir)

        # 验证基本属性
        self.assertEqual(sp.title, "MixModel")
        self.assertEqual(sp.base_path, self.test_dir)
        self.assertEqual(sp.predict_field, "model_execute_time")
        self.assertTrue(sp.save_model)
        self.assertFalse(sp.shuffle)

        # 验证路径创建
        self.assertTrue(sp.model_dir.exists())
        self.assertTrue(sp.step_dir.exists())
        self.assertTrue(sp.bak_dir.exists())

        # 验证XGBoost模型路径
        self.assertEqual(sp.xgb_model_save_model_path, sp.model_dir / "xgb_model.ubj")
        self.assertEqual(sp.xgb_model_load_model_path, sp.model_dir / "xgb_model.ubj")

    def test_custom_initialization(self):
        """测试自定义参数初始化"""
        custom_sp = StateParam(
            base_path=self.test_dir / "custom",
            title="CustomModel",
            predict_field="custom_field",
            save_model=False,
            shuffle=False,
            plot_pred_and_real=False,
            op_algorithm="scale",
        )

        # 验证自定义属性
        self.assertEqual(custom_sp.title, "CustomModel")
        self.assertEqual(custom_sp.predict_field, "custom_field")
        self.assertFalse(custom_sp.save_model)
        self.assertFalse(custom_sp.shuffle)
        self.assertFalse(custom_sp.plot_pred_and_real)
        self.assertEqual(custom_sp.op_algorithm, "scale")

        # 验证自定义路径
        self.assertEqual(custom_sp.base_path, self.test_dir / "custom")
        self.assertTrue((self.test_dir / "custom").exists())

    @patch("ms_serviceparam_optimizer.train.state_param.Path.mkdir")
    def test_directory_creation_failure(self, mock_mkdir):
        """测试目录创建失败处理"""
        mock_mkdir.side_effect = OSError("Permission denied")

        with self.assertRaises(OSError):
            StateParam(base_path=self.test_dir / "invalid")

    def test_xgb_params(self):
        """测试XGBoost参数配置"""
        sp = StateParam(
            base_path=self.test_dir,
            xgb_model_train_param={"max_depth": 6, "eta": 0.3},
            xgb_model_update_param={"updater": "refresh"},
        )

        # 验证训练参数
        self.assertEqual(sp.xgb_model_train_param["max_depth"], 6)
        self.assertEqual(sp.xgb_model_train_param["eta"], 0.3)

        # 验证更新参数
        self.assertEqual(sp.xgb_model_update_param["updater"], "refresh")


if __name__ == '__main__':
    unittest.main()