# -------------------------------------------------------------------------
# This file is part of the MindStudio project.
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
#
# MindStudio is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# -------------------------------------------------------------------------
import unittest
from unittest.mock import patch, MagicMock
import pandas as pd
from ms_serviceparam_optimizer.inference.utils import (
    PreprocessTool,
    OperatorProcessingConfig,
    _preprocess_dataframe,
    RowData,
    OpData,
)


class TestPreprocessTool(unittest.TestCase):
    def setUp(self):
        # 测试用请求数据
        self.sample_request_row = (("req1", 10, 2, 5), ("req2", 15, 3, 8))
        self.request_columns = ("request_id", "input_length", "block_count", "output_length")

        # 测试用算子数据
        self.sample_op_row = (
            ("MatMul", "2", "float32;float32", "100,200;50,60", "200,100", "0.5", "0.1"),
            ("Add", "1", "float32", "30,40", "30,40", "0.2", "0.05"),
        )
        self.op_columns = (
            "op_name",
            "call_count",
            "input_dtype",
            "input_shape",
            "output_shape",
            "execute_delta",
            "execute_delta_ratio",
        )

    def test_generate_data_basic(self):
        """测试基础数据生成"""
        row = ("1", "2.5", "text")
        columns = ("id", "value", "desc")

        result, new_columns = PreprocessTool.generate_data(row, columns)

        self.assertEqual(result, (1.0, 2.5, "text"))
        self.assertEqual(new_columns, columns)

    def test_generate_data_with_request_info_by_df(self):
        """测试DataFrame模式的请求信息生成"""
        # 构造
        row = ((1015, 19, 0),)
        columns = ("input_length", "need_blocks", "output_length")

        result, new_columns = PreprocessTool.generate_data_with_request_info_by_df(row, columns)

        self.assertEqual(result[0], 0.0)  # 总输出长度
        self.assertIn("total_output_length", new_columns)

    @patch('ms_serviceparam_optimizer.inference.utils.PreprocessTool.get_op_in_origin_row_index')
    @patch('ms_serviceparam_optimizer.inference.utils.PreprocessTool.get_all_op_input_ratio')
    @patch('ms_serviceparam_optimizer.inference.utils.PreprocessTool.get_all_op_execute_delta_ratio')
    @patch('ms_serviceparam_optimizer.inference.utils.PreprocessTool.get_label_hist_value')
    def test_generate_data_with_op_info_use_ratio_empty_input(
        self,
        mock_get_label_hist_value,
        mock_get_all_op_execute_delta_ratio,
        mock_get_all_op_input_ratio,
        mock_get_op_in_origin_row_index,
    ):
        """测试算子信息生成"""
        # 模拟输入
        origin_row = ()
        origin_index = ()

        # 模拟返回值
        mock_get_op_in_origin_row_index.return_value = {}
        mock_get_all_op_input_ratio.return_value = {}
        mock_get_all_op_execute_delta_ratio.return_value = {}
        mock_get_label_hist_value.return_value = {}

        # 调用方法
        radio_result = PreprocessTool.generate_data_with_op_info_use_ratio(origin_row, origin_index)
        result = PreprocessTool.generate_data_with_op_info(origin_row, origin_index)
        # 验证结果
        self.assertEqual(sum(radio_result[0]), 0)
        self.assertEqual(sum(result[0]), 0)

    def test_generate_data_with_struct_info(self):
        """测试结构信息转换(比率放大)"""
        row = (0.15, 100)
        columns = ("hit_rate", "count")

        result, new_columns = PreprocessTool.generate_data_with_struct_info(row, columns)

        self.assertEqual(result[0], 15)  # 0.15 -> 15%
        self.assertEqual(result[1], 100.0)

    def test_generate_data_with_model_config(self):
        """测试模型配置信息生成"""
        row = (
            {"kv_quant_type": "int8", "group_size": 64},  # quantization_config
            ["Transformer"],  # architectures
            True,  # quantize
        )
        columns = ("quantization_config", "architectures", "quantize")

        result, new_columns = PreprocessTool.generate_data_with_model_config(row, columns)

        # 验证量化配置
        self.assertEqual(result[new_columns.index("kv_quant_type")], "int8")

    def test_get_all_op_input_ratio_empty_input(self):
        # 测试空输入的情况
        input_data = ()
        input_index = ("op_name", "call_count", "input_shape")
        result = PreprocessTool.get_all_op_input_ratio(input_data, input_index)
        self.assertEqual(result, {})

    def test_get_all_op_input_ratio_single_input(self):
        # 测试单个输入的情况
        input_data = (("conv2d", "1", "1,2,3"),)
        input_index = ("op_name", "call_count", "input_shape")
        result = PreprocessTool.get_all_op_input_ratio(input_data, input_index)
        self.assertEqual(result, {"conv2d": {0: [100.0]}})

    def test_get_all_op_input_ratio_multiple_inputs(self):
        # 测试多个输入的情况
        input_data = (("conv2d", "2", "1,2,3"), ("maxpool", "1", "2,3"))
        input_index = ("op_name", "call_count", "input_shape")
        result = PreprocessTool.get_all_op_input_ratio(input_data, input_index)
        self.assertEqual(result, {"conv2d": {0: [50.0, 50.0]}, 'maxpool': {0: [50.0]}})

    def test_generate_data_with_request_info_by_df_negative_value(self):
        """测试负值检查机制"""
        row = ((-1, -2, -3), (-4, -5, -6))  # 添加output_length列的负值
        column = ("output_length", "hist1", "hist2")
        with self.assertRaises(ValueError):
            PreprocessTool.generate_data_with_request_info_by_df(row, column)


