"""
算子预估
"""
import math

from mindspeed.auto_settings.config.model_config import get_model_config
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.module.operator.operator import Operator
from mindspeed.auto_settings.module.communication.communication import Communication
from mindspeed.auto_settings.utils.utils import get_num_warmup_micro_batches


class TimeCost(object):

    def __init__(self):
        self.logger = get_logger("TimeCost")
        self.operator = Operator()
        self.communication = Communication()

    def train_models(self, profile_results):
        self.operator.train_models(profile_results)
        self.communication.train_models(profile_results)

    def get_communication_time(self, search_cfg):
        """
        获取通信相关耗时
        """
        return self.communication.get_communication_time(search_cfg)

    def get_operator_time(self, search_cfg):
        """
        获取算子耗时信息
        """
        return self.operator.get_operator_info(search_cfg)

    def get_time_cost(self, search_cfg, memory_info):
        """
        考虑vpp,返回最终的耗时信息
        """
        tp = search_cfg.tensor_model_parallel_size
        dp = search_cfg.data_parallel_size
        pp = search_cfg.pipeline_model_parallel_size
        vp = search_cfg.num_layers // (pp * search_cfg.num_layers_per_virtual_pipeline_stage) \
            if search_cfg.num_layers_per_virtual_pipeline_stage else 1
        cp = search_cfg.context_parallel_size
        ep = search_cfg.expert_model_parallel_size if search_cfg.expert_model_parallel_size else 1

        num_layers = get_model_config().num_layers
        global_batch_size = get_model_config().global_batch_size
        model_micro_batch_size = 1
        search_micro_batch_size = search_cfg.micro_batch_size
        micro_batch_num = global_batch_size / (dp * search_micro_batch_size)
        layer_num = math.ceil(micro_batch_num * (num_layers / pp))
        search_model_mbs_ratio = search_micro_batch_size / model_micro_batch_size
        bubble_ratio = (pp - 1) / (micro_batch_num * vp + pp - 1)

        operator_info = self.get_operator_time(search_cfg)
        operator_time = operator_info["operator_time"]
        operator_fw_time = operator_info["operator_fw_time"]

        communication_info = self.get_communication_time(search_cfg)
        use_mc2 = communication_info["use_mc2"]
        fw_communication_time = communication_info["fw_communication_time"]
        communication_time = communication_info["communication_time"]
        pp_time = communication_info["pp_time"]
        dp_time = communication_info["dp_time"]
        tp_time = communication_info["tp_time"]
        cp_time = communication_info["cp_time"]
        ep_time = communication_info["ep_time"]

        fw_performance = operator_fw_time + fw_communication_time
        total_operator_time = operator_time * layer_num
        total_time = total_operator_time + communication_time

        self.logger.debug('global_batch_size : {}, num_layers : {}, search_micro_batch_size : {}, operator_time : {}, '
                          'layer_num : {}'.format(global_batch_size, num_layers, search_micro_batch_size,
                                                  operator_time, layer_num))
        total_time = total_time / (1 - bubble_ratio)
        bubble_time = total_time * bubble_ratio
        total_time = total_time + pp_time * search_model_mbs_ratio + dp_time

        need_recompute = memory_info["need_recompute"]
        model_cfg = get_model_config()
        layer_calculate = memory_info["layer_calculate"]
        warmup_micro_batchs, total_num_micro_batches = get_num_warmup_micro_batches(search_cfg, model_cfg)
        num_layers = model_cfg.num_layers // search_cfg.pp

        self.logger.debug(f"******************   total_time(ms)  ***********************")
        tplt = "{0:<2}\t{1:<2}\t{2:<2}\t{3:<2}\t{4:<2}\t{5:<2}\t{6:<8}\t{7:<10}\t{8:<8}\t{9:<8}\t{10:<8}\t{11:<8}"
        self.logger.debug(tplt.format('tp', 'dp', 'pp', 'vp', 'cp', 'ep', 'operator_time',
                                      'comm_time', 'bubble_time', 'total_time', 'fw_time', chr(12288)))
        tplt = "{0:<2}\t{1:<2}\t{2:<2}\t{3:<2}\t{4:<2}\t{5:<2}\t{6:8.2f}\t{7:8.2f}\t{8:8.2f}\t{9:8.2f}\t{10:8.2f}"
        total_communication_time = communication_time + pp_time * search_model_mbs_ratio + dp_time
        self.logger.debug(tplt.format(tp, dp, pp, vp, cp, ep, total_operator_time,
                                      total_communication_time, bubble_time, total_time, operator_fw_time, chr(12288)))
        tplt = "{0:<4}\t{1:<4}\t{2:<4}\t{3:<4}\t{4:<4}\t{5:<4}"
        self.logger.debug(f"*******   each layer mbs communication time(ms)  ********")
        self.logger.debug(tplt.format('tp_time', 'dp_time', 'pp_time',
                                      'bubble', 'cp_time', 'ep_time', chr(12288)))
        tplt = "{0:4.2f}\t{1:4.2f}\t{2:4.2f}\t{3:4.2f}\t{4:4.2f}\t{5:4.2f}"
        self.logger.debug(tplt.format(tp_time, dp_time, pp_time,
                                      bubble_time, cp_time, ep_time, chr(12288)))
        self.logger.debug(f"end-to-end, each*(global_batch_size / (dp *pp))* num_layers")
        tplt = "{0:<4}\t{1:<4}\t{2:<4}\t{3:<4}\t{4:<4}\t{5:<4}"
        self.logger.debug(tplt.format('tp_time', 'dp_time', 'pp_time',
                                      'bubble', 'cp_time', 'ep_time', chr(12288)))
        tplt = "{0:4.0f}\t{1:4.2f}\t{2:4.2f}\t{3:4.2f}\t{4:4.2f}\t{5:4.2f}"
        self.logger.debug(tplt.format(tp_time * layer_num * search_model_mbs_ratio, dp_time,
                                      pp_time, bubble_time, cp_time * layer_num * search_model_mbs_ratio,
                                      ep_time * layer_num * search_model_mbs_ratio, chr(12288)))
        self.logger.debug(f"before recompute, perf = {total_time}")
        self.logger.debug(f"success enter recompute_solver and tp = {search_cfg.tensor_model_parallel_size} "
                          f"pp = {search_cfg.pipeline_model_parallel_size} "
                          f"layers_per_vpp={search_cfg.num_layers_per_virtual_pipeline_stage} "
                          f"dp = {search_cfg.data_parallel_size} cp = {search_cfg.context_parallel_size} "
                          f"ep = {search_cfg.expert_model_parallel_size} zero = {search_cfg.use_distributed_optimizer}")
        if not need_recompute:
            total_time = total_time - total_num_micro_batches * num_layers * fw_performance
            return {
                "use_mc2": use_mc2,
                "total_time": total_time,
                "num_layers": 0
            }
        if search_cfg.layers_per_vpp:
            time_cost = total_num_micro_batches * layer_calculate * fw_performance
        else:
            time_cost = total_num_micro_batches * layer_calculate * fw_performance
        total_time = total_time - time_cost
        num_layers = num_layers - layer_calculate
        return {
            "use_mc2": use_mc2,
            "total_time": total_time,
            "num_layers": num_layers
        }