import re
from functools import partial
from typing import List, Tuple

from torch import distributed as dist

agieval_single_choice_sets = [
    'gaokao-chinese',
    'gaokao-english',
    'gaokao-geography',
    'gaokao-history',
    'gaokao-biology',
    'gaokao-chemistry',
    'gaokao-mathqa',
    'logiqa-zh',
    'lsat-ar',
    'lsat-lr',
    'lsat-rc',
    'logiqa-en',
    'sat-math',
    'sat-en',
    'sat-en-without-passage',
    'aqua-rat',
]

agieval_multiple_choices_sets = [
    'gaokao-physics',
    'jec-qa-kd',
    'jec-qa-ca',
]

agieval_cloze_sets = ['gaokao-mathcloze', 'math']

agieval_chinese_sets = [
    'gaokao-chinese',
    'gaokao-english',
    'gaokao-geography',
    'gaokao-history',
    'gaokao-biology',
    'gaokao-chemistry',
    'gaokao-physics',
    'gaokao-mathqa',
    'logiqa-zh',
    'gaokao-mathcloze',
    'jec-qa-kd',
    'jec-qa-ca',
]

agieval_english_sets = [
    'lsat-ar',
    'lsat-lr',
    'lsat-rc',
    'logiqa-en',
    'sat-math',
    'sat-en',
    'sat-en-without-passage',
    'aqua-rat',
    'math',
]

agieval_gaokao_sets = [
    'gaokao-chinese',
    'gaokao-english',
    'gaokao-geography',
    'gaokao-history',
    'gaokao-biology',
    'gaokao-chemistry',
    'gaokao-physics',
    'gaokao-mathqa',
]


template_mapping = {
        'gaokao-chinese':
        '以下是一道中国高考语文选择题,请选择正确的答案。',
        'gaokao-english':
        '以下是一道中国高考英语选择题,请选择正确的答案。',
        'gaokao-geography':
        '以下是一道中国高考地理选择题,请选择正确的答案。',
        'gaokao-history':
        '以下是一道中国高考历史选择题,请选择正确的答案。',
        'gaokao-biology':
        '以下是一道中国高考生物选择题,请选择正确的答案。',
        'gaokao-chemistry':
        '以下是一道中国高考化学选择题,请选择正确的答案。',
        'gaokao-physics':
        '以下是一道中国高考物理选择题,请选择正确的答案。',
        'gaokao-mathqa':
        '以下是一道中国高考数学选择题,请选择正确的答案。',
        'logiqa-zh':
        '以下是一道中国公务员考试题,请选择正确的答案。',
        'lsat-ar':
        'The following is a LSAT Analytical Reasoning question. Please select the correct answer.',
        'lsat-lr':
        'The following is a LSAT Logical Reasoning question. Please select the correct answer.',
        'lsat-rc':
        'The following is a LSAT Reading Comprehension question. Please select the correct answer.',
        'logiqa-en':
        'The following is a Logic Reasoning question. Please select the correct answer.',
        'sat-math':
        'The following is a SAT Math question. Please select the correct answer.',
        'sat-en':
        'The following is a SAT English question. Please select the correct answer.',
        'sat-en-without-passage':
        'The following is a SAT English question. Please select the correct answer.',
        'aqua-rat':
        'The following is a AQUA-RAT question. Please select the correct answer.',
        'jec-qa-kd':
        '以下是一道中国司法考试基础知识题,请选择正确的答案。',
        'jec-qa-ca':
        '以下是一道中国司法考试案例分析题,请选择正确的答案。',
        'gaokao-mathcloze':
        '以下是一道中国高考数学填空题,请填入正确的答案。',
        'math':
        'The following is a Math question. Please select the correct answer.',
    }


def get_default_instruction(item):
    if item['passage']:
        question = item['passage'] + '\n' + item['question']
    else:
        question = item['question']
    if item['options']:
        options = '\n'.join(item['options'])
    else:
        options = ""
    if item['label']:
        if isinstance(item['label'], list):
            correct = ','.join(item['label'])
        else:
            correct = item['label']
    else:
        if item['answer']:
            correct = item['answer'].replace('$', '')
        else:
            correct = None
    return question, options, correct


