# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
# Copyright 2020 OpenCompass Authors. All rights reserved.
# Copyright (c) Microsoft Corporation.
import enum
import math
import random
import os
import json
import tempfile
from dataclasses import dataclass
import tiktoken
from mindspeed_llm.tasks.evaluation import file_utils

DATA = "data"
FULL = "full"
NEEDLEBENCH = "needlebench"


class Position(enum.Enum):
    START = 0
    END = 1


@dataclass
class NeedleBenchConfig:
    dataset_dir: str
    context_length_list: list
    depth_percent_list: list
    file_names: list
    language: str
    length_buffer: int
    num_repeats_per_file: int = 10
    tokenizer_model: str = 'gpt-4'
    position: Position = Position.END


def logistic(x, length=100, x0=50, k=0.1):
    return round(length / (1 + math.exp(-k * (x - x0))), 3)


def generate_linear_space(start, end, num):
    if num == 1:
        return [start]
    elif num < 1:
        raise ValueError('num must be at least 1.')
    step = (end - start) / (num - 1)
    return [start + step * i for i in range(num)]


def generate_depth_percents(intervals, interval_type):
    if interval_type == 'linear':
        return generate_linear_space(0, 100, intervals)
    elif interval_type == 'sigmoid':
        linear_space = generate_linear_space(0, 100, intervals)
        return [logistic(x) for x in linear_space]
    else:
        raise ValueError('Unsupported interval type')


def load_datasets(config: NeedleBenchConfig):
    def get_random_line(counter, file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            lines = [
                json.loads(line.strip())
                for line in file
                if json.loads(line.strip())['language'] == config.language
            ]

        if lines:
            random.seed(counter)
            random_line = random.choice(lines)
            return {
                'needle': random_line['needle'],
                'retrieval_question': random_line['retrieval_question'],
                'keyword': random_line['arg2']
            }
        else:
            return {}

    def _generate_context(tokens_context, depth_percent, tokenizer, needle):
        tokens_needle = tokenizer.encode(needle)
        insertion_point = int(len(tokens_context) * (depth_percent / 100))
        tokens_context = (tokens_context[:insertion_point] +
                          tokens_needle + tokens_context[insertion_point:])
        new_context = tokenizer.decode(tokens_context)
        return new_context

    def _modify_retrieval_question(retrieval_question):
        language = config.language
        if language == 'Chinese':
            parts = retrieval_question.split('请按照')
            guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题'
                                                   '最相关的内容是什么。请按照' + parts[1])
            return guide_retrieval_question
        elif language == 'English':
            parts = retrieval_question.split('Please answer in the format')
            guide_retrieval_question = (
                    parts[0] + 'Before answering, please consider'
                               ' what in the document is most relevant to this question.'
                               ' Please answer in the format' + parts[1])
            return guide_retrieval_question
        else:
            raise ValueError(f"Language '{language}' is not supported.")

    def _generate_prompt(context, retrieval_question):  # End: 将问句添加到末尾
        retrieval_question = _modify_retrieval_question(
            retrieval_question)
        language = config.language
        position = config.position
        if language == 'Chinese':
            if position == Position.END:
                prompt = (
                    '你是一个善于回答用户问题的智能AI助手\n'
                    '请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
                    ',或重复你的回答\n'
                    f'用户现在给你的文档是{context}\n\n'
                    f'现在请问:{retrieval_question}'
                )
            elif position == Position.START:
                prompt = (
                    '你是一个善于回答用户问题的智能AI助手\n'
                    '请保持你的回答简洁清楚。不要说和下面文档中的无关的话'
                    ',或重复你的回答\n'
                    f'现在请问:{retrieval_question}',
                    f'用户现在给你的文档是{context}\n\n'
                )
            else:
                raise ValueError('Unsupported position. '
                                 'Position must be POSITION_END or POSITION_START.')
        elif language == 'English':
            if position == Position.END:
                prompt = ('You are an intelligent AI assistant skilled in '
                          'answering user questions.\n'
                          'Please keep your answers concise and clear. Do '
                          'not talk about irrelevant topics or repeat '
                          'your answers.\nThe document '
                          f'given to you by the user is {context}\n\n'
                          f'Now, the question is: {retrieval_question}')
            elif position == Position.START:
                prompt = ('You are an intelligent AI assistant skilled in '
                          'answering user questions.\n'
                          'Please keep your answers concise and clear. Do '
                          'not talk about irrelevant topics or repeat '
                          'your answers.\n'
                          f'Now, the question is: {retrieval_question}'
                          'The document given to you by the user'
                          f' is {context}\n\n')
            else:
                raise ValueError(f'Unsupported position {position}. '
                                 'Position must be POSITION_END or POSITION_START.')
        else:
            raise ValueError(f"Language '{language}' is not supported.")

        return prompt

    def load_dataset(config, original_context_length, depth_percent):
        dataset = []
        needle_file_path = os.path.join(config.dataset_dir, "needles.jsonl")
        cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = "9b5ad71b2ce5302211f9c61530b329a4922fc6a4"
        dest_filepath = os.path.join(cache_dir, cache_file)
        # 参考docs中的needlebench-evaluation.md下载cl100k_base.tiktoken,将其放置到该数据集文件夹
        if not file_utils.is_path_exists(dest_filepath):
            src_file = os.path.join(config.dataset_dir, "cl100k_base.tiktoken")
            print(dest_filepath)
            print(src_file)
            if os.path.exists(src_file):
                import shutil
                shutil.copy(src_file, dest_filepath)
            else:
                raise FileNotFoundError("[ERROR] tiktoken is needed, no this file in /tmp/data-gym-cache/")

        tokenizer = tiktoken.get_encoding("cl100k_base")
        for file_name in config.file_names:
            file_path = os.path.join(config.dataset_dir, f"{file_name}.jsonl")
            with file_utils.safe_open(file_path, 'r', encoding='utf-8') as f:
                lines_bak = [json.loads(line.strip()) for line in f]
            lines = lines_bak.copy()
            for counter in range(config.num_repeats_per_file):
                random.seed(counter)
                random.shuffle(lines)
                random_needle = get_random_line(counter, needle_file_path)
                needle = '\n' + random_needle.get('needle') + '\n'
                retrieval_question = random_needle.get('retrieval_question')
                keyword = random_needle.get('keyword')

                context_length = original_context_length - config.length_buffer
                target_length_per_record = context_length - len(
                    tokenizer.encode(needle))
                target_length_per_record = max(target_length_per_record, 0)
                accumulated_tokens = []
                for line in lines:
                    tokens_current_line = tokenizer.encode(line.get('text'))
                    accumulated_tokens.extend(tokens_current_line)

                    if len(accumulated_tokens) >= target_length_per_record:
                        break
                data = {'prompt': '', 'answer': ''}
                processed_text = _generate_context(
                    accumulated_tokens[:target_length_per_record], depth_percent,
                    tokenizer, needle)

                processed_prompt = _generate_prompt(processed_text,
                                                    retrieval_question)
                data['prompt'] = processed_prompt
                data['answer'] = needle + '*' + keyword
                dataset.append(data)

        return dataset

    datasets = []
    data = 'data'
    for file_name in config.file_names:
        for original_context_length in config.context_length_list:
            for depth_percent in config.depth_percent_list:
                dataset = {'name': f'{file_name}_{original_context_length}_{depth_percent}', data: ''}
                dataset[data] = load_dataset(config, original_context_length, depth_percent)
                datasets.append(dataset)
    return datasets


