"""

Conversation prompt templates.



We kindly request that you import fastchat instead of copying this file if you wish to use it.

If you have changes in mind, please contribute back so the community can benefit collectively

and continue to maintain these valuable templates.

"""



import dataclasses

from enum import IntEnum, auto

from typing import Dict, List, Tuple, Union





class SeparatorStyle(IntEnum):

    """Separator styles."""

    ADD_COLON_SINGLE = auto()

    LLAMA2 = auto()

    MPT = auto()





@dataclasses.dataclass

class Conversation:

    """A class that manages prompt templates and keeps all conversation history."""



    # The name of this template

    name: str

    # The template of the system prompt

    system_template: str = '{system_message}'

    # The system message

    system_message: str = ''

    # The names of two roles

    roles: Tuple[str] = ('USER', 'ASSISTANT')

    # All messages. Each item is (role, message).

    messages: Tuple[List[str]] = ()

    # The number of few shot examples

    offset: int = 0

    # The separator style and configurations

    sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE

    sep: str = '\n'

    sep2: str = None

    # Stop criteria (the default one is EOS token)

    stop_str: Union[str, List[str]] = None

    # Stops generation if meeting any token in this list

    stop_token_ids: List[int] = None



    def get_prompt(self) -> str:

        """Get the prompt for generation."""

        system_prompt = self.system_template.format(system_message=self.system_message)

        if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:

            ret = system_prompt + self.sep

            for role, message in self.messages:

                if message:

                    ret += role + ': ' + message + self.sep

                else:

                    ret += role + ':'

            return ret

        if self.sep_style == SeparatorStyle.LLAMA2:

            seps = [self.sep, self.sep2]

            if self.system_message:

                ret = system_prompt

            else:

                ret = '[INST] '

            for i, (role, message) in enumerate(self.messages):

                tag = self.roles[i % 2]

                if message:

                    if i == 0:

                        ret += message + ' '

                    else:

                        ret += tag + ' ' + message + seps[i % 2]

                else:

                    ret += tag

            return ret

        if self.sep_style == SeparatorStyle.MPT:

            ret = system_prompt + self.sep

            for role, message in self.messages:

                if message:

                    if isinstance(message, tuple):

                        message, _, _ = message

                    ret += role + message + self.sep

                else:

                    ret += role

            return ret

        raise ValueError(f'Invalid style: {self.sep_style}')



    def set_system_message(self, system_message: str):

        """Set the system message."""

        self.system_message = system_message



    def append_message(self, role: str, message: str):

        """Append a new message."""

        self.messages.append([role, message])



    def update_last_message(self, message: str):

        """Update the last output.



        The last message is typically set to be None when constructing the prompt,

        so we need to update it in-place after getting the response from a model.

        """

        self.messages[-1][1] = message



    def to_gradio_chatbot(self):

        """Convert the conversation to gradio chatbot format."""

        ret = []

        for i, (_, msg) in enumerate(self.messages[self.offset:]):

            if i % 2 == 0:

                ret.append([msg, None])

            else:

                ret[-1][-1] = msg

        return ret



    def to_openai_api_messages(self):

        """Convert the conversation to OpenAI chat completion format."""

        ret = [{'role': 'system', 'content': self.system_message}]



        for i, (_, msg) in enumerate(self.messages[self.offset:]):

            if i % 2 == 0:

                ret.append({'role': 'user', 'content': msg})

            else:

                if msg is not None:

                    ret.append({'role': 'assistant', 'content': msg})

        return ret



    def copy(self):

        return Conversation(

            name=self.name,

            system_template=self.system_template,

            system_message=self.system_message,

            roles=self.roles,

            messages=[[x, y] for x, y in self.messages],

            offset=self.offset,

            sep_style=self.sep_style,

            sep=self.sep,

            sep2=self.sep2,

            stop_str=self.stop_str,

            stop_token_ids=self.stop_token_ids,

        )



    def dict(self):

        return {

            'template_name': self.name,

            'system_message': self.system_message,

            'roles': self.roles,

            'messages': self.messages,

            'offset': self.offset,

        }





# A global registry for all conversation templates

conv_templates: Dict[str, Conversation] = {}





def register_conv_template(template: Conversation, override: bool = False):

    """Register a new conversation template."""

    if not override:

        if template.name in conv_templates:

            raise AssertionError(f'{template.name} has been registered.')



    conv_templates[template.name] = template





def get_conv_template(name: str) -> Conversation:

    """Get a conversation template."""

    return conv_templates[name].copy()





# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference

# is that during training, the preprocessing function for the Hermes-2 template doesn't add

# <s> at the beginning of the tokenized sequence, while the internlm2-chat template does.

# Therefore, they are completely equivalent during inference.

register_conv_template(

    Conversation(

        name='Hermes-2',

        system_template='<|im_start|>system\n{system_message}',

        # note: The new system prompt was not used here to avoid changes in benchmark performance.

        # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',

        system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',

        roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),

        sep_style=SeparatorStyle.MPT,

        sep='<|im_end|>',

        stop_token_ids=[

            2,

            6,

            7,

            8,

        ],

        stop_str='<|endoftext|>',

    )

)



register_conv_template(

    Conversation(

        name='internlm2-chat',

        system_template='<|im_start|>system\n{system_message}',

        # note: The new system prompt was not used here to avoid changes in benchmark performance.

        # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',

        system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',

        roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),

        sep_style=SeparatorStyle.MPT,

        sep='<|im_end|>',

        stop_token_ids=[

            2,

            92543,

            92542

        ]

    )

)



register_conv_template(

    Conversation(

        name='phi3-chat',

        system_template='<|system|>\n{system_message}',

        # note: The new system prompt was not used here to avoid changes in benchmark performance.

        # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',

        system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',

        roles=('<|user|>\n', '<|assistant|>\n'),

        sep_style=SeparatorStyle.MPT,

        sep='<|end|>',

        stop_token_ids=[

            2,

            32000,

            32007

        ]

    )

)