import os
import string

import torch
import torch.distributed as dist
import pandas as pd
from tqdm import tqdm
from megatron.core import mpu

from mindspeed_mm.utils.security_utils.input_filter import sanitize_dataframe
from mindspeed_mm.utils.security_utils.validate_path import normalize_path
from mindspeed_mm.tasks.evaluation.utils.analysis_utils import mmmu_open_question_preprocess, report_acc, eval_vanilla, \
    track_progress_rich
from mindspeed_mm.tasks.evaluation.utils.string_utils import logger_rank_0
from mindspeed_mm.tasks.evaluation.utils.file_utils import load_json, save_json
from mindspeed_mm.tasks.evaluation.eval_prompt.build_prompt_base import BasePromptTemplate
from mindspeed_mm.tasks.evaluation.eval_datasets.datasets_base import BaseEvalDataset


class BaseEvalImpl:

    def __init__(self, dataset: BaseEvalDataset, inference_pipeline, args, model_prompt_template=None, drop_last=False):

        self.rank = mpu.get_data_parallel_group().rank()
        self.world_size = mpu.get_data_parallel_world_size()

        self.output_path = normalize_path(args.result_output_path)[0]
        os.makedirs(self.output_path, exist_ok=True)
        model_name = args.evaluation_model
        self.dataset = dataset
        self.inference_pipeline = inference_pipeline
        self.model_prompt_template = model_prompt_template

        self.dataset_name = args.evaluation_dataset
        self.supported_types = ['text', 'image', 'video']
        self.result_path = os.path.join(self.output_path, model_name + "_" + self.dataset_name + ".xlsx")
        self.prev_file = os.path.join(self.output_path, model_name + "_" + self.dataset_name + "_PREV.json")

        # Save the results on each rank.
        self.out_file = self._out_file(self.rank)
        self.report_file = self.result_path.replace('.xlsx', '_acc.xlsx')
        is_divisive = len(dataset) % self.world_size == 0
        remainder = len(dataset) % self.world_size
        if is_divisive:
            data_len_total = len(dataset)
        elif drop_last and not is_divisive:
            raise ValueError("drop_last must be false now")
        elif not drop_last and not is_divisive:
            print(f"the length of dataset: {len(dataset)}, world_size: {self.world_size}, remainder: {remainder}")
            raise ValueError(f'The length of the dataset must be divided evenly by the world_size.')
        sheet_indices = list(range(self.rank, data_len_total, self.world_size))  # Chunk the data for multi-rank evaluation later
        print(f"Rank {self.rank} of the data parallel group has {len(sheet_indices)} evaluation data ")
        self.data = dataset.data.iloc[sheet_indices]
        self.data_indices = [i for i in self.data['index']]  # Each rank processes its own index
        self.result = load_json(self.prev_file) if os.path.exists(self.prev_file) else {}

    def __call__(self):
        """
        Call pipeline to start evaluation and write evaluation results to file
        """

        data = self.data[~self.data['index'].isin(self.result)]
        data_len = len(data)
        for i in tqdm(range(data_len)):
            data_index = data.iloc[i]['index']
            if data_index in self.result:
                # If data_index is already in result, skip it
                continue

            if self.model_prompt_template and self.model_prompt_template.is_use_custom_prompt(self.dataset_name):
                prompt = self.model_prompt_template.build_prompt(data.iloc[i], self.dataset.dump_image,
                                                                 self.dataset_name)
            else:
                prompt = self.dataset.build_prompt(data.iloc[i])

            BasePromptTemplate.check_content_type(prompt)
            instruct = BasePromptTemplate.preprocess_content(prompt)
            if instruct is None or BasePromptTemplate.check_content_type(instruct) != 'ListOfDict':
                raise ValueError(f'Invalid instruct: {instruct}. Only list of dict (ListOfDict) is supported.')
            for item in instruct:
                if item['type'] not in self.supported_types:
                    raise ValueError(
                        f'The instruct type is {item["type"]},only support {", ".join(self.supported_types)}')
            # To wait for all processes to finish processing the data
            dist.barrier()
            if hasattr(self.inference_pipeline, "evaluate"):
                response = self.inference_pipeline.evaluate(instruct)
            else:
                text, images = '', []
                for item in instruct:
                    if item['type'] == 'text':
                        text += item['value']
                    elif item['type'] == 'image':
                        images.append(item['value'])
                    else:
                        raise Exception("unsupported instruct type")
                response = self.inference_pipeline(prompt=text, images=images, return_ids=True)
            torch.cuda.empty_cache()
            if response is not None:
                print(response, flush=True)
                self.result[data_index] = response

            # Save pkl file every 20 steps
            if (
                ((i + 1) % 20 == 0)
                and mpu.is_pipeline_last_stage()
                and (mpu.get_tensor_model_parallel_rank() == 0)
            ):
                save_json(self.result, self.out_file)

        if mpu.is_pipeline_last_stage() and mpu.get_tensor_model_parallel_rank() == 0:
            res = {k: self.result[k] for k in self.data_indices}
            save_json(res, self.out_file)

    def gather_result(self):
        """
        Get the calculation results on each rank
        """
        if self.world_size > 0:
            print("rank:", self.rank, " finish evaluate")
            dist.barrier()
        if torch.distributed.get_rank() == 0:
            data_all = {}
            for i in range(self.world_size):
                data_all.update(load_json(self._out_file(i)))

            data = self.dataset.data
            for x in data['index']:
                if x not in data_all:
                    raise ValueError(f"{x} not found in data_all")
            data['prediction'] = [str(data_all.get(x, None)) for x in data['index']]
            if 'image' in data:
                data.pop('image')

            # save to xlsx
            data.to_excel(self.result_path, index=False, engine='xlsxwriter')

            for i in range(self.world_size):
                os.remove(self._out_file(i))
        if self.world_size > 0:
            dist.barrier()

    def analyse_result(self):
        if self.world_size > 0:
            dist.barrier()
        if torch.distributed.get_rank() == 0:
            data = pd.read_excel(self.result_path)
            data = data.sort_values(by='index')
            data['prediction'] = [str(x) for x in data['prediction']]

            for k in data.keys():
                data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)

            meta = self.dataset.data
            meta_q_map = {x: y for x, y in zip(meta['index'], meta['question'])}
            data_map = {x: y for x, y in zip(data['index'], data['question'])}

            for k in data_map:
                if k not in meta_q_map:
                    raise ValueError(
                        f'eval_file should be the same as or a subset of dataset {self.dataset_name}'
                    )

            result = {}
            answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}

            if 'MMMU' in self.dataset_name:
                data = mmmu_open_question_preprocess(data)
                answer_map = {k: (v if v in list(string.ascii_uppercase) else 'A') for k, v in answer_map.items()}

            data = data[data['index'].isin(answer_map)]
            data['GT'] = [answer_map[idx] for idx in data['index']]
            items = []

            for i in range(len(data)):
                # Dealing with the normal part
                item = data.iloc[i]
                if item['index'] not in result:
                    items.append(item)

            tups = [dict(item=x, dataset_name=self.dataset_name) for x in items]
            keys = [x['index'] for x in items]

            nproc = 1

            if tups:
                res = track_progress_rich(eval_vanilla, tups, nproc=nproc, chunksize=nproc, save=self.prev_file,
                                          keys=keys)
                result = load_json(self.prev_file)
                for k, v in zip(keys, res):
                    result[k] = v
            data['hit'] = [result.get(i, {}).get('hit', None) for i in data['index']]
            data['log'] = [result.get(i, {}).get('log', None) for i in data['index']]
            if 'GT' in data:
                data.pop('GT')
            cleaned_data = sanitize_dataframe(data)
            cleaned_data.to_excel(self.result_path, index=False, engine='xlsxwriter')
            acc = report_acc(data)
            cleaned_acc = sanitize_dataframe(acc)
            cleaned_acc.to_excel(self.report_file, index=False, engine='xlsxwriter')
            logger_rank_0(f"save acc file to {self.report_file}")
        if self.world_size > 0:
            dist.barrier()

    def _out_file(self, rank):
        return os.path.join(self.output_path,
                            str(rank) + "_" + f'{self.world_size}_{self.dataset_name}.json')