# Copyright Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
import glob
import json
import os
import stat
from pathlib import Path
import math
import re
import collections

import torch
import torch.nn.functional as F
import torch_npu
import pandas as pd
import numpy as np
from transformers import PreTrainedModel, AutoTokenizer
from tqdm import tqdm

from security import json_safe_load, json_safe_dump, get_valid_path, get_valid_write_path, get_valid_read_path
from precision_tool import logger
from precision_tool import truthfulqa_eval
from ascend_utils.common.security.type import check_type


supported_dataset = ["boolq", "ceval_0_shot", "ceval_5_shot", "humaneval", "truthfulqa", "mmlu"]

supported_hardware = ["npu"]

TENSOR_TYPE_PYTORCH = "pt"
DATASET_HUMAN_EVAL = "humaneval"

CalcParam = collections.namedtuple('CalcParam', ['tag', 'frame', 'idx', 'scores_true', 
                                   'scores_false', 'ref_true', 'ref_best'])


def is_running_on_npu(device_name):
    return device_name.lower() == "npu"


class PrecisionTest:
    def __init__(self, model, tokenizer, dataset, batch_size, hardware_type,
                 tokenizer_return_type_id=False, shot=5):
        """
        @param model:
            llm to run the test, should be an instance of transformers.PreTrainedModel
        @param dataset:
            dataset to test precision
        @param batch_size:
            batch_size to run inference
        @param hardware_type:
            currently only npu is supported
        @param tokenizer_return_type_id:
            tokenizer return token type id
        @param shot
            shot to test precision
        """
        self.logger = logger
        self.logger.info("Ready to init precision tool.")
        self.dataset = dataset
        if not self.__verify_dataset():
            raise TypeError("Dataset setting failed.")
        self.model = model
        if not self.__verify_model():
            raise TypeError("Invalid model.")
        self.device = model.device
        self.tokenizer_return_type_id = tokenizer_return_type_id
        check_type(self.tokenizer_return_type_id, bool, param_name="tokenizer_return_type_id")
        self.tokenizer = tokenizer
        if not self.__verify_tokenizer():
            raise TypeError("Tokenizer must be matched with model.")
        self.batch_size = batch_size
        if not self.__verify_batch_size():
            raise TypeError("Batch size must be a positive integer.")
        self.shot = shot
        if not self.__verify_shot():
            raise TypeError("shot must be a nonnegative integer.")
        self.hardware_type = hardware_type
        if self.hardware_type not in supported_hardware:
            raise TypeError(f"Hardware type {self.hardware_type} is not supported.")
        self.script_path = os.path.dirname(os.path.abspath(__file__))
        self.dataset_path = os.path.join(self.script_path, "dataset", self.dataset)
        if not os.path.exists(self.dataset_path):
            raise EnvironmentError(f"Dataset was not found, valid path should be '{self.dataset_path}")
        self.dataset_path = get_valid_path(self.dataset_path)
        self.result_file = ""
        self.logger.info("Precision test was inited.")
        
    @staticmethod
    def __postprocess(text: str, options: str, cushion=True) -> str:
        patterns = [
            f'答案是?\s?([{options}])',
            f'答案是?\s?:([{options}])',
            f'答案是?\s?:([{options}])',
            f'答案应该?是\s?([{options}])',
            f'答案应该?选\s?([{options}])',
            f'答案为\s?([{options}])',
            f'答案选\s?([{options}])',
            f'选择?\s?([{options}])',
            f'故选?\s?([{options}])'
            f'只有选?项?\s?([{options}])\s?是?对',
            f'只有选?项?\s?([{options}])\s?是?错',
            f'只有选?项?\s?([{options}])\s?不?正确',
            f'只有选?项?\s?([{options}])\s?错误',
            f'说法不?对选?项?的?是\s?([{options}])',
            f'说法不?正确选?项?的?是\s?([{options}])',
            f'说法错误选?项?的?是\s?([{options}])',
            f'([{options}])\s?是正确的',
            f'([{options}])\s?是正确答案',
            f'选项\s?([{options}])\s?正确',
            f'所以答\s?([{options}])',
            f'所以\s?([{options}][.。$]?$)',
            f'所有\s?([{options}][.。$]?$)',
            f'[\s,::,]([{options}])[。,,\.]?$',
            f'[\s,,::][故即]([{options}])[。\.]?$',
            f'[\s,,::]因此([{options}])[。\.]?$',
            f'[是为。]\s?([{options}])[。\.]?$',
            f'因此\s?([{options}])[。\.]?$',
            f'显然\s?([{options}])[。\.]?$',
            f'答案是\s?(\S+)(?:。|$)',
            f'答案应该是\s?(\S+)(?:。|$)',
            f'答案为\s?(\S+)(?:。|$)',
            f'[Tt]he answer is \(?([{options}])\)?',
            f'[Tt]he answer is option \(?([{options}])\)?',
            f'[Tt]he correct answer is \(?([{options}])\)?',
            f'[Tt]he correct answer is option \(?([{options}])\)?',
            f'[Tt]he answer to the question is \(?([{options}])\)?',
            f'^选项\s?([{options}])',
            f'^([{options}])\s?选?项',
            f'(\s|^)[{options}][\s。,,::\.$]',
            f'(\s|^)[{options}](\s|$)',
            f'1.\s?(.*?)$',
            f'1.\s?([{options}])[.。$]?$',
        ]
        cushion_patterns = [
            f'([{options}]):',
            f'[{options}]',
        ]

        if cushion:
            patterns.extend(cushion_patterns)
        for pattern in patterns:
            match = re.search(pattern, text)
            if not match:
                continue 
            outputs = match.group(0)
            for i in options:
                if i in outputs:
                    return i
        return ''

    def test(self):
        self.logger.info("Begin to run precision test.")
        torch.manual_seed(1)
        self.__run_test()
        self.logger.info("Precision test has finished.")

    def __verify_batch_size(self):
        if not isinstance(self.batch_size, int) or isinstance(self.batch_size, bool):
            self.logger.error("Batch size must be an integer")
            return False
        if not self.batch_size > 0:
            self.logger.error("Batch size must be positive")
            return False
        return True

    def __verify_dataset(self):
        if not isinstance(self.dataset, str):
            self.logger.error("Dataset should be a string")
            return False
        if self.dataset not in supported_dataset:
            self.logger.error("Dataset %r is not supported. Use %s instead." % (self.dataset, supported_dataset))
            return False
        if self.dataset == DATASET_HUMAN_EVAL:
            try:
                from human_eval.evaluation import evaluate_functional_correctness
            except Exception as e:
                self.logger.error(f"Import human_eval test package failed. Error info is as below. \n {e}")
                return False
        return True

    def __verify_model(self):
        if not isinstance(self.model, PreTrainedModel):
            self.logger.error(f"Model is not a valid model, type is {type(self.model)}")
            return False
        return True

    def __verify_tokenizer(self):
        text = "hello world"
        inputs = (self.tokenizer(text, return_tensors=TENSOR_TYPE_PYTORCH, 
                    return_token_type_ids=self.tokenizer_return_type_id).to(self.model.device))
        try:
            with torch.no_grad():
                self.model(**inputs)
        except Exception as e:
            self.logger.error(f"{e}")
            self.logger.error("Model and tokenizer are not compatible. Cannot simply use model API.")
            self.logger.error(f"Model type is {type(self.model)}")
            self.logger.error(f"Tokenizer type is {type(self.tokenizer)}")
            return False
        try:
            with torch.no_grad():
                self.model.generate(**inputs, do_sample=False, max_new_tokens=1)
        except Exception as e:
            self.logger.error(f"{e}")
            self.logger.error("Model and tokenizer are not compatible. Cannot simply use generate API.")
            self.logger.error(f"Model type is {type(self.model)}")
            self.logger.error(f"Tokenizer type is {type(self.tokenizer)}")
            return False
        self.logger.info("Model and Tokenizer are compatible.")
        return True

    def __verify_shot(self):
        if not isinstance(self.shot, int) or isinstance(self.shot, bool):
            self.logger.error("shot must be an integer, can not be {}.".format(type(self.shot).__name__))
            return False
        if not self.shot >= 0:
            self.logger.error("shot must be nonnegative integer, can not be {}.".format(self.shot))
            return False
        return True

    def __run_test(self):
        if self.dataset == 'ceval_0_shot':
            self.__run_full_dataset_ceval_0_shot()
        elif self.dataset == 'ceval_5_shot':
            self.__run_full_dataset_ceval_5_shot()
        elif self.dataset == 'boolq':
            self.__run_full_dataset_boolq()
        elif self.dataset == 'humaneval':
            self.__run_full_dataset_humaneval()
        elif self.dataset == 'truthfulqa':
            self.__run_full_dataset_truthfulqa()
        elif self.dataset == 'mmlu':
            self.__run_full_dataset_mmlu()
        else:
            raise TypeError("Unsupported dataset")

    def __run_full_dataset_ceval_0_shot(self):
        choices = ["A", "B", "C", "D"]
        choice_tokens = [self.tokenizer.encode(choice, add_special_tokens=False)[0] for choice in choices]

        extraction_prompt = '综上所述,ABCD中正确的选项是:'

        def build_prompt(text):
            return "[Round {}]\n\n问:{}\n\n答:".format(1, text)

        def run_test():
            correct_total, sum_total = 0, 0
            for entry in glob.glob((Path(self.dataset_path) / "val/**/*.jsonl").as_posix(), recursive=True):
                correct, dataset = 0, []
                entry = get_valid_read_path(entry)
                with open(entry, encoding='utf-8') as file:
                    for line in file:
                        single_json = json.loads(line)
                        dataset.append(single_json)

                dataset_num = len(dataset)

                dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size)
                for batch in tqdm(dataloader):
                    texts = batch["inputs_pretokenized"]
                    queries = [build_prompt(query) for query in texts]
                    inputs = self.tokenizer(queries, padding=True, return_tensors=TENSOR_TYPE_PYTORCH, truncation=True,
                                            return_token_type_ids=self.tokenizer_return_type_id).to(self.device)
                    outputs = self.model.generate(**inputs, do_sample=False, max_new_tokens=512)
                    intermediate_outputs = []
                    for idx in range(len(outputs)):
                        output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
                        response = self.tokenizer.decode(output)
                        intermediate_outputs.append(response)
                    answer_texts = []
                    for text, intermediate in zip(texts, intermediate_outputs):
                        answer_texts.append(text + intermediate + "\n" + extraction_prompt)
                    input_tokens = []
                    for answer_text in answer_texts:
                        input_tokens.append(build_prompt(answer_text))
                    inputs = self.tokenizer(input_tokens, padding=True, return_tensors=TENSOR_TYPE_PYTORCH,
                                            truncation=True,
                                            return_token_type_ids=self.tokenizer_return_type_id).to(self.device)
                    outputs = self.model(**inputs)
                    logits = outputs.logits[:, -1, :]
                    logits = logits[:, choice_tokens]
                    preds = logits.argmax(dim=-1)
                    correct += (preds.cpu() == batch["label"]).sum().item()
                correct_total += correct
                sum_total += dataset_num
            return correct_total, sum_total

        with torch.no_grad():
            correct_total, sum_total = run_test()

        if sum_total == 0:
            self.logger.error("Did not ran any test, maybe wrong cevel_0_shot dataset folder.")
        else:
            self.logger.info(
                f"correct rate:{correct_total / sum_total}, correct total:{correct_total}, sum_total:{sum_total}")

    def __run_full_dataset_ceval_5_shot(self):
        choices = ["A", "B", "C", "D"]
        shot = 5

        def get_subject_mapping():
            subject_mapping_path = os.path.join(self.dataset_path, "subject_mapping.json")
            subject_mapping_path = get_valid_read_path(subject_mapping_path)
            with open(subject_mapping_path) as f:
                subject_mapping = json.load(f)
            return subject_mapping

        def load_csv_by_task_name(task_name, dataset_path):
            dev_file_path = get_valid_path(os.path.join(dataset_path, "dev", task_name + "_dev.csv"))
            dev_df = pd.read_csv(dev_file_path, header=None)[:shot + 1]
            val_file_path = get_valid_path(os.path.join(dataset_path, "val", task_name + "_val.csv"))
            val_df = pd.read_csv(val_file_path, header=None)

            dev_df = dev_df.iloc[1:, 1:]
            val_df = val_df.iloc[1:, 1:]
            return dev_df, val_df

        def format_subject(subject):
            str_units = subject.split("_")
            res = ""
            for entry in str_units:
                res += " " + entry
            return res

        def format_example(df, idx, include_answer=True):
            prompt = df.iloc[idx, 0]
            k = len(choices)
            for j in range(k):
                prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
            prompt += "\nAnswer:"
            if include_answer:
                prompt += " {}\n\n".format(df.iloc[idx, k + 1])
            return prompt

        def gen_prompt(train_df, subject, k=-1):
            prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
                format_subject(subject))
            if k == -1:
                k = train_df.shape[0]
            for i in range(k):
                prompt += format_example(train_df, i)
            return prompt

        correct_total = 0
        sum_total = 0

        subject_mapping = get_subject_mapping()
        index = 1
        for task_name in tqdm(subject_mapping):
            dev_df, val_df = load_csv_by_task_name(task_name, self.dataset_path)
            correct = 0
            task_len = val_df.shape[0]
            for i in range(math.ceil(task_len / self.batch_size)):
                q_num = self.batch_size if (i + 1) * self.batch_size <= task_len else task_len - i * self.batch_size
                prompt_ends = []
                for j in range(q_num):
                    prompt_ends.append(format_example(val_df, i * self.batch_size + j, include_answer=False))
                train_prompts = [gen_prompt(dev_df, task_name, shot)] * q_num
                prompt = [t + p for t, p in zip(train_prompts, prompt_ends)]
                labels = [val_df.iloc[i * self.batch_size + j, val_df.shape[1] - 1] for j in range(q_num)]
                prompts = [prpt.encode().decode(encoding="utf8") for prpt in prompt]
                inputs = self.tokenizer(prompts, padding=True, return_tensors=TENSOR_TYPE_PYTORCH, truncation=True,
                                        return_token_type_ids=self.tokenizer_return_type_id).to(self.device)
                outputs = self.model.generate(**inputs, do_sample=False, max_new_tokens=20)
                answers = []
                for idx in range(len(outputs)):
                    output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
                    response = self.tokenizer.decode(output)
                    answers.append(response)

                answer_results = [answer.lstrip()[0] if answer.lstrip() else "-1" for answer in answers]
                is_correct = []
                for answer_result, label in zip(answer_results, labels):
                    is_correct.append("Correct" if answer_result == label else "Wrong")
                correct += is_correct.count("Correct")
            correct_total += correct
            sum_total += task_len
            index += 1
        if sum_total == 0:
            self.logger.error(
                f"Did not ran any test, maybe wrong cevel_5_shot dataset folder. Current folder is {self.dataset_path}")
        else:
            self.logger.info(
                f"correct rate:{correct_total / sum_total}, correct total:{correct_total}, sum_total:{sum_total}")

    def __run_full_dataset_humaneval(self):
        def cleanup_code(code: str) -> str:
            code_splits = code.split("\n")
            is_empty_line = False
            ind_empty_line = None
            for i, line in enumerate(code_splits):
                if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
                    is_empty_line = True
                    ind_empty_line = i
                    break
            if is_empty_line:
                code = "\n".join(code_splits[:ind_empty_line])
            else:
                end_words = ["\ndef", "\nclass", "\n#", "\nassert", '\n"""', "\nprint", "\nif", "\n\n\n"]
                for w in end_words:
                    if w in code:
                        code = code[:code.rfind(w)]
            return code

        def run_test():
            for entry in tqdm(glob.glob((Path(self.dataset_path) / "*.jsonl").as_posix(), recursive=True),
                              desc='global'):
                dataset = []
                entry = get_valid_read_path(entry)
                with open(entry, encoding='utf-8') as f:
                    for line in f:
                        line_json = json.loads(line)
                        dataset.append(line_json)

                samples = []
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size)
                for _, batch in enumerate(tqdm(dataloader)):
                    task_ids = [task_id.split('/')[1] for task_id in batch["task_id"]]
                    queries = [prompt.strip() for prompt in batch["prompt"]]
                    inputs = self.tokenizer(queries, padding=True, return_tensors=TENSOR_TYPE_PYTORCH,
                                            truncation=True).to(self.device)
                    outputs = self.model.generate(inputs=inputs.input_ids,
                                                  attention_mask=inputs.attention_mask, do_sample=False,
                                                  max_new_tokens=512)
                    for idx, output in enumerate(outputs.tolist()):
                        output = output[len(inputs["input_ids"][idx]):]
                        response = self.tokenizer.decode(output)
                        response_cleaned_up = cleanup_code(response)
                        result = dict(
                            task_id="HumanEval/" + task_ids[idx],
                            completion=response_cleaned_up,
                        )
                        samples += [result]
                self.__save_humaneval_res(samples)
                self.logger.info(f"Single question check result has been written to {self.result_file}.")
                result = evaluate_functional_correctness(self.result_file, [1], 4, 3.0, entry)
                self.logger.info(result)

        try:
            from human_eval.evaluation import evaluate_functional_correctness
        except ImportException as e:
            self.logger.error(f"Import human_eval test package failed. Error info is as below. \n {e}")
            return

        with torch.no_grad():
            try:
                run_test()
            except Exception as e:
                self.logger.error(f"Unable to run humaneval test. {e}")

    def __run_full_dataset_boolq(self):
        sample_yes = "How can we learning machine learning: yes"
        sample_no = "How can we learning machine learning: no"
        choice_tokens = [
            self.tokenizer([sample_yes], return_tensors=TENSOR_TYPE_PYTORCH, max_length=2048,
                           add_special_tokens=False).input_ids[0, -1].item(),
            self.tokenizer([sample_no], return_tensors=TENSOR_TYPE_PYTORCH, max_length=2048,
                           add_special_tokens=False).input_ids[0, -1].item()
        ]

        def build_prompt(title, text, passage):
            prompt = f"{title} -- {passage}\nQuestion: {text}?\nAnswer:"
            return prompt

        def run_test(choice_tokens, correct_total, sum_total):
            for entry in tqdm(glob.glob((Path(self.dataset_path) / "*.jsonl").as_posix(), recursive=True),
                              desc='global'):
                dataset = []
                entry = get_valid_read_path(entry)
                with open(entry, encoding='utf-8') as f:
                    for line in f:
                        line_json = json.loads(line)
                        dataset.append(line_json)

                correct = 0
                dataset_num = len(dataset)
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size)
                for _, batch in enumerate(tqdm(dataloader)):
                    titles = batch["title"]
                    texts = batch["question"]
                    passages = batch["passage"]
                    queries = []
                    for title, query, passage in zip(titles, texts, passages):
                        queries.append(build_prompt(title, query, passage))
                    inputs = self.tokenizer(queries, padding=True, return_tensors=TENSOR_TYPE_PYTORCH, truncation=True,
                                            return_token_type_ids=self.tokenizer_return_type_id).to(self.model.device)
                    outputs = self.model(**inputs)
                    logits = outputs.logits[:, -1, :]
                    logits_softmax = F.log_softmax(logits.float(), dim=-1)

                    logits_softmax = logits_softmax[:, choice_tokens]
                    for idx, ans in enumerate(batch['answer']):
                        choice = (logits_softmax[idx, 0] > logits_softmax[idx, 1]).cpu()
                        correct += 1 if choice == ans else 0
                correct_total += correct
                sum_total += dataset_num
            return correct_total, sum_total

        correct_total = 0
        sum_total = 0
        with torch.no_grad():
            correct_total, sum_total = run_test(choice_tokens, correct_total, sum_total)
        if sum_total == 0:
            self.logger.error("Did not ran any test, maybe wrong humaneval dataset folder.")
        else:
            self.logger.info(
                f"correct rate:{correct_total / sum_total}, correct total:{correct_total}, sum_total:{sum_total}")

    def __save_humaneval_res(self, results):
        self.result_file = os.path.dirname(os.path.abspath(__file__)) + os.sep + "result.jsonl"
        self.result_file = get_valid_write_path(self.result_file)
        mode = stat.S_IWUSR | stat.S_IRUSR
        flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
        with os.fdopen(os.open(self.result_file, flags=flags, mode=mode), "w") as fp:
            for result in results:
                fp.write((json.dumps(result) + "\n"))

    def __run_full_dataset_truthfulqa(self):
        best_col = 'Best Answer'
        answer_col = 'Correct Answers'
        incorrect_col = 'Incorrect Answers'

        def run_answers():
            frame = pd.read_csv((Path(self.dataset_path) / "TruthfulQA.csv").as_posix())
            frame.dropna(axis=1, how='all', inplace=True)

            if tag not in frame.columns:
                frame[tag] = ''

            frame[tag].fillna('', inplace=True)
            frame[tag] = frame[tag].astype(str)

            num_rows = frame.shape[0]
            num_batches = math.ceil(num_rows / self.batch_size)

            seq_start = np.array(tokenizer('A:')['input_ids'])
            seq_end = np.array(tokenizer('Q:')['input_ids'])
            try:
                self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            except AttributeError:
                self.logger.info("this model can't set attribute 'pad_token'")
            with torch.no_grad():
                for batch in tqdm(range(num_batches)):
                    q_num = self.batch_size if (batch + 1) * self.batch_size <= num_rows \
                        else num_rows - self.batch_size * batch
                    idx_list = [i for i in range(batch * self.batch_size, batch * self.batch_size + q_num)]
                    prompt = [truthfulqa_eval.format_prompt(frame.loc[idx]) for idx in idx_list]
                    input_ids = tokenizer(prompt, padding=True, return_tensors=TENSOR_TYPE_PYTORCH,
                                            truncation=True).input_ids
                    max_len = input_ids.shape[-1] + 50
                    input_ids = input_ids.to(self.device)
                    outputs = self.model.generate(input_ids, do_sample=False, max_length=max_len,
                                                    pad_token_id=self.tokenizer.eos_token_id)
                    output_token_ids_list = [output[len(input_ids[idx]):] 
                                                for idx, output in enumerate(outputs.tolist())] 
                    gen_arrs = np.array(output_token_ids_list)
                    
                    idx_start = [truthfulqa_eval.find_subsequence(gen_arr, seq_start, start=True)
                                    for gen_arr in gen_arrs]
                    idx_end = [truthfulqa_eval.find_subsequence(gen_arr, seq_end, start=False)
                                for gen_arr in gen_arrs]

                    output_token_ids_list = [output_token_ids[idx_start[output_token_ids_idx]:
                                                                idx_end[output_token_ids_idx]] 
                    for output_token_ids_idx, output_token_ids in enumerate(output_token_ids_list)]
                    output_strs = [tokenizer.decode(output_token_ids, skip_special_tokens=True)
                                    for output_token_ids in output_token_ids_list]
                    output_str = [output_str.strip() for output_str in output_strs]

                    for idx in idx_list:
                        frame.loc[idx, tag] = output_str[idx % self.batch_size] 
            return frame
        
        def run_probs(frame):
            truthfulqa_eval.set_columns(tag, frame)
            with torch.no_grad():
                for idx in tqdm(frame.index):
                    if pd.isnull(frame.loc[idx, incorrect_col]):
                        self.logger.warning("References missing for {0}!".format(idx))
                        continue
                    if len(frame.loc[idx, incorrect_col]) == 0:
                        self.logger.warning("References missing for {0}!".format(idx))
                        continue

                    ref_best = truthfulqa_eval.format_best(frame.loc[idx, best_col])
                    ref_true = truthfulqa_eval.split_multi_answer(frame.loc[idx, answer_col])
                    ref_false = truthfulqa_eval.split_multi_answer(frame.loc[idx, incorrect_col])

                    input_prompt = truthfulqa_eval.format_prompt(frame.loc[idx])

                    scores_true = get_scores(input_prompt, frame, idx, ref_true)
                    scores_false = get_scores(input_prompt, frame, idx, ref_false)
                    calc_param = CalcParam(tag, frame, idx, scores_true, scores_false, ref_true, ref_best)
                    frame = truthfulqa_eval.mc_calcs(calc_param)
            return frame

        def get_scores(input_prompt, frame, idx, ref_answer):
            scores_answer = []
            for temp_ans in ref_answer:
                prompt = [truthfulqa_eval.format_prompt_with_answer_strings(frame.loc[idx, 'Question'], temp_ans)]
                input_ids = tokenizer(input_prompt, return_tensors=TENSOR_TYPE_PYTORCH).input_ids
                prompt_ids = tokenizer(prompt, return_tensors=TENSOR_TYPE_PYTORCH).input_ids
                input_ids = input_ids.to(self.device)
                prompt_ids = prompt_ids.to(self.device)
                logits = self.model(prompt_ids)[0].squeeze(0)
                logits = logits[input_ids.shape[-1] - 1: -1, :]

                logits_softmax = F.log_softmax(logits.float(), dim=-1)
                prompt_ids = prompt_ids[0, input_ids.shape[-1]:]
                log_probs = logits_softmax[range(logits_softmax.shape[0]), prompt_ids.squeeze(0)]
                log_probs = log_probs[3:]
                scores_answer.append(log_probs.sum().item())
            return scores_answer

        tokenizer = self.tokenizer
        tag = "test_model"
        
        frame = run_answers()
        frame = run_probs(frame)
        frame = truthfulqa_eval.run_bleu_and_rouge(tag, frame)
    
        results = truthfulqa_eval.format_frame(frame)

        results = results.mean(axis=0)
        results = results.reset_index().rename(columns={'level_0': 'Model',
                                                        'level_1': 'Metric',
                                                        0: 'Value'})

        results = results[results['Metric'].isin(['MC1', 'MC2',
                                                'bleu diff',
                                                'rouge1 diff',
                                                'BLEURT diff'])]
        
        results = pd.pivot_table(results, 'Value', 'Model', 'Metric')

        results = results.rename(columns={'bleu diff': 'BLEU',
                                            'rouge1 diff': 'ROUGE',
                                            'BLEURT diff': 'BLEURT'})
        self.logger.info(f"results:{results}")

    def __run_full_dataset_mmlu(self):
        if self.shot == 0:
            self.__run_full_dataset_mmlu_0_shot()
        else:
            self.__run_full_dataset_mmlu_few_shots()

    def __run_full_dataset_mmlu_0_shot(self):
        def get_subject_mapping():            
            subject_mapping_path = os.path.join(self.dataset_path, "subject_mapping.json")
            subject_mapping = json_safe_load(subject_mapping_path)
            return subject_mapping
        
        def load_csv_by_task_name(task_name, dataset_path):
            csv_file_path = get_valid_path(os.path.join(dataset_path, "test", task_name + "_test.csv"))
            val_df = pd.read_csv(csv_file_path, header=None)
            return val_df
        
        def format_example(name, df, idx):
            question = df.iloc[idx, 0]
            a = df.iloc[idx, 1]
            b = df.iloc[idx, 2]
            c = df.iloc[idx, 3]
            d = df.iloc[idx, 4]
            prompt = "\nThere is a single choice question about {}. Answer the question by replying A, B, C or D.\n" \
                     "Q: {}\nA. {}\nB. {}\nC. {}\nD. {}\n" \
                     "Let's think step by step. A: ".format(name.replace("_", " "), question, a, b, c, d)
            return prompt

        correct_total = 0
        sum_total = 0
        result_total = []
        try:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        except AttributeError:
            self.logger.info("this model can't set attribute 'pad_token'")

        subject_mapping = get_subject_mapping()
        subject_mapping = subject_mapping["mmlu_all_sets"]
        index = 1
        for task_name in tqdm(subject_mapping):
            self.logger.info(f"dataset {index} start, task name: {task_name}")
            val_df = load_csv_by_task_name(task_name, self.dataset_path)
            correct = 0
            task_len = val_df.shape[0]
            if task_len == 0:
                raise ValueError(f"dateset {task_name} is empty, please check.")
            for i in range(math.ceil(task_len / self.batch_size)):
                q_num = self.batch_size if (i + 1) * self.batch_size <= task_len else task_len - i * self.batch_size
                name = task_name
                prompt = [format_example(name, val_df, i * self.batch_size + j) for j in range(q_num)]
                labels = [val_df.iloc[i * self.batch_size + j, val_df.shape[1] - 1] for j in range(q_num)]
                prompts = [prpt.encode().decode(encoding="utf8") for prpt in prompt]
                inputs = self.tokenizer(prompts, padding=True, return_tensors=TENSOR_TYPE_PYTORCH, truncation=True)
                inputs = inputs.to(self.device)
                tokenizer_out_ids = inputs.input_ids.to(self.device)
                attention_mask = inputs.attention_mask.to(self.device)
                outputs = self.model.generate(inputs=tokenizer_out_ids, attention_mask=attention_mask,
                                                do_sample=False, max_new_tokens=1024,
                                                pad_token_id=self.tokenizer.eos_token_id)
                answers = []
                for idx, output in enumerate(outputs.tolist()):
                    output = output[len(inputs["input_ids"][idx]):]
                    answers.append(self.tokenizer.decode(output))
                
                answer_results = [self.__postprocess(answer, "ABCD") for answer in answers]
                is_correct = []
                for answer_result, label in zip(answer_results, labels):
                    is_correct.append("Correct" if answer_result == label else "Wrong")
                correct += is_correct.count("Correct")

            result = [task_name, correct / task_len, correct, task_len]
            self.logger.info(f"dataset {index} finish, result:{result}")
            result_total.append(result)
            correct_total += correct
            sum_total += task_len
            index += 1
        if sum_total == 0:
            self.logger.error(
                f"Did not run any test, maybe wrong mmlu dataset folder. Current folder is {self.dataset_path}")
            return
        total = ["total", correct_total / sum_total, correct_total, sum_total]
        self.logger.info(f"total result:{total}")

    
    def __run_full_dataset_mmlu_few_shots(self): 
        choices = ["A", "B", "C", "D"]
        test_set = {"mmlu": "test"}

        def get_subject_mapping():
            subject_mapping_path = os.path.join(self.dataset_path, "subject_mapping.json")
            subject_mapping = json_safe_load(subject_mapping_path)
            return subject_mapping["mmlu_all_sets"]
        
        def load_csv_by_task_name(task_name, dataset_path):
            row_begin_idx = 0
            col_begin_idx = 0
            dev_file_path = get_valid_path(os.path.join(dataset_path, "dev", task_name + "_dev.csv"))
            ori_dev_df = pd.read_csv(dev_file_path, header=None)
            val_file_path = get_valid_path(os.path.join(dataset_path, test_set.get(self.dataset), 
                                                  f"{task_name}_{test_set.get(self.dataset)}.csv"))
            ori_val_df = pd.read_csv(val_file_path, header=None)

            dev_df = ori_dev_df.iloc[row_begin_idx:row_begin_idx + self.shot, col_begin_idx:]
            val_df = ori_val_df.iloc[row_begin_idx:, col_begin_idx:]
            return dev_df, val_df
        
        def format_subject(subject):
            sub_set = subject.split("_")
            res = ""
            for entry in sub_set:
                res += " " + entry    
            return res
            
        def format_example(df, idx, include_answer=True):
            prompt = df.iloc[idx, 0]
            k = len(choices)
            for j in range(k):
                prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
            prompt += "\nAnswer:"
            if include_answer:
                prompt += " {}\n\n".format(df.iloc[idx, k + 1])
            return prompt
        
        def gen_prompt(train_df, subject, k=-1):
            prompt = "The following are multiple choice questions " \
                     "(with answers) about {}.\n\n".format(format_subject(subject))
            if k == -1:
                k = train_df.shape[0]
            for i in range(k):
                prompt += format_example(train_df, i)
            return prompt

        correct_total = 0
        sum_total = 0
        result_total = []
        is_result = False
        try:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        except AttributeError:
            self.logger.info("this model can't set attribute 'pad_token'")

        subject_mapping = get_subject_mapping()
        index = 1
        for task_name in tqdm(subject_mapping):
            self.logger.info(f"dataset {index} start, task name: {task_name}")
            dev_df, val_df = load_csv_by_task_name(task_name, self.dataset_path)
            correct = 0
            task_len = val_df.shape[0]
            if task_len == 0:
                raise ValueError(f"dateset {task_name} is empty, please check.")
            for i in range(math.ceil(task_len / self.batch_size)):
                q_num = self.batch_size if (i + 1) * self.batch_size <= task_len else task_len - i * self.batch_size
                prompt_ends = []
                for j in range(q_num):
                    prompt_ends.append(format_example(val_df, i * self.batch_size + j, include_answer=False))                
                train_prompts = [gen_prompt(dev_df, task_name, self.shot)] * q_num
                prompt = [t + p for t, p in zip(train_prompts, prompt_ends)]
                labels = [val_df.iloc[i * self.batch_size + j, val_df.shape[1] - 1] for j in range(q_num)]
                prompts = [prpt.encode().decode(encoding="utf8") for prpt in prompt]
                inputs = self.tokenizer(prompts, padding=True, return_tensors=TENSOR_TYPE_PYTORCH, truncation=True)
                inputs = inputs.to(self.device)
                tokenizer_out_ids = inputs.input_ids.to(self.device)
                attention_mask = inputs.attention_mask.to(self.device)
                outputs = self.model.generate(inputs=tokenizer_out_ids, attention_mask=attention_mask, do_sample=False,
                                                max_new_tokens=20,
                                                pad_token_id=self.tokenizer.eos_token_id)
                answers = []
                for idx in range(len(outputs)):
                    output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
                    response = self.tokenizer.decode(output)
                    answers.append(response)

                answer_results = [answer.lstrip()[0] if answer.lstrip() else "-1" for answer in answers]
                is_correct = []
                for answer_result, label in zip(answer_results, labels):
                    is_correct.append("Correct" if answer_result == label else "Wrong")
                correct += is_correct.count("Correct")

            result = [task_name, correct / task_len, correct, task_len]
            self.logger.info(f"dataset {index} finish, result:{result}")
            result_total.append(result)
            correct_total += correct
            sum_total += task_len
            index += 1
        
        if sum_total == 0:
            self.logger.error(
                f"Did not run any test, maybe wrong mmlu dataset folder. Current folder is {self.dataset_path}")
            return
        total = ["total", correct_total / sum_total, correct_total, sum_total]
        self.logger.info(f"total result:{total}")