class TestGenerateDataWithRequestInfo(unittest.TestCase):
    def setUp(self):
        self.output_length_field = "output_length"
        self.total_output_length = "total_output_length"

    def test_generate_data_with_request_info_normal_case(self):
        # 准备测试数据
        row = (
            MagicMock(**{self.output_length_field: 5}),
            MagicMock(**{self.output_length_field: 15}),
            MagicMock(**{self.output_length_field: 25}),
        )
        column = (self.output_length_field, "input_length", "need_blocks")

        # 模拟get_field_bins_count的返回值
        with patch('ms_serviceparam_optimizer.inference.common.get_field_bins_count') as mock_get_field_bins_count:
            mock_get_field_bins_count.side_effect = [
                [0, 2, 1],  # 对于input_length的返回值
                [1, 0, 1],  # 对于need_blocks的返回值
                [0, 4, 1],  # 对于output_length的返回值
            ]

            # 调用方法
            result = PreprocessTool.generate_data_with_request_info(row, column)

            # 验证结果
            self.assertEqual(sum(result[0]), 54)

    def test_generate_data_with_request_info_empty_row(self):
        # 测试空row的情况
        row = ()
        column = (self.output_length_field, "input_length", "need_blocks")

        # 预期行为:返回空的元组
        result = PreprocessTool.generate_data_with_request_info(row, column)
        self.assertEqual(sum(result[0]), 0)

    def test_generate_data_with_request_info_empty_column(self):
        # 测试空column的情况
        row = (
            MagicMock(**{self.output_length_field: 5}),
            MagicMock(**{self.output_length_field: 15}),
            MagicMock(**{self.output_length_field: 25}),
        )
        column = ()

        # 预期行为:返回空的元组
        result = PreprocessTool.generate_data_with_request_info(row, column)
        self.assertEqual(result[0], ())
        self.assertEqual(result[1], ())

    def test_get_all_op_execute_delta_ratio_empty_input(self):
        # 测试空输入
        input_data = ()
        input_index = ("op_name", "call_count", "execute_delta")
        result = PreprocessTool.get_all_op_execute_delta_ratio(input_data, input_index)
        self.assertEqual(result, {})

    def test_get_all_op_execute_delta_ratio_single_input(self):
        # 测试单个输入
        input_data = (("op1", 1, 1.0),)
        input_index = ("op_name", "call_count", "execute_delta")
        result = PreprocessTool.get_all_op_execute_delta_ratio(input_data, input_index)
        self.assertEqual(result, {"op1": [1.0]})

    def test_get_all_op_execute_delta_ratio_multiple_inputs(self):
        # 测试多个输入
        input_data = (("op1", 2, 2.0), ("op2", 1, 1.0))
        input_index = ("op_name", "call_count", "execute_delta")
        result = PreprocessTool.get_all_op_execute_delta_ratio(input_data, input_index)
        self.assertEqual(result, {'op1': [0.6666666666666666, 0.6666666666666666], 'op2': [0.3333333333333333]})

    def test_get_all_op_execute_delta_ratio_same_op_name(self):
        # 测试相同操作名
        input_data = (("op1", 2, 2.0), ("op1", 1, 1.0))
        input_index = ("op_name", "call_count", "execute_delta")
        result = PreprocessTool.get_all_op_execute_delta_ratio(input_data, input_index)
        self.assertEqual(result, {'op1': [0.6666666666666666, 0.6666666666666666, 0.3333333333333333]})

    def test_get_all_op_execute_delta_ratio_different_field(self):
        # 测试不同的字段
        input_data = (("op1", 2, 2.0), ("op2", 1, 1.0))
        input_index = ("op_name", "call_count", "execute_delta")
        result = PreprocessTool.get_all_op_execute_delta_ratio(input_data, input_index, field="execute_delta")
        self.assertEqual(result, {'op1': [0.6666666666666666, 0.6666666666666666], 'op2': [0.3333333333333333]})