def alternativate_prompt_instruction(item, subject_name):
    if subject_name in agieval_chinese_sets:
        _hint = '答案是: '
    else:
        _hint = 'The answer is '
    prompt = f'{{question}}\n{{options}}\n{_hint}'
    if item['passage']:
        question = item['passage'] + item['question']
    else:
        question = item['question']
    if item['options']:
        options = '\n'.join(item['options'])
    else:
        options = ""
    if item['label']:
        if isinstance(item['label'], list):
            correct = ','.join(item['label'])
        else:
            correct = item['label']
    else:
        if item['answer']:
            correct = item['answer'].replace('$', '')
        else:
            correct = None
    if options:
        prompt = prompt.format(question=question,
                            options=options.strip(),
                            _hint=_hint
        )
    else:
        prompt = f'{{question}}\n{_hint}'
        prompt = prompt.format(question=question,
                            _hint=_hint
        )
    if item['options']:
        num_choice = len(item['options'])
        choices = generate_alphabet_string(num_choice)
    else:
        choices = None

    prompt = template_mapping[subject_name] + '\n' + prompt
    return prompt, correct, choices


def generate_alphabet_string(length):
    if length <= 0:
        return ""

    result = ''.join(chr(ord('A') + i) for i in range(length))
    return result


def get_pred_postprocess_func(subject_name, choices):
    if subject_name in agieval_multiple_choices_sets:
        return first_capital_postprocess_multi
    
    if subject_name in agieval_single_choice_sets:
        return partial(first_option_postprocess, options=choices)

    if subject_name in agieval_cloze_sets:
        return parse_math_answer
    
    raise ValueError(f"Unknown subject_name: {subject_name}")


def first_option_postprocess(text: str, options: str, cushion=True) -> str:
    """Find first valid option for text."""

    patterns = [
        f'答案是?\s*([{options}])',
        f'答案是?\s*:\s*([{options}])',
        f'答案是?\s*:\s*([{options}])',
        f'答案选项应?该?是\s*([{options}])',
        f'答案选项应?该?为\s*([{options}])',
        f'答案应该?是\s*([{options}])',
        f'答案应该?选\s*([{options}])',
        f'答案选项为?\s*:\s*([{options}])',
        f'答案选项为?\s+\(?\*?\*?([{options}])\*?\*?\)?',
        f'答案选项是?\s*:\s*([{options}])',
        f'答案为\s*([{options}])',
        f'答案选\s*([{options}])',
        f'选择?\s*([{options}])',
        f'故选?\s*([{options}])'
        f'只有选?项?\s?([{options}])\s?是?对',
        f'只有选?项?\s?([{options}])\s?是?错',
        f'只有选?项?\s?([{options}])\s?不?正确',
        f'只有选?项?\s?([{options}])\s?错误',
        f'说法不?对选?项?的?是\s?([{options}])',
        f'说法不?正确选?项?的?是\s?([{options}])',
        f'说法错误选?项?的?是\s?([{options}])',
        f'([{options}])\s?是正确的',
        f'([{options}])\s?是正确答案',
        f'选项\s?([{options}])\s?正确',
        f'所以答\s?([{options}])',
        f'所以\s?([{options}][.。$]?$)',
        f'所有\s?([{options}][.。$]?$)',
        f'[\s,::,]([{options}])[。,,\.]?$',
        f'[\s,,::][故即]([{options}])[。\.]?$',
        f'[\s,,::]因此([{options}])[。\.]?$',
        f'[是为。]\s?([{options}])[。\.]?$',
        f'因此\s?([{options}])[。\.]?$',
        f'显然\s?([{options}])[。\.]?$',
        f'答案是\s?(\S+)(?:。|$)',
        f'答案应该是\s?(\S+)(?:。|$)',
        f'答案为\s?(\S+)(?:。|$)',
        f'(?i)ANSWER\s*:\s*([{options}])',
        f'[Tt]he answer is:?\s+\(?([{options}])\)?',
        f'[Tt]he answer is:?\s+\(?\*?\*?([{options}])\*?\*?\)?',
        f'[Tt]he answer is option:?\s+\(?([{options}])\)?',
        f'[Tt]he correct answer is:?\s+\(?([{options}])\)?',
        f'[Tt]he correct answer is option:?\s+\(?([{options}])\)?',
        f'[Tt]he correct answer is:?.*?boxed{{([{options}])}}',
        f'[Tt]he correct option is:?.*?boxed{{([{options}])}}',
        f'[Tt]he correct answer option is:?.*?boxed{{([{options}])}}',
        f'[Tt]he answer to the question is:?\s+\(?([{options}])\)?',
        f'^选项\s?([{options}])',
        f'^([{options}])\s?选?项',
        f'(\s|^)[{options}][\s。,,::\.$]',
        f'1.\s?(.*?)$',
        f'1.\s?([{options}])[.。$]?$',
    ]
    cushion_patterns = [
        f'([{options}]):',
        f'([{options}])',
    ]

    if cushion:
        patterns.extend(cushion_patterns)
    for pattern in patterns:
        text = text.strip()
        match = re.search(pattern, text, re.DOTALL)
        if match:
            if match.group(1) is not None and match.group(1) != '':
                outputs = match.group(1)
            else:
                outputs = match.group(0)
            for i in options:
                if i in outputs:
                    return i
    return ''


