import re
from dataclasses import dataclass
from copy import deepcopy
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
from typing_extensions import override
from transformers import PreTrainedTokenizer
from mindspeed_mm.data.data_utils.func_utils.convert import Role
from mindspeed_mm.data.data_utils.func_utils.formatters import EmptyFormatter, Formatter, StringFormatter, FunctionFormatter, ToolFormatter
from mindspeed_mm.data.data_utils.func_utils.formatters import SLOTS
from mindspeed_mm.data.data_utils.func_utils.log import get_logger
from mindspeed_mm.data.data_utils.func_utils.mm_plugin import get_mm_plugin
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from .formatters import SLOTS
from .mm_plugin import BasePlugin
logger = get_logger(__name__)
@dataclass
class Template:
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_observation: "Formatter"
format_prefix: "Formatter"
format_function: "Formatter"
format_tools: "Formatter"
format_tool: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
enable_thinking: Optional[bool]
thought_words: tuple[str, str]
tool_call_words: tuple[str, str]
mm_plugin: "BasePlugin"
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: List[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
answer_ids = encoded_messages[-1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: List[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: Sequence[Dict[str, str]],
system: Optional[str],
tools: List[str]
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0:
elements += self.format_prefix.apply()
if system:
if tools is not None and len(tools) > 0:
tool_schema = []
for t in tools:
tool_schema.append(t)
tool_schema = '\n'.join(tool_schema)
tools_prompt = self.format_tool.apply(content=tool_schema)
system += tools_prompt[0]
elements += self.format_system.apply(content=system)
if message["role"] == Role.USER.value:
elements += self.format_user.apply(
content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(
content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(
content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(
content=message["content"])
elif message["role"] == Role.TOOL_CALL:
elements += self.format_assistant.apply(
content=message["content"])
elif message["role"] == Role.TOOL_RESPONSE:
elements += self.format_observation.apply(
content=message["content"])
else:
raise NotImplementedError(
"Unexpected role: {}".format(message["role"]))
encoded_messages.append(
self._convert_elements_to_ids(tokenizer, elements))
return encoded_messages
@staticmethod
def _convert_elements_to_ids(tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem,
add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [
tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError(
"Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
@dataclass
class ReasoningTemplate(Template):
r"""A template that add thought to assistant message."""
@override
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if self.enable_thinking is False:
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
):
if not self.enable_thinking:
prompt_ids += self.get_thought_word_ids(tokenizer)
else:
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
return prompt_ids, response_ids
@override
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False:
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
):
if not self.enable_thinking:
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else:
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
TEMPLATES: Dict[str, "Template"] = {}
@dataclass
class RegisterParams:
format_user: Optional["Formatter"] = None
format_assistant: Optional["Formatter"] = None
format_system: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None
format_prefix: Optional["Formatter"] = None
format_function: Optional["Formatter"] = None
format_tools: Optional["Formatter"] = None
tool_prompt: Optional["Formatter"] = None
tool_call_words: Optional[tuple[str, str]] = None
default_system: str = ""
stop_words: Optional[Sequence[str]] = None
efficient_eos: bool = False
replace_eos: bool = False
enable_thinking: Optional[bool] = True
thought_words: Optional[tuple[str, str]] = None
def _register_template(
name: str,
params: RegisterParams,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template,
) -> None:
r"""
Registers a chat template.
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
```
The corresponding code should be:
```
_register_template(
name="custom",
RegisterParams(format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True),
)
```
"""
eos_slots = [] if params.efficient_eos else [{"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(
slots=["{{content}}"] + eos_slots)
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(
format_user=params.format_user or default_user_formatter,
format_assistant=params.format_assistant or default_assistant_formatter,
format_system=params.format_system or default_user_formatter,
format_function=params.format_function or default_user_formatter,
format_tools=params.format_tools or default_user_formatter,
format_observation=params.format_observation or params.format_user or default_user_formatter,
format_prefix=params.format_prefix or default_prefix_formatter,
format_tool=params.tool_prompt or None,
default_system=params.default_system,
stop_words=[] if params.stop_words is None else params.stop_words,
tool_call_words=params.tool_call_words or ("", ""),
efficient_eos=params.efficient_eos,
replace_eos=params.replace_eos,
enable_thinking=params.enable_thinking,
thought_words=params.thought_words or ("<think>", "</think>"),
mm_plugin=mm_plugin
)
_register_template(
name="qwen2vl",
params=RegisterParams(
format_user=StringFormatter(
slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(
slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(
slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True),
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>")
)
_register_template(
name="qwen3_vl",
params=RegisterParams(
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
stop_words=["<|im_end|>"],
replace_eos=True),
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
_register_template(
name="qwen3_vl_nothink",
params=RegisterParams(
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
stop_words=["<|im_end|>"],
replace_eos=True),
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
_register_template(
name="qwen2_omni",
params=RegisterParams(
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(
slots=[
"<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True),
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>")
)
_register_template(
name="qwen3_omni",
params=RegisterParams(
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(
slots=[
"<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
stop_words=["<|im_end|>"],
replace_eos=True),
mm_plugin=get_mm_plugin(
name="qwen3_omni", audio_token="<|audio_pad|>", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate
)
tools_slot = '''
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{content}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>'''
_register_template(
name="qwen3_omni_nothink",
params=RegisterParams(
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(
slots=[
"<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
stop_words=["<|im_end|>"],
replace_eos=True,
tool_prompt=StringFormatter(slots=[tools_slot])),
mm_plugin=get_mm_plugin(
name="qwen3_omni", audio_token="<|audio_pad|>", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
_register_template(
name="glm4.1v",
params=RegisterParams(
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>", "</answer>"],
efficient_eos=True
),
mm_plugin=get_mm_plugin(name="glm4.1v", image_token="<|image|>", video_token="<|video|>"),
template_class=ReasoningTemplate
)
_register_template(
name="glm4v_moe",
params=RegisterParams(
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>", "</answer>"],
efficient_eos=True
),
mm_plugin=get_mm_plugin(name="glm4.1v", image_token="<|image|>", video_token="<|video|>"),
template_class=ReasoningTemplate
)
_register_template(
name="mistral",
params=RegisterParams(
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
default_system="""First draft your thinking process (inner monologue) until you arrive at a response. Format your response using Markdown, and use LaTeX for any mathematical equations. Write both your thoughts and the response in the same language as the input.
Your thinking process must follow the template below:[THINK]Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate the response. Use the same language as the input.[/THINK]Here, provide a self-contained response. """,
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
_register_template(
name="default",
params=RegisterParams(
format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]),
),
)
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", template: str) -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
template = TEMPLATES.get(template, None)
if template is None:
raise ValueError(
"Template {} does not exist.".format(template))
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError(
"Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning(
"New tokens have been added, make sure `resize_vocab` is True.")
return template
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token))
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if num_added_tokens > 0:
logger.warning(
"New tokens have been added, make sure `resize_vocab` is True.")