# Copyright (c) 2024, Huawei Technologies Co., Ltd.  All rights reserved.
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.module.communication.comm_perf_predictor import CommPerfPredictor, SimpleParallelCfg
from mindspeed.auto_settings.module.communication.communication_profile import Mc2ProfileTimeInfo
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_constant import NumberConstant


class DebugMc2Comm:
    def __init__(self):
        self.comm_time_y = 0
        self.comm_x = 0
        self.cfg = SearchConfig()
        self.cfg_no = 0


class Mc2CommPerfPredictor(CommPerfPredictor):
    def __init__(self, hard_info):
        super(Mc2CommPerfPredictor, self).__init__(hard_info)
        self.is_mc2_modeling = False
        self.xs = []
        self.ys = []
        self.cfgs = []
        self.w = 0
        self.b = 0

    def get_communication_info_from_profile(
        self, mc2_profile_time_info: Mc2ProfileTimeInfo, hcom_info_tage_id, index
    ):
        mc2_res = hcom_info_tage_id[index][1]
        mat_res = hcom_info_tage_id[index - 1][1]
        mc2_profile_time_info.matmul_compute_time = mat_res.matmul_total_time[0]
        mc2_profile_time_info.total_comm_time = mc2_res.mc2_total_time[0]

    def receive_samples_from_profiling(
        self, config_no, model_config: SearchConfig, mc2_profile_time_info: Mc2ProfileTimeInfo
    ):
        if not self.is_mc2_modeling:
            return
        config = model_config
        tp = config.tp
        cp = config.cp
        pp = config.pp
        s = config.seq_length / NumberConstant.CONVERSION_TIME
        hccs_x = s / (tp * cp) * pp
        hccs_time = (
            mc2_profile_time_info.total_comm_time - mc2_profile_time_info.matmul_compute_time
        )
        cfg = SimpleParallelCfg(config_no, tp, cp, '', pp, '', '')
        self.xs.append(hccs_x)
        self.ys.append(hccs_time)
        self.cfgs.append(cfg)

        debug_info = DebugMc2Comm()
        debug_info.comm_x = hccs_x
        debug_info.comm_time_y = hccs_time
        debug_info.cfg = config
        debug_info.cfg_no = config_no
        self.debug_info_list.append(debug_info)

    def fit(self):
        if not self.is_mc2_modeling:
            return
        sum_x = 0
        sum_time = 0
        for index, x in enumerate(self.xs):
            sum_x += x
            sum_time += self.ys[index]

        self.w = sum_time / sum_x

    def debug(self, config_list):
        if not self.is_mc2_modeling:
            return
        self.logger.debug("===============================================================================")
        mc2lt = "{0:<9}\t{1:<9}\t{2:<1}\t{3:<1}\t{4:<1}\t{5:<1}\t{6:<1}"
        self.logger.debug(f"******************   MC2 modeling   ***********************")
        self.logger.debug(mc2lt.format("MC2_x", "MC2_time", "No", "tp", "dp", "pp", "cp", "ep", chr(12288)))

        for debug_info in self.debug_info_list:
            if debug_info.cfg.use_ascend_mc2:
                self.logger.debug(
                    mc2lt.format(
                        round(debug_info.comm_x, 2), round(debug_info.comm_time_y, 2),
                        debug_info.cfg_no, debug_info.cfg.tp, debug_info.cfg.dp, debug_info.cfg.pp,
                        debug_info.cfg.cp, debug_info.cfg.ep, chr(12288),
                    )
                )
        self.logger.debug(f"------model parameters of MC2-------------------")
        mc2lt = "{0:<9}\t{1:<9}"
        self.logger.debug(mc2lt.format("mc2_w", "mc2_b", chr(12288)))
        self.logger.debug(mc2lt.format(round(self.w, 3), round(self.b, 3), chr(12288)))
        self.logger.debug("===============================================================================")

    def predict(self, search_cfg: SearchConfig):
        if not self.is_mc2_modeling:
            return
        tp = search_cfg.tensor_model_parallel_size
        cp = search_cfg.context_parallel_size
        pp = search_cfg.pipeline_model_parallel_size
        s = search_cfg.seq_length / NumberConstant.CONVERSION_TIME
        mc2_time = 0
        if tp > 1:
            mc2_time = self.w * (s / (tp * cp) * pp) + self.b
        return mc2_time