def first_capital_postprocess_multi(text: str) -> str:
    match = re.search(r'([A-D]+)', text)
    if match:
        return match.group(1)
    return ''


def parse_math_answer(raw_string):

    def remove_boxed(s):
        left = '\\boxed{'
        try:
            if s[:len(left)] != left:
                raise ValueError(f"s[:len(left)] should equal to left.")
            if s[-1] != '}':
                raise ValueError(f"s[-1] is not correct.")
            answer = s[len(left):-1]
            if '=' in answer:
                answer = answer.split('=')[-1].lstrip(' ')
            return answer
        except ValueError:
            return None

    def last_boxed_only_string(string):
        idx = string.rfind('\\boxed')
        if idx < 0:
            idx = string.rfind('\\fbox')
            if idx < 0:
                return None
        i = idx
        right_brace_idx = None
        num_left_braces_open = 0
        while i < len(string):
            if string[i] == '{':
                num_left_braces_open += 1
            if string[i] == '}':
                num_left_braces_open -= 1
                if num_left_braces_open == 0:
                    right_brace_idx = i
                    break
            i += 1

        if right_brace_idx is None:
            retval = None
        else:
            retval = string[idx:right_brace_idx + 1]

        return retval

    def get_answer_with_dollar_sign(s):
        first_pattern = '\$(.*)\$'
        last_match = None
        matches = re.findall(first_pattern, s)
        if matches:
            last_match = matches[-1]
            if '=' in last_match:
                last_match = last_match.split('=')[-1].lstrip(' ')
        return last_match

    def get_answer_without_dollar_sign(s):
        last_match = None
        if '=' in s:
            last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
            if '\\n' in last_match:
                last_match = last_match.split('\\n')[0]
        else:
            pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
            matches = re.findall(pattern, s)
            if matches:
                last_match = matches[-1]
        return last_match

    raw_string = remove_few_shot_prefix(raw_string)
    if '\\boxed' in raw_string:
        answer = remove_boxed(last_boxed_only_string(raw_string))
    else:
        answer = get_answer_with_dollar_sign(raw_string)
        if not answer:
            answer = get_answer_without_dollar_sign(raw_string)
    return answer


def remove_few_shot_prefix(string: str):
    prefix_list = ['The answer is therefore', '答案是']
    for prefix in prefix_list:
        if string.startswith(prefix):
            string = string[len(prefix):].strip()
        elif prefix in string:
            index = string.rfind(prefix)
            if index >= 0:
                string = string[index + len(prefix):].strip()
    return string