# Copyright (c) 2025, Huawei Technologies Co., Ltd
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from msprof_analyze.cluster_analyse.prof_bean.step_trace_time_bean import StepTraceTimeBean
from msprof_analyze.prof_common.utils import convert_to_float
from msprof_analyze.prof_common.file_manager import FileManager
from msprof_analyze.cluster_analyse.common_func.time_range_calculator import RangeCaculator
from msprof_analyze.prof_common.db_manager import DBManager
from msprof_analyze.prof_common.logger import get_logger
from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant
from msprof_analyze.cluster_analyse.common_func.time_range_calculator import CommunicationTimeRange
from msprof_analyze.prof_common.constant import Constant

logger = get_logger()


class MsprofStepTraceTimeAdapter:
    COMPUTE = "Computing"
    COMM_NOT_OVERLAP = "Communication(Not Overlapped)"
    OVERLAPPED = "Overlapped"
    COMMUNICATION = "Communication"
    FREE = "Free"
    STAGE = "Stage"
    BUBBLE = "Bubble"
    COMM_NOT_OVERLAP_EXCLUDE_RECEIVE = "Communication(Not Overlapped and Exclude Receive)"
    PREPARE = "Preparing"
    STEP = "Step"

    def __init__(self, file_path):
        self.file_path = file_path
        self._data = {self.STEP: None, self.COMPUTE: 0, self.COMM_NOT_OVERLAP: 0, self.OVERLAPPED: 0,
                      self.COMMUNICATION: 0, self.FREE: 0, self.STAGE: 0, self.BUBBLE: 0,
                      self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE: 0, self.PREPARE: 0}

    def generate_step_trace_time_data(self):
        json_str = []
        for file_path in self.file_path:
            json_str.extend(FileManager.read_json_file(file_path))
        receive_comm = []
        analysis_data = {}
        for data in json_str:
            event_name = data.get("name", "")
            if event_name in {self.COMMUNICATION, self.COMPUTE, self.FREE, self.COMM_NOT_OVERLAP}:
                analysis_data.setdefault(event_name, []).append(data)
            elif event_name.startswith('hcom_receive'):
                receive_comm.append(data)
        for event_type, event_list in analysis_data.items():
            self._data[event_type] = sum((convert_to_float(event.get("dur", 0)) for event in event_list))
        self._data[self.BUBBLE] = sum((convert_to_float(event.get("dur", 0)) for event in receive_comm))
        self._data[self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE] = self._data[self.COMM_NOT_OVERLAP] - self._data[self.BUBBLE]
        self._data[self.OVERLAPPED] = self._data[self.COMMUNICATION] - self._data[self.COMM_NOT_OVERLAP]
        e2e_time = self._data[self.FREE] + self._data[self.COMPUTE] + self._data[self.COMM_NOT_OVERLAP]
        self._data[self.STAGE] = e2e_time - self._data[self.BUBBLE]
        return [StepTraceTimeBean(self._data)]


