import os
import sys
import re
import json
import logging
from copy import deepcopy
from pathlib import Path
from dataclasses import dataclass
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union, Any, Set, ClassVar
from .formatter import (EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter,
FunctionFormatterForThink, ToolFormatterForThink)
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from .formatter import SLOTS, Formatter
logger = logging.getLogger(__name__)
cur_file_dir = Path(__file__).absolute().parent
TEMPLATES_DIR = os.path.join(cur_file_dir.parent.parent.parent, "configs/finetune/templates.json")
@dataclass
class AlpacaTemplate:
system_token = ""
user_token = "### Instruction:"
assistant_token = "### Response:"
end_token = ""
system = "Below is an instruction that describes a task, paired with an input that provides further context. " \
"Write a response that appropriately completes the request. " \
"Please note that you need to think through your response logically and step by step."
class Prompter(object):
def __init__(self, template, verbose: bool = False):
self._verbose = verbose
self.template = template
self.user_role = "user"
self.assistant_role = "assistant"
def generate_training_prompt(self, messages) -> str:
prompt = self.template.system_token + "\n" + self.template.system + self.template.end_token + "\n"
for message in messages:
if message["role"] == self.user_role:
prompt += self.template.user_token + "\n" + message["content"] + self.template.end_token + "\n"
else:
prompt += self.template.assistant_token + "\n" + message["content"] \
+ self.template.end_token + "\n"
return prompt
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
if source_len + target_len == 0:
max_target_len = 0
else:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
@dataclass
class Template:
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
thought_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
force_system: bool
enable_thinking: Optional[bool]
reasoning_effort: Optional[str]
drop_thinking: Optional[bool]
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids(
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
) -> List[int]:
r"""
Converts elements to token ids.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
return token_ids
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = infer_max_len(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}{self.thought_words[1]}" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
@dataclass
class LFDefaultTemplate(Template):
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
response_ids = encoded_messages[-1]
return prompt_ids, response_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
) -> Sequence[list[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if message["role"] == Role.USER:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return encoded_messages
def _make_pairs(
self,
encoded_messages: Sequence[List[int]],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
from .decoder_packed_mtf_dataset import _infer_seqlen
encoded_pairs = []
total_length = 0
cutoff_len = cutoff_len - reserved_label_len
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
max_source_len, max_target_len = _infer_seqlen(
source_len=len(encoded_messages[i]),
target_len=len(encoded_messages[i + 1]),
cutoff_len=(cutoff_len - total_length)
)
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
@dataclass
class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
@dataclass
class ReasoningTemplate(LFDefaultTemplate):
r"""A template that add thought to assistant message."""
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
) -> Sequence[list[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return encoded_messages
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if self.enable_thinking is False:
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
):
if not self.enable_thinking:
prompt_ids += self.get_thought_word_ids(tokenizer)
else:
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
return prompt_ids, response_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False:
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
):
if not self.enable_thinking:
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else:
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
@dataclass
class DeepSeek4Template(LFDefaultTemplate):
BOS_TOKEN: ClassVar[str] = "<|begin▁of▁sentence|>"
EOS_TOKEN: ClassVar[str] = "<|end▁of▁sentence|>"
USER_SP_TOKEN: ClassVar[str] = "<|User|>"
ASSISTANT_SP_TOKEN: ClassVar[str] = "<|Assistant|>"
LATEST_REMINDER_SP_TOKEN: ClassVar[str] = "<|latest_reminder|>"
THINKING_START: ClassVar[str] = "<think>"
THINKING_END: ClassVar[str] = "</think>"
DSML_TOKEN: ClassVar[str] = "|DSML|"
DS_TASK_SP_TOKENS: ClassVar[Dict[str, str]] = {
"action": "<|action|>",
"query": "<|query|>",
"authority": "<|authority|>",
"domain": "<|domain|>",
"title": "<|title|>",
"read_url": "<|read_url|>",
}
VALID_TASKS: ClassVar[Set[str]] = set(DS_TASK_SP_TOKENS.keys())
TOOL_CALLS_BLOCK_NAME: ClassVar[str] = "tool_calls"
TOOLS_TEMPLATE: ClassVar[str] = (
"## Tools\n\n"
"You have access to a set of tools to help answer the user's question. "
"You can invoke tools by writing a \"<{dsml}tool_calls>\" block like the following:\n\n"
"<{dsml}tool_calls>\n"
"<{dsml}invoke name=\"$TOOL_NAME\">\n"
"<{dsml}parameter name=\"$PARAMETER_NAME\" string=\"true|false\">$PARAMETER_VALUE</{dsml}parameter>\n"
"...\n"
"</{dsml}invoke>\n"
"<{dsml}invoke name=\"$TOOL_NAME2\">\n"
"...\n"
"</{dsml}invoke>\n"
"</{dsml}tool_calls>\n\n"
"String parameters should be specified as is and set `string=\"true\"`. "
"For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string=\"false\"`.\n\n"
"If thinking_mode is enabled (triggered by {ts}), you MUST output your complete reasoning inside {ts}...{te} BEFORE any tool calls or final response.\n\n"
"Otherwise, output directly after {te} with tool calls or final response.\n\n"
"### Available Tool Schemas\n\n"
"{tool_schemas}\n\n"
"You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.\n"
)
REASONING_EFFORT_MAX: ClassVar[str] = (
"Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
"You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, "
"rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
"Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, "
"and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
)
RESPONSE_FORMAT_TEMPLATE: ClassVar[str] = (
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
)
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
"""Last-turn-only: returns (prompt_ids, response_ids)."""
v4_messages = self._normalize_to_v4_schema(messages, system, tools)
v4_messages = self._merge_tool_messages(v4_messages)
v4_messages = self._sort_tool_results_by_call_order(v4_messages)
effective_drop = (
self.drop_thinking
and self.enable_thinking
and not any(m.get("tools") for m in v4_messages)
)
if effective_drop:
v4_messages = self._drop_thinking_messages(v4_messages)
last_asst_idx = -1
for i, m in enumerate(v4_messages):
if m.get("role") == "assistant":
last_asst_idx = i
prompt_text = self.BOS_TOKEN
response_text = ""
for idx, _ in enumerate(v4_messages):
rendered = self._render_message(
idx, v4_messages,
thinking_mode="thinking" if self.enable_thinking else "chat",
drop_thinking=self.drop_thinking,
reasoning_effort=self.reasoning_effort if idx == 0 else None,
)
if last_asst_idx == -1 or idx < last_asst_idx:
prompt_text += rendered
elif idx == last_asst_idx:
response_text = rendered
else:
prompt_text += rendered
return self._encode(prompt_text, tokenizer), self._encode(response_text, tokenizer)
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
"""All-turn loss: returns [(source_ids, target_ids), ...] per assistant turn.
"""
v4_messages = self._normalize_to_v4_schema(messages, system, tools)
v4_messages = self._merge_tool_messages(v4_messages)
v4_messages = self._sort_tool_results_by_call_order(v4_messages)
effective_drop = (
self.drop_thinking
and self.enable_thinking
and not any(m.get("tools") for m in v4_messages)
)
if effective_drop:
v4_messages = self._drop_thinking_messages(v4_messages)
encoded_segments: List[List[int]] = []
current_source_text = self.BOS_TOKEN
for idx, _ in enumerate(v4_messages):
rendered = self._render_message(
idx, v4_messages,
thinking_mode="thinking" if self.enable_thinking else "chat",
drop_thinking=effective_drop,
reasoning_effort=self.reasoning_effort if idx == 0 else None,
)
if v4_messages[idx].get("role") == "assistant":
if not current_source_text:
raise ValueError(
f"DeepSeek4Template.encode_multiturn: assistant at index "
f"{idx} has no preceding source segment. messages must "
f"alternate user/assistant after _merge_tool_messages."
)
encoded_segments.append(self._encode(current_source_text, tokenizer))
encoded_segments.append(self._encode(rendered, tokenizer))
current_source_text = ""
else:
current_source_text += rendered
return self._make_pairs(encoded_segments, cutoff_len, reserved_label_len)
def _encode(self, tokens, tokenizer: "PreTrainedTokenizer",):
return tokenizer.encode(tokens, add_special_tokens=False) if tokens else []
def _normalize_to_v4_schema(
self,
messages: List[Dict[str, Any]],
system: Optional[str],
tools: Optional[str],
) -> List[Dict[str, Any]]:
"""Translate (LlamaFactory-style messages, system_str, tools_str) into
V4-native messages.
How (system_str, tools_str) merge with `messages`:
- If messages already starts with system/developer: handler-supplied
system_str is prepended to its content; tools_str is attached only
if the message doesn't already have a tools field.
- Otherwise: a leading system message is synthesized from system_str
and tools_str.
- When both system_str and tools_str are empty: pass through unchanged.
"""
messages = deepcopy(messages) if messages else []
parsed_tools: Optional[List[Dict[str, Any]]] = None
if tools and isinstance(tools, str) and tools.strip():
try:
parsed = json.loads(tools)
except json.JSONDecodeError as e:
raise ValueError(
f"Failed to parse tools JSON: {e!r}; tools_str={tools[:200]!r}"
) from e
if isinstance(parsed, list) and parsed:
for i, t in enumerate(parsed):
if not isinstance(t, dict) or "function" not in t:
raise ValueError(
f"DeepSeek4Template only accepts OpenAI-format tools "
f"([{{'type': 'function', 'function': {{...}}}}, ...]). "
f"Bad entry at index {i}: {t!r}"
)
parsed_tools = parsed
system_text = system or ""
first_role = messages[0].get("role") if messages else None
first_is_system = first_role in ("system", "developer")
synthesize_leading = (system_text or parsed_tools) and not first_is_system
merge_into_first = (system_text or parsed_tools) and first_is_system
out: List[Dict[str, Any]] = []
if synthesize_leading:
sys_msg: Dict[str, Any] = {"role": "system", "content": system_text}
if parsed_tools:
sys_msg["tools"] = parsed_tools
out.append(sys_msg)
elif merge_into_first:
first = messages[0]
if system_text:
first["content"] = (
system_text
+ ("\n\n" if first.get("content") else "")
+ (first.get("content") or "")
)
if parsed_tools:
first.setdefault("tools", parsed_tools)
for msg in messages:
role = msg.get("role")
if role == "user":
new_msg = {"role": "user", "content": msg.get("content", "")}
for k in ("task", "mask", "wo_eos", "content_blocks"):
if k in msg:
new_msg[k] = msg[k]
out.append(new_msg)
elif role == "assistant":
new_msg: Dict[str, Any] = {"role": "assistant"}
content = msg.get("content", "") or ""
if "reasoning_content" in msg:
new_msg["reasoning_content"] = msg["reasoning_content"] or ""
new_msg["content"] = content
else:
m = re.compile(r"^\s*<think>\s*(.*?)\s*</think>\s*", re.DOTALL).match(content) if content else None
if m:
new_msg["reasoning_content"] = m.group(1)
new_msg["content"] = content[m.end():]
else:
new_msg["content"] = content
if msg.get("tool_calls"):
new_msg["tool_calls"] = msg["tool_calls"]
for k in ("task", "mask", "wo_eos"):
if k in msg:
new_msg[k] = msg[k]
out.append(new_msg)
elif role in ("tool", "function", "observation"):
new_msg = {"role": "tool", "content": msg.get("content", "")}
if "tool_call_id" in msg:
new_msg["tool_call_id"] = msg["tool_call_id"]
out.append(new_msg)
elif role == "system":
new_msg = {"role": "system", "content": msg.get("content", "")}
if msg.get("tools"):
new_msg["tools"] = msg["tools"]
out.append(new_msg)
elif role in ("developer", "latest_reminder"):
out.append(deepcopy(msg))
else:
raise NotImplementedError(
f"DeepSeek4Template: unsupported role {role!r}"
)
return out
@classmethod
def _render_message(
cls,
index: int,
messages: List[Dict[str, Any]],
thinking_mode: str,
drop_thinking: bool = True,
reasoning_effort: Optional[str] = None,
) -> str:
"""Render a single message into its V4-encoded text form.
thinking_mode: 'thinking' or 'chat'.
drop_thinking: whether earlier-turn reasoning_content was already
stripped (encode_oneturn=True; encode_multiturn=False).
reasoning_effort: only inserts the max-effort prefix when
(index == 0 and thinking_mode == 'thinking' and effort == 'max').
"""
if not (0 <= index < len(messages)):
raise IndexError(f"index {index} out of range for messages of length {len(messages)}")
if thinking_mode not in ("chat", "thinking"):
raise ValueError(f"Invalid thinking_mode: {thinking_mode!r}")
if reasoning_effort not in (None, "max", "high"):
raise ValueError(f"Invalid reasoning_effort: {reasoning_effort!r}")
prompt = ""
msg = messages[index]
last_user_idx = -1
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") in ("user", "developer"):
last_user_idx = i
break
role = msg.get("role")
content = msg.get("content")
tools = msg.get("tools")
response_format = msg.get("response_format")
tool_calls = msg.get("tool_calls")
reasoning_content = msg.get("reasoning_content")
wo_eos = msg.get("wo_eos", False)
if tools:
tools = [t["function"] for t in tools]
if tool_calls:
tool_calls = [
{"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}
for tc in tool_calls
]
if index == 0 and thinking_mode == "thinking" and reasoning_effort == "max":
prompt += cls.REASONING_EFFORT_MAX
if role == "system":
prompt += content or ""
if tools:
tool_schemas = "\n".join(cls._to_json(t) for t in tools)
prompt += "\n\n" + cls.TOOLS_TEMPLATE.format(
tool_schemas=tool_schemas, dsml=cls.DSML_TOKEN,
ts=cls.THINKING_START, te=cls.THINKING_END,
)
if response_format:
prompt += "\n\n" + cls.RESPONSE_FORMAT_TEMPLATE.format(
schema=cls._to_json(response_format)
)
elif role == "developer":
if not content:
raise ValueError(f"Invalid developer message: {msg}")
content_dev = cls.USER_SP_TOKEN + content
if tools:
tool_schemas = "\n".join(cls._to_json(t) for t in tools)
content_dev += "\n\n" + cls.TOOLS_TEMPLATE.format(
tool_schemas=tool_schemas, dsml=cls.DSML_TOKEN,
ts=cls.THINKING_START, te=cls.THINKING_END,
)
if response_format:
content_dev += "\n\n" + cls.RESPONSE_FORMAT_TEMPLATE.format(
schema=cls._to_json(response_format)
)
prompt += content_dev
elif role == "user":
prompt += cls.USER_SP_TOKEN
content_blocks = msg.get("content_blocks")
if content_blocks:
parts = []
for block in content_blocks:
btype = block.get("type")
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_result":
tool_content = block.get("content", "")
if isinstance(tool_content, list):
text_parts = []
for b in tool_content:
if b.get("type") == "text":
text_parts.append(b.get("text", ""))
else:
text_parts.append(f"[Unsupported {b.get('type')}]")
tool_content = "\n\n".join(text_parts)
parts.append(f"<tool_result>{tool_content}</tool_result>")
else:
parts.append(f"[Unsupported {btype}]")
prompt += "\n\n".join(parts)
else:
prompt += content or ""
elif role == "latest_reminder":
prompt += cls.LATEST_REMINDER_SP_TOKEN + (content or "")
elif role == "tool":
raise NotImplementedError(
"tool messages must be merged into user via _merge_tool_messages first"
)
elif role == "assistant":
thinking_part = ""
tc_content = ""
if tool_calls:
tc_list = [
f'<{cls.DSML_TOKEN}invoke name="{tc.get("name")}">\n'
f"{cls._encode_arguments_to_dsml(tc)}\n"
f"</{cls.DSML_TOKEN}invoke>"
for tc in tool_calls
]
tc_content = (
f"\n\n<{cls.DSML_TOKEN}{cls.TOOL_CALLS_BLOCK_NAME}>\n"
+ "\n".join(tc_list)
+ f"\n</{cls.DSML_TOKEN}{cls.TOOL_CALLS_BLOCK_NAME}>"
)
summary_content = content or ""
rc = reasoning_content or ""
prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
if thinking_mode == "thinking" and not prev_has_task:
if not drop_thinking or index > last_user_idx:
thinking_part = rc + cls.THINKING_END
assembled = thinking_part + summary_content + tc_content
prompt += assembled if wo_eos else assembled + cls.EOS_TOKEN
else:
raise NotImplementedError(f"Unknown role: {role}")
if (
index + 1 < len(messages)
and messages[index + 1].get("role") not in ("assistant", "latest_reminder")
):
return prompt
task = msg.get("task")
if task is not None:
if task not in cls.VALID_TASKS:
raise ValueError(
f"Invalid task: {task!r}. Valid: {sorted(cls.VALID_TASKS)}"
)
task_token = cls.DS_TASK_SP_TOKENS[task]
if task != "action":
prompt += task_token
else:
prompt += cls.ASSISTANT_SP_TOKEN
prompt += cls.THINKING_END if thinking_mode != "thinking" else cls.THINKING_START
prompt += task_token
elif role in ("user", "developer"):
prompt += cls.ASSISTANT_SP_TOKEN
if not drop_thinking and thinking_mode == "thinking":
prompt += cls.THINKING_START
elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
prompt += cls.THINKING_START
else:
prompt += cls.THINKING_END
return prompt
@classmethod
def _encode_arguments_to_dsml(cls, tool_call: Dict[str, str]) -> str:
"""Serialize a tool call's `arguments` (JSON string) into DSML parameter lines."""
try:
arguments = json.loads(tool_call["arguments"])
except Exception:
arguments = {"arguments": tool_call["arguments"]}
lines = []
for k, v in arguments.items():
is_str = "true" if isinstance(v, str) else "false"
value = v if isinstance(v, str) else cls._to_json(v)
lines.append(
f'<{cls.DSML_TOKEN}parameter name="{k}" string="{is_str}">{value}</{cls.DSML_TOKEN}parameter>'
)
return "\n".join(lines)
@staticmethod
def _merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fold role='tool' messages into the preceding user message as
<tool_result> blocks via content_blocks."""
merged: List[Dict[str, Any]] = []
for msg in messages:
msg = deepcopy(msg)
role = msg.get("role")
if role == "tool":
tool_block = {
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": msg.get("content", ""),
}
if (merged
and merged[-1].get("role") == "user"
and "content_blocks" in merged[-1]):
merged[-1]["content_blocks"].append(tool_block)
else:
merged.append({"role": "user", "content_blocks": [tool_block]})
elif role == "user":
text_block = {"type": "text", "text": msg.get("content", "")}
can_merge = (
merged
and merged[-1].get("role") == "user"
and "content_blocks" in merged[-1]
and merged[-1].get("task") is None
)
if can_merge:
merged[-1]["content_blocks"].append(text_block)
else:
new_msg = {
"role": "user",
"content": msg.get("content", ""),
"content_blocks": [text_block],
}
for k in ("task", "wo_eos", "mask"):
if k in msg:
new_msg[k] = msg[k]
merged.append(new_msg)
else:
merged.append(msg)
return merged
@staticmethod
def _sort_tool_results_by_call_order(
messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Reorder tool_result blocks within a user message to match the order
of tool_calls in the preceding assistant message."""
last_order: Dict[str, int] = {}
for msg in messages:
role = msg.get("role")
if role == "assistant" and msg.get("tool_calls"):
last_order = {}
for idx, tc in enumerate(msg["tool_calls"]):
tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
if tc_id:
last_order[tc_id] = idx
elif role == "user" and msg.get("content_blocks"):
tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
if len(tool_blocks) > 1 and last_order:
sorted_blocks = sorted(
tool_blocks,
key=lambda b: last_order.get(b.get("tool_use_id", ""), 0),
)
j = 0
new_blocks = []
for block in msg["content_blocks"]:
if block.get("type") == "tool_result":
new_blocks.append(sorted_blocks[j])
j += 1
else:
new_blocks.append(block)
msg["content_blocks"] = new_blocks
return messages
@staticmethod
def _drop_thinking_messages(
messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Strip reasoning_content from assistant messages occurring strictly
before the last user message."""
last_user_idx = -1
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") in ("user", "developer"):
last_user_idx = i
break
keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
result = []
for idx, msg in enumerate(messages):
role = msg.get("role")
if role in keep_roles or idx >= last_user_idx:
result.append(msg)
elif role == "assistant":
msg = deepcopy(msg)
msg.pop("reasoning_content", None)
result.append(msg)
return result
@staticmethod
def _to_json(value: Any) -> str:
try:
return json.dumps(value, ensure_ascii=False)
except Exception:
return json.dumps(value, ensure_ascii=True)
templates: Dict[str, Template] = {}
def get_templates() -> Dict[str, Template]:
return templates
def get_model_template(name, prompt_type_path, enable_thinking, reasoning_effort=None, drop_thinking=True):
name = register_custom_template(name, prompt_type_path, enable_thinking, reasoning_effort, drop_thinking)
if name is None:
template = templates["empty"]
else:
template = get_templates().get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
return template
def fix_model_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
prompt_type_path: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = None,
drop_thinking: Optional[bool] = True
):
template = get_model_template(name, prompt_type_path, enable_thinking, reasoning_effort, drop_thinking)
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
def _register_template(
name: str,
format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None,
format_system: Optional["Formatter"] = None,
format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
force_system: bool = False,
enable_thinking: Optional[bool] = True,
template_class: type["Template"] = Template,
reasoning_effort: Optional[str] = None,
drop_thinking: Optional[bool] = True
) -> None:
r"""
Registers a chat template.
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
```
The corresponding code should be:
```
_register_template(
name="custom",
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
```
"""
eos_slots = [] if efficient_eos else [{"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
templates[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
format_system=format_system or default_user_formatter,
format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort,
drop_thinking=drop_thinking
)
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token))
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
def register_custom_template(name, json_file_path=TEMPLATES_DIR, enable_thinking=False, reasoning_effort=None, drop_thinking=True) -> str:
if name in templates:
if name == 'deepseek4' and reasoning_effort is not None:
templates[name].reasoning_effort = reasoning_effort
templates[name].drop_thinking = drop_thinking
return name
if not bool(re.match(r'(?:(?:/|\.{1,2}/|[^/\0]+/)(?:[^/\0]+/)*[^/\0]*|\.{1,2})', json_file_path)):
raise ValueError(f"Invalid Path: {json_file_path}, please provide a valid custom template path.")
with open(json_file_path, 'r') as file:
config = json.load(file)
templates_dict = {template['name']: template for template in config}
config = templates_dict.get(name, None)
if not config:
raise ValueError(f"Can't find the template. Please provide a valid prompt type template in the {json_file_path}.")
format_user = _format_custom_template(config.get("format_user", None))
format_assistant = _format_custom_template(config.get("format_assistant", None))
format_system = _format_custom_template(config.get("format_system", None))
format_function = _format_custom_template(config.get("format_function", None))
format_observation = _format_custom_template(config.get("format_observation", None))
format_tools = _format_custom_template(config.get("format_tools", None))
format_separator = _format_custom_template(config.get("format_separator", None))
format_prefix = _format_custom_template(config.get("format_prefix", None))
default_system = _format_custom_template(config.get("default_system", ""))
stop_words = _format_custom_template(config.get("stop_words", []))
efficient_eos = _format_custom_template(config.get("efficient_eos", False))
replace_eos = _format_custom_template(config.get("replace_eos", False))
force_system = _format_custom_template(config.get("force_system", False))
template_class = _format_custom_template(config.get("template_class", None))
thought_words = _format_custom_template(config.get("thought_words", None))
if isinstance(default_system, list):
default_system = "".join(default_system) if all(isinstance(sentence, str) for sentence in default_system) else default_system
format_user = StringFormatter(**format_user) if format_user else None
format_assistant = StringFormatter(**format_assistant) if format_assistant else None
format_system = StringFormatter(**format_system) if format_system else None
format_observation = StringFormatter(**format_observation) if format_observation else None
format_separator = EmptyFormatter(**format_separator) if format_separator else None
format_prefix = EmptyFormatter(**format_prefix) if format_prefix else None
template_class = _get_template_class(template_class) if template_class else Template
if name == 'deepseek4':
format_function = None
format_tools = None
elif name in ['qwen3', 'bailing_mini']:
format_function = FunctionFormatterForThink(**format_function) if format_function else None
format_tools = ToolFormatterForThink(**format_tools) if format_tools else None
else:
format_function = FunctionFormatter(**format_function) if format_function else None
format_tools = ToolFormatter(**format_tools) if format_tools else None
_register_template(
name=name,
format_user=format_user,
format_assistant=format_assistant,
format_system=format_system,
format_function=format_function,
format_observation=format_observation,
format_tools=format_tools,
format_separator=format_separator,
format_prefix=format_prefix,
default_system=default_system,
stop_words=stop_words,
thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
force_system=force_system,
enable_thinking=enable_thinking,
template_class=template_class,
reasoning_effort=reasoning_effort,
drop_thinking=drop_thinking
)
return name
def _format_custom_template(slots: Dict) -> Dict:
if slots and isinstance(slots, Dict):
for key, slot in slots.items():
slots[key] = list(map(lambda slot: set(slot) if isinstance(slot, list) else slot, slot)) if slot else None
return slots
def _get_template_class(template_name: str) -> None:
current_module = sys.modules.get(__name__)
if not current_module:
raise Exception("curent module not found")
template_class = getattr(current_module, template_name, None)
if template_class is None:
template_class = Template
logger.info("template will use %s to format dataset", template_class.__name__)
return template_class