import os
import math
import copy
import argparse
import json
import csv
import tabulate

from ais_bench.benchmark.configs.datasets.needlebench_v2.needlebench_v2_4k.needlebench_v2_multi_reasoning_4k import language
from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError
from ais_bench.benchmark.cli.config_manager import CustomConfigChecker
from ais_bench.benchmark.utils.logging.logger import AISLogger
from ais_bench.benchmark.datasets.utils.lmm_judge import get_lmm_point_list
from mmengine.config import Config

logger = AISLogger(__name__)


def load_config(config_path: str) -> Config:
    """Load and validate configuration file"""
    if not os.path.exists(config_path):
        raise ParameterValueError(f"Config path: {config_path} is not exist!")
    try:
        config = Config.fromfile(config_path, format_python_code=False)
    except BaseException as e:
        raise RuntimeError(f"Fail to load config {config_path}, failed reason: {e}")
    CustomConfigChecker(config, config_path).check()
    return config

class GEditEvalResultParser:
    def __init__(self, args):
        self.config = load_config(args.config_path)
        self.output_dir = args.timestamp_path
        self.paths_map = dict(
            org_pred_path = [],
            sc_judge_pred_path = [],
            pq_judge_pred_path = [],
        )
        for comb in self.config["model_dataset_combinations"]:
            model_abbr = comb["models"][0]["abbr"]
            dataset_org_abbr = comb["datasets"][0]["abbr"]
            dataset_sc_abbr = f'{comb["datasets"][0]["abbr"]}-{comb["datasets"][0]["judge_infer_cfg"]["judge_model"]["abbr"]}'
            dataset_pq_abbr = f'{comb["datasets"][1]["abbr"]}-{comb["datasets"][1]["judge_infer_cfg"]["judge_model"]["abbr"]}'
            self.paths_map["org_pred_path"].append(os.path.join(self.output_dir, "predictions", model_abbr, f"{dataset_org_abbr}.jsonl"))
            self.paths_map["sc_judge_pred_path"].append(os.path.join(self.output_dir, "predictions", model_abbr, f"{dataset_sc_abbr}.jsonl"))
            self.paths_map["pq_judge_pred_path"].append(os.path.join(self.output_dir, "predictions", model_abbr, f"{dataset_pq_abbr}.jsonl"))

    def parse_results(self):
        logger.info(f"Start parse judge infer result from: {self.output_dir}")
        org_pred_data_list = self._load_and_merge_jsonl("org_pred_path")
        org_pred_data_dict = {item["uuid"]: item for item in org_pred_data_list}
        sc_judge_pred_data_list = self._load_and_merge_jsonl("sc_judge_pred_path")
        sc_judge_pred_data_dict = {item["gold"]: item for item in sc_judge_pred_data_list}
        pq_judge_pred_data_list = self._load_and_merge_jsonl("pq_judge_pred_path")
        pq_judge_pred_data_dict = {item["gold"]: item for item in pq_judge_pred_data_list}

        self.all_data_results = {}

        for uuid in org_pred_data_dict.keys():
            id = org_pred_data_dict[uuid]["id"]
            output_img_path = org_pred_data_dict[uuid]["prediction"]
            sc_point = self._calc_meta_point(sc_judge_pred_data_dict[uuid]["prediction"])
            pq_point = self._calc_meta_point(pq_judge_pred_data_dict[uuid]["prediction"])
            o_point = math.sqrt(sc_point * pq_point)
            question, case_language = self._get_question_and_language(org_pred_data_dict[uuid]["origin_prompt"][0]["prompt"])
            self.all_data_results[id] = {
                "uuid": uuid,
                "question": question,
                "language": case_language,
                "output_img_path": output_img_path,
                "SC_point": sc_point,
                "PQ_point": pq_point,
                "O_point": o_point,
            }
        logger.info(f"Finish parsing results")

    def dump_result_csv(self):
        save_path = os.path.join(self.output_dir, "results")
        logger.info(f"Start dumping details ......")
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        with open(os.path.join(save_path, "gedit_gathered_result.csv"), "w", encoding="utf-8", newline="") as f:
            fieldnames = ["id", "uuid", "question", "language", "output_img_path", "SC_point", "PQ_point", "O_point"]
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            for id in self.all_data_results.keys():
                writer.writerow(self.all_data_results[id])
        logger.info(f"Finish dumping csv to: {os.path.join(save_path, 'gedit_gathered_result.csv')}")

    def display_results(self):
        evaluate_result_list = []

        lang_count = {"zh": 0, "en": 0}
        for id in self.all_data_results.keys():
            lang_count[self.all_data_results[id]["language"]] += 1

        for lang in ["zh", "en"]:
            sc_point_sum = 0
            pq_point_sum = 0
            o_point_sum = 0
            count = 0
            for id in self.all_data_results.keys():
                if self.all_data_results[id]["language"] == lang:
                    sc_point_sum += self.all_data_results[id]["SC_point"]
                    pq_point_sum += self.all_data_results[id]["PQ_point"]
                    o_point_sum += self.all_data_results[id]["O_point"]
                    count += 1
            if count > 0:
                evaluate_result_list.append(copy.deepcopy([lang, sc_point_sum / count, pq_point_sum / count, o_point_sum / count]))

        sc_point_sum = 0
        pq_point_sum = 0
        o_point_sum = 0
        count = len(self.all_data_results)
        for id in self.all_data_results.keys():
            sc_point_sum += self.all_data_results[id]["SC_point"]
            pq_point_sum += self.all_data_results[id]["PQ_point"]
            o_point_sum += self.all_data_results[id]["O_point"]
        evaluate_result_list.append(copy.deepcopy(["all case", sc_point_sum / count, pq_point_sum / count, o_point_sum / count]))

        print(tabulate.tabulate(evaluate_result_list,
                                headers=["language", "SC_point", "PQ_point", "O_point"],
                                floatfmt=".4f"))

    def _calc_meta_point(self, pred):
        results_list = get_lmm_point_list(pred)
        if not results_list:
            return 0
        try:
            point_list = json.loads(results_list)
        except BaseException as e:
            raise RuntimeError(f"Illegal prediction: {pred}")
        return min(point_list)

    def _get_question_and_language(self, input_prompt):
        question = ""
        for item in input_prompt:
            if item.get("type") == "text":
                question = item.get("text")
        # Check if question contains Chinese characters
        if any("\u4e00" <= char <= "\u9fff" for char in question):
            return question, "zh"
        else:
            return question, "en"

    def _load_and_merge_jsonl(self, path_kind: "org_pred_path"):
        merged_data = []
        start_index = 0
        for path in self.paths_map[path_kind]:
            offset_index = copy.deepcopy(start_index)
            with open(path, "r", encoding="utf-8") as f:
                for line in f:
                    data = json.loads(line)
                    data["id"] = data["id"] + offset_index
                    start_index += 1
                    merged_data.append(data)
        return merged_data


def main():
    """Main function"""
    parser = argparse.ArgumentParser(description="Display inference results for gedit dataset")
    parser.add_argument("--config_path", default="./multi_device_run_qwen_image_edit.py", help="Configuration file path")
    parser.add_argument("--timestamp_path", help="Result timestamp path")

    args = parser.parse_args()
    eval_parser = GEditEvalResultParser(args)
    eval_parser.parse_results()
    eval_parser.dump_result_csv()
    eval_parser.display_results()


if __name__ == "__main__":
    main()