def test_single_4k(dataset_dir):
    context_lengths = list(range(1000, 5000, 1000))
    document_depth_percent_intervals = 20
    document_depth_percent_interval_type = 'linear'
    file_list = ['PaulGrahamEssays']
    depths_list = generate_depth_percents(
        document_depth_percent_intervals,
        document_depth_percent_interval_type)
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 600)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets


def test_single_8k(dataset_dir):
    context_lengths = list(range(5000, 9000, 1000))
    document_depth_percent_intervals = 20
    document_depth_percent_interval_type = 'linear'
    file_list = ['PaulGrahamEssays']
    depths_list = generate_depth_percents(
        document_depth_percent_intervals,
        document_depth_percent_interval_type)
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 600)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets


def test_single_32k(dataset_dir):
    context_lengths = [9000, 13000, 17000, 21000, 25000, 29000, 31000, 32000]
    depths_list = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100]
    file_list = ['PaulGrahamEssays']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 3000)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets


def test_single_128k(dataset_dir):
    context_lengths = [16000, 32000, 48000, 64000, 80000, 96000, 112000, 128000]
    depths_list = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100]
    num_repeats_per_file = 10
    if os.environ.get('TEST_MODEL') == 'true':
        context_lengths = [16000, 32000]
        depths_list = [0, 10]
        num_repeats_per_file = 1
    file_list = ['PaulGrahamEssays']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 600,
                               num_repeats_per_file)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200,
                               num_repeats_per_file)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets


def test_single_200k(dataset_dir):
    context_lengths = [16000, 48000, 80000, 112000, 128000, 144000, 176000, 200000]
    depths_list = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100]
    file_list = ['PaulGrahamEssays']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 600)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets


def test_single_256k(dataset_dir):
    context_lengths = [32000, 128000, 256000]
    depths_list = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100]
    file_list = ['PaulGrahamEssays']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 600)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets


def test_single_512k(dataset_dir):
    context_lengths = [32000, 128000, 256000, 512000]
    depths_list = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100]
    file_list = ['PaulGrahamEssays']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 600)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets


def test_single_1000k(dataset_dir):
    context_lengths = [20000, 160000, 300000, 440000, 580000, 720000, 860000, 1000000]
    depths_list = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100]
    file_list = ['PaulGrahamEssays']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'English', 600)
    needlebench_en_datasets = load_datasets(config)

    file_list = ['zh_finance']
    config = NeedleBenchConfig(dataset_dir, context_lengths, depths_list, file_list, 'Chinese', 200)
    needlebench_zh_datasets = load_datasets(config)
    return needlebench_en_datasets + needlebench_zh_datasets