import csv
import json
import os
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_constant import SpecialOperatorName
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_constant import NumberConstant
from mindspeed.auto_settings.utils.file_utils import check_file_size


class FileAnalyseTool:
    """
        support csv and json parse
    """

    @classmethod
    def analyse_csv_info(cls, file_path: str, csv_name: str):
        csv_path = os.path.join(file_path, csv_name)
        try:
            with open(csv_path, newline='') as csvfile:
                check_file_size(csvfile)
                reader = csv.DictReader(csvfile)
                csv_details = list(reader)

        except FileNotFoundError as e:
            raise f"'Please check file name, {e}"
        except csv.Error as e:
            raise f"An error occurred while reading the CSV file: {e}"
        return csv_details

    @classmethod
    def analyse_json_info(cls, file_path: str, json_name: str):
        json_path = os.path.join(file_path, json_name)
        json_details = {"p2p": {}, "collective": {}}
        try:
            with open(json_path, 'r') as f:
                check_file_size(f)
                details = json.load(f)
            details_value = list(details.values())[0]
            for name, info in details_value.get('p2p', {}).items():
                comm_name = name.split("@")[0]
                json_details['p2p'][comm_name] = info["Communication Time Info"]
            for name, info in details_value.get('collective', {}).items():
                comm_name = name.split("@")[0]
                json_details['collective'][comm_name] = info["Communication Time Info"]
        except KeyError as e:
            raise f"'Please check file name, {e}"
        except Exception as e:
            raise f"Read communication file error: {e}"

        return json_details


class StructureAnalyseTool:
    """
        support structure parse
    """

    def __init__(self, rank_file_path, memory_details):
        self._rank_file_path = rank_file_path
        self._memory_details = memory_details
        self.fw_norm_op = SpecialOperatorName.FW_RMS_NORM_TYPE
        self.bw_norm_op = SpecialOperatorName.BW_RMS_NORM_TYPE
        self._search_special_norm_op()

    def analyse_norm_op(self):
        """ Analyse the norm op details in kernel_details.csv. """
        fw_norm_op_idx_list = []
        bw_norm_op_idx_list = []
        matmul_total_time = 0
        mc2_total_time = 0
        for idx, row in enumerate(self._memory_details):
            if "Name" not in row or "Type" not in row:
                continue
            if row["Type"] == "MatMulCommon":
                time = float(row["Duration(us)"]) / NumberConstant.CONVERSION_TIME
                matmul_total_time += time
                mc2_total_time += time
            if row["Type"] == "AllGatherMatmul" or row["Type"] == "MatmulReduceScatter":
                mc2_total_time += float(row["Duration(us)"]) / NumberConstant.CONVERSION_TIME
            if row["Type"] == self.fw_norm_op:
                fw_norm_op_idx_list.append(idx)
            elif row["Type"] == self.bw_norm_op:
                bw_norm_op_idx_list.append(idx)
        if fw_norm_op_idx_list == [] and bw_norm_op_idx_list == []:
            #通过layernorm找算子
            fw_norm_op_idx_list, bw_norm_op_idx_list = self.analyse_layer_norm_op()
        if bw_norm_op_idx_list == []:
            bw_norm_op_idx_list = [fw_norm_op_idx_list[-1], fw_norm_op_idx_list[-1], fw_norm_op_idx_list[-1]]

        return fw_norm_op_idx_list, bw_norm_op_idx_list, matmul_total_time, mc2_total_time

    def analyse_layer_norm_op(self):
        """ Analyse the norm op details in kernel_details.csv. """
        op_idx_list = []
        op_type_list = []
        for idx, row in enumerate(self._memory_details):
            if "Name" not in row or "Type" not in row:
                continue
            if "LayerNorm" in row["Type"] or "FlashAttentionScore" in row["Type"]:
                op_idx_list.append(idx)
                op_type_list.append(row["Type"])
        fw_norm_op_idx_list = []
        bw_norm_op_idx_list = []
        write_flag = False
        for idx, op_type in enumerate(op_type_list):
            if idx + 2 < len(op_type_list):
                if op_type == "LayerNormV3" and op_type_list[idx + 1] == "FlashAttentionScore" \
                    and op_type_list[idx + 2] == "FlashAttentionScore" or \
                len(op_type_list) <= 4:
                    write_flag = True
            if idx > 2:
                if op_type == "LayerNormGradV3" and op_type_list[idx - 1] == "FlashAttentionScoreGrad" \
                    and op_type_list[idx - 2] == "FlashAttentionScoreGrad":
                    bw_norm_op_idx_list.append(op_idx_list[idx])
                    write_flag = False
            if not write_flag:
                continue
            if "LayerNormV" in op_type:
                fw_norm_op_idx_list.append(op_idx_list[idx])
            if "LayerNormGradV" in op_type:
                bw_norm_op_idx_list.append(op_idx_list[idx])
        return fw_norm_op_idx_list, bw_norm_op_idx_list

    def get_fw_norm_op(self):
        return self.fw_norm_op

    def _search_special_norm_op(self):
        """ Special norm op: rms_norm, layer_norm, rms_norm_grad """
        op_statistic_details = FileAnalyseTool.analyse_csv_info(self._rank_file_path, 'op_statistic.csv')
        for op in op_statistic_details:
            if SpecialOperatorName.FW_LAYER_NORM_TYPE in op['OP Type']:
                self.fw_norm_op = SpecialOperatorName.FW_LAYER_NORM_TYPE
                self.bw_norm_op = SpecialOperatorName.BW_LAYER_NORM_TYPE
                break