class MsprofStepTraceTimeDBAdapter(MsprofStepTraceTimeAdapter):
    OP_NAME = 0
    START_NS = 1
    END_NS = 2

    def __init__(self, file_path):
        super().__init__(file_path)
        self.task_db_con = None
        self.task_db_curs = None
        self.string_id_map = None
        self.compute_task_info = None
        self.communication_op_info = None

    def generate_step_trace_time_data(self):
        try:
            self._init_task_info_from_db()
        except Exception as err:
            logger.error(err)
            DBManager.destroy_db_connect(self.task_db_con, self.task_db_curs)
            return []
        origin_compute_data = self._get_compute_data()
        origin_communication_data, bubble_data = self._get_communication_data()
        compute_data = RangeCaculator.merge_continuous_intervals(origin_compute_data)
        self._data[self.COMPUTE] = sum(data.end_ts - data.start_ts for data in compute_data)
        communication_data = RangeCaculator.merge_continuous_intervals(origin_communication_data)
        self._data[self.COMMUNICATION] = sum(data.end_ts - data.start_ts for data in communication_data)
        pure_communication_data, free_data = RangeCaculator.compute_pipeline_overlap(communication_data, compute_data)
        self._data[self.COMM_NOT_OVERLAP] = sum(data.end_ts - data.start_ts for data in pure_communication_data)
        self._data[self.FREE] = sum(data.end_ts - data.start_ts for data in free_data)
        self._data[self.BUBBLE] = sum(data.end_ts - data.start_ts for data in bubble_data)
        self._data[self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE] = self._data[self.COMM_NOT_OVERLAP] - self._data[self.BUBBLE]
        self._data[self.OVERLAPPED] = self._data[self.COMMUNICATION] - self._data[self.COMM_NOT_OVERLAP]
        e2e_time = self._data[self.FREE] + self._data[self.COMPUTE] + self._data[self.COMM_NOT_OVERLAP]
        self._data[self.STAGE] = e2e_time - self._data[self.BUBBLE]
        return [[self._data[self.STEP], self._data[self.COMPUTE] / Constant.NS_TO_US,
                 self._data[self.COMM_NOT_OVERLAP] / Constant.NS_TO_US, self._data[self.OVERLAPPED] / Constant.NS_TO_US,
                 self._data[self.COMMUNICATION] / Constant.NS_TO_US, self._data[self.FREE] / Constant.NS_TO_US,
                 self._data[self.STAGE] / Constant.NS_TO_US, self._data[self.BUBBLE] / Constant.NS_TO_US,
                 self._data[self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE] / Constant.NS_TO_US,
                 self._data[self.PREPARE] / Constant.NS_TO_US]]

    def _init_task_info_from_db(self):
        db_path = self.file_path.get(Constant.PROFILER_DB_PATH)
        conn, curs = DBManager.create_connect_db(db_path)
        if not (conn and curs):
            logger.warning(f"Failed to connect to db file: {db_path}")
            return
        self.task_db_con = conn
        self.task_db_curs = curs
        if DBManager.judge_table_exists(curs, TableConstant.TABLE_STRING_IDS):
            sql = "select id, value from {}".format(TableConstant.TABLE_STRING_IDS)
            string_id_data = DBManager.fetch_all_data(curs, sql, is_dict=False)
            self.string_id_map = {data[0]: data[1] for data in string_id_data}
        if DBManager.judge_table_exists(curs, TableConstant.TABLE_COMPUTE_TASK_INFO):
            sql = f"select TASK.startNs, TASK.endNs from {TableConstant.TABLE_COMPUTE_TASK_INFO} JOIN " \
                  f"{TableConstant.TABLE_TASK} on {TableConstant.TABLE_TASK}.globalTaskId = " \
                  f"{TableConstant.TABLE_COMPUTE_TASK_INFO}.globalTaskId"
            self.compute_task_info = DBManager.fetch_all_data(curs, sql, is_dict=False)
        if DBManager.judge_table_exists(curs, TableConstant.TABLE_COMMUNICATION_OP):
            sql = "select opName, startNs, endNs from {}".format(TableConstant.TABLE_COMMUNICATION_OP)
            self.communication_op_info = DBManager.fetch_all_data(curs, sql, is_dict=False)
        DBManager.destroy_db_connect(conn, curs)

    def _get_communication_data(self):
        communication_data = []
        bubble_data = []
        for op_info in self.communication_op_info:
            op_start_time = op_info[self.START_NS]
            time_range = RangeCaculator.generate_time_range(
                op_start_time, op_info[self.END_NS], class_range=CommunicationTimeRange)
            communication_data.append(time_range)
            op_name = self.string_id_map.get(op_info[self.OP_NAME], '')
            if op_name.startswith('hcom_receive'):
                bubble_data.append(time_range)
        return communication_data, bubble_data

    def _get_compute_data(self):
        compute_data = []
        for compute_task in self.compute_task_info:
            compute_data.append(RangeCaculator.generate_time_range(compute_task[0], compute_task[1]))
        return compute_data