class TestProcessOperatorInfo(unittest.TestCase):
    def setUp(self):
        self.config = OperatorProcessingConfig(
            origin_row=[],
            origin_index=[],
            op_index_on_origin_rows=[],
            dtype_category=[],
            op_input_param_expected={},
            op_output_expected={},
            op_execute_delta_field=[],
            op_delta_expected={},
            op="",
        )

    def test_op_name(self):
        new_row = []
        result = PreprocessTool.process_operator_info("op_name", "value", self.config, new_row)
        self.assertEqual(result, [1])

    def test_call_count(self):
        self.config.origin_row = [["10"], ["20"]]
        self.config.origin_index = ["call_count"]
        self.config.op_index_on_origin_rows = [0, 1]
        new_row = []
        result = PreprocessTool.process_operator_info("call_count", "value", self.config, new_row)
        self.assertEqual(result, [30])

    def test_input_dtype(self):
        self.config.origin_row = [["input_dtype", "int;float"], ["input_dtype", "float"]]
        self.config.origin_index = ["input_dtype"]
        self.config.op_index_on_origin_rows = [0, 1]
        self.config.dtype_category = ["int", "float"]
        new_row = []
        result = PreprocessTool.process_operator_info("input_dtype__int", "value", self.config, new_row)
        self.assertEqual(result, [0])

    def test_input_size(self):
        self.config.op_input_param_expected = {0: "10"}
        new_row = []
        result = PreprocessTool.process_operator_info("input_size__0", "value", self.config, new_row)
        self.assertEqual(result, ["10"])

    def test_output_dtype(self):
        self.config.origin_row = [["output_dtype", "int;float"], ["output_dtype", "float"]]
        self.config.origin_index = ["output_dtype"]
        self.config.op_index_on_origin_rows = [0, 1]
        self.config.dtype_category = ["int", "float"]
        new_row = []
        result = PreprocessTool.process_operator_info("output_dtype__int", "value", self.config, new_row)
        self.assertEqual(result, [0])

    def test_output_size(self):
        self.config.op_output_expected = {0: "10"}
        new_row = []
        result = PreprocessTool.process_operator_info("output_size__0", "value", self.config, new_row)
        self.assertEqual(result, ["10"])

    def test_op_execute_delta_field(self):
        self.config.op_execute_delta_field = ["field"]
        self.config.op_delta_expected = {"field": {"op": "10"}}
        self.config.op = "op"
        new_row = []
        result = PreprocessTool.process_operator_info("field", "value", self.config, new_row)
        self.assertEqual(result, ["10"])


class TestPreprocessDataFrame(unittest.TestCase):
    def setUp(self):
        self.df = pd.DataFrame(
            {
                'A': ["1", "2", "3"],
                'B': ["4", "5", "6"],
            }
        )

    def test_all_columns_valid(self):
        # 测试所有检查的列都有效的情况
        result = _preprocess_dataframe(self.df, check_columns=['A', 'B'])
        self.assertTrue(result)

    def test_some_columns_invalid(self):
        # 测试一些检查的列包含无效值的情况
        self.df.loc[0, 'A'] = "+A3"  # 添加一个无效值
        result = _preprocess_dataframe(self.df, check_columns=['A', 'B'])
        self.assertFalse(result)

    @patch('ms_serviceparam_optimizer.inference.utils.logger')
    def test_warning_called(self, mock_warning):
        # 测试当列包含无效值时,logger.warning 是否被调用
        self.df.loc[0, 'A'] = "=A2+A3"  # 添加一个无效值
        _preprocess_dataframe(self.df, check_columns=['A', 'B'])
        mock_warning.warning.assert_called_once_with("Column A contains malicious values")


class TestProcessRowData(unittest.TestCase):
    def setUp(self):
        # 创建模拟数据
        self.row_data = RowData(
            origin_row=[["input_dtype", "int32;float32"], ["output_dtype", "int32;float32"]],
            origin_index=["input_dtype", "output_dtype"],
            op_index_on_origin_rows=[0, 1],
            dtype_category=["int32", "float32"],
        )

        self.op_data = OpData(
            op="test_op",
            op_input_param_hist_ratio={"test_op": {"0__input_size__1": 0}},
            op_output_hist_ratio={"test_op": {"0__output_size__1": 0}},
            op_delta_hist_ratio={"test_op": {"0__execute_delta__1": 0}},
        )

    def test_op_name(self):
        new_row = []
        self.row_data.dtype_category = ["int32", "float32"]
        self.row_data.op_index_on_origin_rows = [0, 1]
        self.row_data.origin_row = [["input_dtype", "int32;float32"], ["output_dtype", "int32;float32"]]
        self.row_data.origin_index = ["input_dtype", "output_dtype"]
        PreprocessTool.process_row_data("op_name", self.row_data, self.op_data, new_row)
        self.assertEqual(new_row, [1])

    def test_input_dtype(self):
        new_row = []
        PreprocessTool.process_row_data("input_dtype__int32", self.row_data, self.op_data, new_row)
        self.assertEqual(new_row, [0])

    def test_input_size(self):
        new_row = []
        PreprocessTool.process_row_data("input_size__0__input_size__1", self.row_data, self.op_data, new_row)
        self.assertEqual(new_row, [0])

    def test_output_dtype(self):
        new_row = []
        PreprocessTool.process_row_data("output_dtype__int32", self.row_data, self.op_data, new_row)
        self.assertEqual(new_row, [2])

    def test_output_size(self):
        new_row = []
        PreprocessTool.process_row_data("output_size__0__output_size__1", self.row_data, self.op_data, new_row)
        self.assertEqual(new_row, [0])

    def test_execute_delta(self):
        new_row = []
        PreprocessTool.process_row_data("execute_delta__0__execute_delta__1", self.row_data, self.op_data, new_row)
        self.assertEqual(new_row, [0])