# Copyright (c) 2024, Huawei Technologies Co., Ltd.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from msprof_analyze.compare_tools.compare_backend.utils.excel_config import ExcelConfig
from msprof_analyze.compare_tools.compare_backend.utils.common_func import calculate_diff_ratio
from msprof_analyze.prof_common.constant import Constant


class CommunicationInfo:
    __slots__ = ['comm_op_name', 'task_name', 'calls', 'total_duration', 'avg_duration', 'max_duration', 'min_duration']

    def __init__(self, name: str, data_list: list, is_task: bool):
        self.comm_op_name = None
        self.task_name = None
        self.calls = None
        self.total_duration = 0
        self.avg_duration = None
        self.max_duration = None
        self.min_duration = None
        if data_list:
            self.comm_op_name = "|" if is_task else name
            self.task_name = name if is_task else None
            self.calls = len(data_list)
            self.total_duration = sum(data_list)
            self.avg_duration = sum(data_list) / len(data_list)
            self.max_duration = max(data_list)
            self.min_duration = min(data_list)


class CommunicationBean:
    __slots__ = ['_name', '_base_comm', '_comparison_comm']
    TABLE_NAME = Constant.COMMUNICATION_TABLE
    HEADERS = ExcelConfig.HEADERS.get(TABLE_NAME)
    OVERHEAD = ExcelConfig.OVERHEAD.get(TABLE_NAME)

    def __init__(self, name: str, base_comm_data: dict, comparison_comm_data: dict):
        self._name = name
        self._base_comm = base_comm_data
        self._comparison_comm = comparison_comm_data

    @property
    def rows(self):
        rows = []
        base_comm = CommunicationInfo(self._name, self._base_comm.get("comm_list", []), is_task=False)
        comparison_comm = CommunicationInfo(self._name, self._comparison_comm.get("comm_list", []), is_task=False)
        rows.append(self._get_row(base_comm, comparison_comm, is_task=False))

        base_task = self._base_comm.get("comm_task", {})
        comparison_task = self._comparison_comm.get("comm_task", {})
        if not base_task and not comparison_task:
            return rows

        for task_name, task_list in base_task.items():
            base_task_info = CommunicationInfo(task_name, task_list, is_task=True)
            comparison_task_info = CommunicationInfo("", [], is_task=True)
            for _task_name, _task_list in comparison_task.items():
                comparison_task_info = CommunicationInfo(_task_name, _task_list, is_task=True)
                comparison_task.pop(_task_name, None)
                break
            rows.append(self._get_row(base_task_info, comparison_task_info, is_task=True))
        for task_name, task_list in comparison_task.items():
            base_task_info = CommunicationInfo("", [], is_task=True)
            comparison_task_info = CommunicationInfo(task_name, task_list, is_task=True)
            rows.append(self._get_row(base_task_info, comparison_task_info, is_task=True))

        return rows

    @classmethod
    def _get_row(cls, base_info: CommunicationInfo, comparison_info: CommunicationInfo, is_task: bool) -> list:
        row = [
            None, base_info.comm_op_name, base_info.task_name, base_info.calls, base_info.total_duration,
            base_info.avg_duration, base_info.max_duration, base_info.min_duration, comparison_info.comm_op_name,
            comparison_info.task_name, comparison_info.calls, comparison_info.total_duration,
            comparison_info.avg_duration, comparison_info.max_duration, comparison_info.min_duration
        ]
        diff_fields = [None, None] if is_task else calculate_diff_ratio(base_info.total_duration,
                                                                        comparison_info.total_duration)
        row.extend(diff_fields)
        return row