import random
import re
from typing import Any
try:
from jinja2 import Template
except ImportError as e:
raise ImportError("Jinja2 is required. Please install it with: pip install jinja2") from e
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.http_utils import post
from slime.utils.types import Sample
try:
from slime.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score
except ImportError as e:
raise ImportError("MathDapo is not installed") from e
from tool_sandbox import SEMAPHORE, TOOL_CONFIGS, tool_registry
_LOG_SAMPLE_PROB = 0.05
_LOG_WIDTH = 300
def _trunc(s: str, n: int = 300) -> str:
"""Truncate *s* to at most *n* characters for display."""
if len(s) <= n:
return s
return s[:n] + f"…[+{len(s) - n}]"
TOOL_TEMPLATE = """<|im_start|>system
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
You are a helpful assistant.
{%- endif %}
{%- if tools %}
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{%- for tool in tools %}
{{- tool | tojson }}
{%- endfor %}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
{%- endif %}
<|im_end|>
{%- for message in messages %}
{%- if message['role'] == 'user' %}
<|im_start|>user
{{- message['content'] }}<|im_end|>
{%- elif message['role'] == 'assistant' %}
<|im_start|>assistant
{{- message['content'] }}<|im_end|>
{%- endif %}
{%- endfor %}
<|im_start|>assistant
"""
def format_conversation_with_tools(
prompt: str, tools: list[dict[str, Any]] = None, system_prompt: str = None, messages: list[dict[str, Any]] = None
) -> str:
"""Format conversation using Jinja2 template with tool support"""
template = Template(TOOL_TEMPLATE)
messages_to_render = []
if system_prompt:
system_content = system_prompt
else:
system_content = (
"You are a helpful assistant that can use Python "
"tools to solve mathematical problems. When you need "
"to perform calculations, use the code_interpreter "
"tool to execute code and get results."
)
messages_to_render.append({"role": "system", "content": system_content})
if prompt:
messages_to_render.append({"role": "user", "content": prompt})
if messages:
messages_to_render.extend(messages)
formatted_text = template.render(messages=messages_to_render, tools=tools or [])
return formatted_text
def postprocess_predictions(prediction: str):
"""Extract action and content from prediction string"""
answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}"
answer_match = re.search(answer_pattern, prediction, re.DOTALL)
if answer_match:
content = answer_match.group(1).strip()
return "answer", content
tool_call_pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
tool_call_match = re.search(tool_call_pattern, prediction, re.DOTALL)
if tool_call_match:
try:
import json
json_str = tool_call_match.group(1)
json_str = json_str.replace("\n", "\\n")
tool_call_data = json.loads(json_str)
tool_name = tool_call_data.get("name")
arguments = tool_call_data.get("arguments", {})
if tool_name == "code_interpreter":
code = arguments.get("code", "")
if code.strip():
return "code", code
except (json.JSONDecodeError, KeyError, AttributeError):
pass
glm_tool_call_pattern = (
r"<tool_call>\s*(\w[\w.]*)\s*"
r"(?:<arg_key>[^<]*</arg_key>)?\s*"
r"<arg_value>(.*?)</arg_value>\s*"
r"</tool_call>"
)
glm_match = re.search(glm_tool_call_pattern, prediction, re.DOTALL)
if glm_match:
tool_name = glm_match.group(1).strip()
code = glm_match.group(2).strip()
if tool_name == "code_interpreter" and code:
return "code", code
code_pattern = r"<code>(.*?)</code>"
code_match = re.search(code_pattern, prediction, re.DOTALL)
if code_match:
content = code_match.group(1).strip()
return "code", content
python_code_pattern = r"```python\s*(.*?)\s*```"
python_code_match = re.search(python_code_pattern, prediction, re.DOTALL)
if python_code_match:
content = python_code_match.group(1).strip()
return "code", content
return None, ""
def postprocess_responses(resp: str) -> str:
"""Post-process response to ensure tag completeness"""
if "<tool_call>" in resp:
tool_call_pattern = r"<tool_call>\s*\{.*?\}\s*</tool_call>"
matches = list(re.finditer(tool_call_pattern, resp, re.DOTALL))
if matches:
last_match = matches[-1]
return resp[: last_match.end()]
glm_tool_call_pattern = (
r"<tool_call>\s*\w[\w.]*\s*"
r"(?:<arg_key>[^<]*</arg_key>)?\s*"
r"<arg_value>.*?</arg_value>\s*"
r"</tool_call>"
)
glm_matches = list(re.finditer(glm_tool_call_pattern, resp, re.DOTALL))
if glm_matches:
return resp[: glm_matches[-1].end()]
if "</code>" in resp:
return resp.split("</code>")[0] + "</code>"
if "```python" in resp:
python_pattern = r"```python\s*.*?```"
matches = list(re.finditer(python_pattern, resp, re.DOTALL))
if matches:
last_match = matches[-1]
return resp[: last_match.end()]
if "Answer:" in resp and "\\boxed{" in resp:
answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}"
matches = list(re.finditer(answer_pattern, resp, re.DOTALL))
if matches:
last_match = matches[-1]
return resp[: last_match.end()]
return resp
def _uses_glm_tool_format(prediction: str) -> bool:
"""Detect GLM4.7-native tool call format by presence of </arg_value>."""
return "</arg_value>" in prediction
async def execute_predictions(prediction: str) -> str:
"""Execute predictions and return results"""
action, content = postprocess_predictions(prediction)
glm_format = _uses_glm_tool_format(prediction)
if action == "code":
code = content.strip()
if code:
async with SEMAPHORE:
result = await tool_registry.execute_tool("code_interpreter", {"code": code})
if glm_format:
next_obs = f"\n\n<interpreter>\n{result}\n</interpreter>\n<|assistant|>\n"
else:
next_obs = f"\n\n<interpreter>\n{result}\n</interpreter>\n\n"
done = False
else:
if glm_format:
next_obs = "\n\n<interpreter>\nError: No Python code found\n</interpreter>\n<|assistant|>\n"
else:
next_obs = "\n\n<interpreter>\nError: No Python code found\n</interpreter>\n\n"
done = False
elif action == "answer":
next_obs = ""
done = True
else:
next_obs = (
"\nMy previous action is invalid. "
"If I want to execute code, I should put the code between "
"<code> and </code>. "
"If I want to give the final answer, I should use the format "
"'Answer: \\boxed{answer}'. Let me try again.\n"
)
done = False
return next_obs, done
async def generate(args, sample: Sample, sampling_params) -> Sample:
"""Custom generation function supporting tool calls"""
assert not args.partial_rollout, "Partial rollout is not supported for " "this function at the moment."
state = GenerateState(args)
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
tool_specs = tool_registry.get_tool_specs()
if isinstance(sample.prompt, list):
messages = sample.prompt
else:
messages = [{"role": "user", "content": sample.prompt}]
try:
prompt = state.tokenizer.apply_chat_template(
messages,
tools=tool_specs,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
if isinstance(sample.prompt, list):
raw_prompt = next(
(m["content"] for m in reversed(sample.prompt) if m.get("role") == "user"),
"",
)
else:
raw_prompt = sample.prompt
prompt = format_conversation_with_tools(prompt=raw_prompt, tools=tool_specs)
prompt_tokens_ids = state.tokenizer(prompt, add_special_tokens=False)["input_ids"]
response = ""
response_token_ids = []
loss_masks = []
tool_call_count = 0
obs_truncated = False
max_context_length = len(prompt_tokens_ids) + args.rollout_max_response_len
print(f"max_context_length is set to {max_context_length}", flush=True)
verbose = random.random() < _LOG_SAMPLE_PROB
if verbose:
_sep = "═" * _LOG_WIDTH
print(f"\n{_sep}", flush=True)
_prompt_display = (
"".join(m.get("content", "") for m in sample.prompt)
if isinstance(sample.prompt, list)
else sample.prompt
)
print(f"[ReTool LOG] prompt ({len(prompt_tokens_ids)} tokens): {_trunc(_prompt_display, 200)}", flush=True)
print(_sep, flush=True)
for turn in range(TOOL_CONFIGS["max_turns"]):
current_token_ids = prompt_tokens_ids + response_token_ids
total_length = len(current_token_ids)
if total_length >= max_context_length:
sample.status = Sample.Status.TRUNCATED
break
remaining_tokens = max_context_length - total_length
current_sampling_params = sampling_params.copy()
current_sampling_params["max_new_tokens"] = min(
sampling_params.get("max_new_tokens", args.rollout_max_response_len),
remaining_tokens
)
if current_sampling_params["max_new_tokens"] <= 0:
sample.status = Sample.Status.TRUNCATED
break
payload = {
"input_ids": current_token_ids,
"sampling_params": current_sampling_params,
"return_logprob": True,
}
try:
import wandb
if wandb.run is not None:
available_tools = len(tool_specs)
tools_used = response.count("<interpreter>")
wandb.log(
{
"debug/payload_length": len(prompt + response),
"debug/available_tools": available_tools,
"debug/tools_used": tools_used,
"debug/turn": turn,
}
)
except ImportError:
pass
output = await post(url, payload)
if output["meta_info"]["finish_reason"]["type"] == "abort":
sample.status = Sample.Status.ABORTED
return sample
if "output_token_logprobs" in output["meta_info"]:
cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
cur_response = state.tokenizer.decode(cur_response_token_ids)
cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
if sample.rollout_log_probs is None:
sample.rollout_log_probs = []
sample.rollout_log_probs += cur_log_probs
else:
cur_response = output["text"]
cur_response = postprocess_responses(cur_response)
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
response += cur_response
response_token_ids += cur_response_token_ids
loss_masks += [1] * len(cur_response_token_ids)
if verbose:
n_tok = len(cur_response_token_ids)
finish = output["meta_info"]["finish_reason"]["type"]
print(f"\n{'─' * _LOG_WIDTH}", flush=True)
print(f"[Turn {turn + 1}] model output ({n_tok} tok, finish={finish}):", flush=True)
print(" " + _trunc(cur_response).replace("\n", "\n "), flush=True)
if output["meta_info"]["finish_reason"]["type"] == "length":
if verbose:
print(f"[Turn {turn + 1}] → length limit reached, stopping.", flush=True)
break
next_obs, done = await execute_predictions(cur_response)
if verbose:
if done:
print(f"[Turn {turn + 1}] → answer detected (DONE)", flush=True)
elif "<interpreter>" in next_obs:
obs_display = " " + _trunc(next_obs, 300).replace("\n", "\n ")
print(f"[Turn {turn + 1}] → code executed, observation:", flush=True)
print(obs_display, flush=True)
else:
print(f"[Turn {turn + 1}] → invalid action (no recognized code or answer)", flush=True)
if done:
break
if "<interpreter>" in next_obs:
tool_call_count += 1
assert next_obs != "", "Next observation should not be empty."
obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"]
response += next_obs
response_token_ids += obs_tokens_ids
loss_masks += [0] * len(obs_tokens_ids)
if sample.rollout_log_probs is not None:
sample.rollout_log_probs += [0.0] * len(obs_tokens_ids)
assert len(response_token_ids) == len(
sample.rollout_log_probs
), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps"
max_response_tokens = max_context_length - len(prompt_tokens_ids)
if len(response_token_ids) > max_response_tokens:
response_token_ids = response_token_ids[:max_response_tokens]
loss_masks = loss_masks[:max_response_tokens]
if sample.rollout_log_probs is not None:
sample.rollout_log_probs = sample.rollout_log_probs[:max_response_tokens]
obs_truncated = True
break
if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]:
break
if verbose:
print(f"\n{'═' * _LOG_WIDTH}", flush=True)
print(
f"[ReTool LOG] finished | tool_calls={tool_call_count} | "
f"response_tokens={len(response_token_ids)} | "
f"finish={output['meta_info']['finish_reason']['type']}",
flush=True,
)
print("═" * _LOG_WIDTH + "\n", flush=True)
sample.tokens = prompt_tokens_ids + response_token_ids
sample.response_length = len(response_token_ids)
sample.response = response
sample.loss_mask = loss_masks
sample.payload_text = prompt + response
sample.payload_has_system = "<|im_start|>system" in prompt + response
sample.payload_has_tools = "# Tools" in prompt + response
sample.tool_call_count = tool_call_count
match output["meta_info"]["finish_reason"]["type"]:
case "length":
sample.status = Sample.Status.TRUNCATED
case "abort":
sample.status = Sample.Status.ABORTED
case "stop":
sample.status = Sample.Status.COMPLETED
if obs_truncated:
sample.status = Sample.Status.TRUNCATED
return sample
async def reward_func(args, sample, **kwargs):
"""Tool call reward function using math_dapo as primary reward model"""
if not isinstance(sample, Sample):
raise TypeError("Sample must be an instance of Sample class.")
if isinstance(sample.prompt, list):
prompt_str = "".join(m.get("content", "") for m in sample.prompt)
else:
prompt_str = sample.prompt
solution_str = prompt_str + sample.response
ground_truth = sample.label if sample.label is not None else ""
num_turns = getattr(sample, "tool_call_count", 0)
result = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=True)
if result["score"] < 0:
tool_call_reward = (num_turns - 2) / 2 * 0.1
result["score"] = min(-0.6, result["score"] + tool_call_reward)
if result["pred"] is None:
result["pred"] = ""
return result