from typing import Callable
import string

import torch
import pandas as pd

from mindspeed_mm.tasks.evaluation.utils.file_utils import parse_file
from mindspeed_mm.tasks.evaluation.utils.string_utils import is_cn_string


class BasePromptTemplate:

    def __init__(self, use_custom_prompt):
        self._use_custom_prompt = use_custom_prompt
        self.device = torch.cuda.current_device()

    def build_prompt(self, line, dump_image: Callable, dataset_name=None):
        raise NotImplementedError('you must implement build_prompt')

    @staticmethod
    def check_content_type(message):
        """Check the content type of the input. Four types are allowed: str, dict, ListOfString, ListOfDict.
        """
        if isinstance(message, str):
            return 'String'
        elif isinstance(message, dict):
            return 'Dict'
        elif isinstance(message, list):
            types = [BasePromptTemplate.check_content_type(m) for m in message]
            if all(t == 'String' for t in types):
                return 'ListOfString'
            elif all(t == 'Dict' for t in types):
                return 'ListOfDict'
            else:
                raise ValueError(f'Unknown prompt type')
        else:
            raise ValueError(f'Unknown prompt type')

    @staticmethod
    def _build_prompt(line):
        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += '\n请直接回答选项字母。' if is_cn_string(
                prompt) else "\nAnswer with the option's letter from the given choices directly."
        else:
            prompt += '\n请直接回答问题。' if is_cn_string(prompt) else '\nAnswer the question directly.'
        return prompt

    @staticmethod
    def preprocess_content(inputs):

        """Convert the raw input messages to a list of dicts.

        Args:
            inputs: raw input messages.

        Returns:
            list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
        """
        if BasePromptTemplate.check_content_type(inputs) == 'String':
            return [dict(type='text', value=inputs)]
        elif BasePromptTemplate.check_content_type(inputs) == 'Dict':
            if not ('type' in inputs and 'value' in inputs):
                raise ValueError("inputs must contain both 'type' and 'value' keys")
            return [inputs]
        elif BasePromptTemplate.check_content_type(inputs) == 'ListOfString':
            res = []
            for s in inputs:
                mime, pth = parse_file(s)
                if mime is None or mime == 'Unknown':
                    res.append(dict(type='text', value=s))
                else:
                    res.append(dict(type=mime.split('/')[0], value=pth))
            return res
        elif BasePromptTemplate.check_content_type(inputs) == 'ListOfDict':
            for item in inputs:
                if not ('type' in item and 'value' in item):
                    raise ValueError("item must contain both 'type' and 'value' keys")
                mime, s = parse_file(item['value'])
                if mime is None:
                    if item['type'] != 'text':
                        raise ValueError("item type must be 'text'")
                else:
                    if not mime.split('/')[0] == item['type']:
                        raise ValueError(f"The MIME type '{mime}' does not match the expected type '{item['type']}'")
                    item['value'] = s
            return inputs
        else:
            return None