from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.module.operator.operator_elemental import DictShape, ProfileList
from mindspeed.auto_settings.module.operator.operator_shape_cal import cal_operator_flops


class DictNoteShape(DictShape):
    def __init__(self):
        super(DictNoteShape, self).__init__()
        self.tp = 0
        self.cp = 0
        self.ep = 0
        self.type = ""
        self.input_shape = ""
        self.output_shape = ""
        self.duration = 0.0
        self.num_experts = 0
        self.seq_length = 0
        self.flops = 0.0
        self.jit = 0

    def change_profile_into_dictshape(self, item, config_info):
        flops = cal_operator_flops(item.input_shapes.replace('"', ''),
                                   item.output_shapes.replace('"', ''),
                                   item.type)
        self.name = item.name
        self.type = item.type
        self.accelerator_core = item.accelerator_core
        self.index_name = ''
        self.tp = config_info.tp
        self.cp = config_info.cp
        self.ep = config_info.ep
        self.jit = config_info.jit
        self.num_experts = config_info.num_experts or 1
        self.seq_length = config_info.seq_length
        self.input_shape = item.input_shapes.replace('"', '')
        self.output_shape = item.output_shapes.replace('"', '')
        self.duration = float(item.duration_us)
        self.flops = flops


class OperatorNoteList:
    def __init__(self):
        self.operator_note_list = []
        self.operator_note_dict = {}
        self.seq_length = 0
        self._logger = get_logger('operator_note_list')

    def get_operator_note(self, block):
        self.get_operator_list(block.origin_profile_data_list)
        self.get_note_in_list(block)
        self.get_note_dict()

    def get_note_in_list(self, block):
        for (index, operator_note) in enumerate(self.operator_note_list):
            tp = block.origin_profile_data_list.data_list[index].config_info.tp
            cp = block.origin_profile_data_list.data_list[index].config_info.cp
            ep = block.origin_profile_data_list.data_list[index].config_info.ep
            num_experts = block.origin_profile_data_list.data_list[index].config_info.num_experts
            # 基础模块对齐
            operator_note.reset_index_name(operator_note.fw, block.base_block.fw)
            operator_note.reset_index_name(operator_note.bw, block.base_block.bw)
            # CP基础模块对齐
            if cp > 1:
                _, cp_fw_index = operator_note.reset_index_name(operator_note.fw, block.cp_block.fw)
                _, cp_re_index = operator_note.reset_index_name(operator_note.bw, block.cp_block.re)
                _, cp_bw_index = operator_note.reset_index_name(operator_note.bw, block.cp_block.bw)
                if cp > 2:
                    operator_note.fw = block.cp_block.reset_index_diff_cp(operator_note.fw, block.cp_block.diff_list.fw,
                                                                          cp_fw_index, cp / 2)
                    operator_note.bw = block.cp_block.reset_index_diff_cp(operator_note.bw, block.cp_block.diff_list.re,
                                                                          cp_re_index, cp / 2)
                    operator_note.bw = block.cp_block.reset_index_diff_cp(operator_note.bw, block.cp_block.diff_list.bw,
                                                                          cp_bw_index, cp / 2)
            # EP模块对齐
            if num_experts:
                if num_experts // ep >= 2:
                    operator_note.fw = block.ep_block.reset_index_diff_ep(operator_note.fw, block.ep_block.fw,
                                                                          (num_experts / ep) - 1)
                    operator_note.bw = block.ep_block.reset_index_diff_ep(operator_note.bw, block.ep_block.re,
                                                                          (num_experts / ep) - 1)
                    operator_note.bw = block.ep_block.reset_index_diff_ep(operator_note.bw, block.ep_block.bw,
                                                                          (num_experts / ep) - 1)

    def get_note_dict(self):
        for operator_note in self.operator_note_list:
            operator_list = operator_note.fw + operator_note.bw
            for operator in operator_list:
                dict_exist = False
                if operator.index_name in self.operator_note_dict.keys():
                    for dict_temp in self.operator_note_dict[operator.index_name]:
                        if dict_temp == operator:
                            dict_exist = True
                    if not dict_exist:
                        self.operator_note_dict[operator.index_name].append(operator)
                else:
                    self.operator_note_dict[operator.index_name] = [operator]

    def get_operator_list(self, origin_profile_data_list):
        self.seq_length = origin_profile_data_list.data_list[0].config_info.seq_length
        for (index, origin_profile_data) in enumerate(origin_profile_data_list.data_list):
            operator_note = ProfileList()
            self.change_profile_list_into_dict_shape_list(origin_profile_data.profile_list, operator_note,
                                                          origin_profile_data.config_info)
            self.operator_note_list.append(operator_note)

    @staticmethod
    def change_profile_list_into_dict_shape_list(change_profile_list, change_operator_list, config_info):
        for (index, item) in enumerate(change_profile_list.fw):
            dict_shape_fw = DictNoteShape()
            dict_shape_fw.change_profile_into_dictshape(item, config_info)
            change_operator_list.fw.append(dict_shape_fw)
        for (index, item) in enumerate(change_profile_list.bw):
            dict_shape_bw = DictNoteShape()
            dict_shape_bw.change_profile_into_dictshape(item, config_info)
            change_operator_list.bw.append(dict_shape_bw)