# Adapted from https://github.com/volcengine/verl/blob/cb809d66e46dfd3342d008628891a14a054fa424/recipe/retool/retool.py

import random

import re

from typing import Any



try:

    from jinja2 import Template

except ImportError as e:

    raise ImportError("Jinja2 is required. Please install it with: pip install jinja2") from e



from slime.rollout.sglang_rollout import GenerateState

from slime.utils.http_utils import post

from slime.utils.types import Sample



# Import reward models

try:

    from slime.rollout.rm_hub.math_dapo_utils import compute_score as math_dapo_compute_score

except ImportError as e:

    raise ImportError("MathDapo is not installed") from e



# Import tool sandbox functionality

from tool_sandbox import SEMAPHORE, TOOL_CONFIGS, tool_registry



# ── Sample-level verbose logging ──────────────────────────────────────────

# Roughly 1/20 samples are logged so the output stays readable.

_LOG_SAMPLE_PROB = 0.05

_LOG_WIDTH = 300





def _trunc(s: str, n: int = 300) -> str:

    """Truncate *s* to at most *n* characters for display."""

    if len(s) <= n:

        return s

    return s[:n] + f"…[+{len(s) - n}]"





# Jinja2 template for tool-enabled conversations

TOOL_TEMPLATE = """<|im_start|>system

{%- if messages[0]['role'] == 'system' %}

{{- messages[0]['content'] }}

{%- else %}

You are a helpful assistant.

{%- endif %}

{%- if tools %}

# Tools



You may call one or more functions to assist with the user query.



You are provided with function signatures within <tools></tools> XML tags:

<tools>

{%- for tool in tools %}

{{- tool | tojson }}

{%- endfor %}

</tools>



For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:

<tool_call>

{"name": <function-name>, "arguments": <args-json-object>}

</tool_call>

{%- endif %}

<|im_end|>

{%- for message in messages %}

{%- if message['role'] == 'user' %}

<|im_start|>user

{{- message['content'] }}<|im_end|>

{%- elif message['role'] == 'assistant' %}

<|im_start|>assistant

{{- message['content'] }}<|im_end|>

{%- endif %}

{%- endfor %}

<|im_start|>assistant

"""





def format_conversation_with_tools(

    prompt: str, tools: list[dict[str, Any]] = None, system_prompt: str = None, messages: list[dict[str, Any]] = None

) -> str:

    """Format conversation using Jinja2 template with tool support"""

    template = Template(TOOL_TEMPLATE)



    # Prepare messages

    messages_to_render = []



    # Always add system message - use provided one or default

    if system_prompt:

        system_content = system_prompt

    else:

        system_content = (

            "You are a helpful assistant that can use Python "

            "tools to solve mathematical problems. When you need "

            "to perform calculations, use the code_interpreter "

            "tool to execute code and get results."

        )



    messages_to_render.append({"role": "system", "content": system_content})



    # Add user message if provided

    if prompt:

        messages_to_render.append({"role": "user", "content": prompt})



    # Add assistant responses from previous turns if provided

    if messages:

        messages_to_render.extend(messages)



    # Render template

    formatted_text = template.render(messages=messages_to_render, tools=tools or [])



    return formatted_text





def postprocess_predictions(prediction: str):

    """Extract action and content from prediction string"""

    # Check for Answer: \boxed{...} format (only format we need for math_dapo)

    # Use a more robust regex that handles nested braces

    answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}"

    answer_match = re.search(answer_pattern, prediction, re.DOTALL)

    if answer_match:

        content = answer_match.group(1).strip()

        return "answer", content



    # Then check for <tool_call> tags (new format from Jinja2 template)

    tool_call_pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"

    tool_call_match = re.search(tool_call_pattern, prediction, re.DOTALL)

    if tool_call_match:

        try:

            import json



            # Clean up the JSON string by removing newlines and extra

            # whitespace

            json_str = tool_call_match.group(1)

            # Replace newlines in string values with \n

            json_str = json_str.replace("\n", "\\n")

            tool_call_data = json.loads(json_str)

            tool_name = tool_call_data.get("name")

            arguments = tool_call_data.get("arguments", {})



            if tool_name == "code_interpreter":

                code = arguments.get("code", "")

                if code.strip():

                    return "code", code

        except (json.JSONDecodeError, KeyError, AttributeError):

            pass



    # Check for GLM4.7-native tool call format:

    # <tool_call>funcname<arg_key>argname</arg_key><arg_value>raw_python_code</arg_value></tool_call>

    # The argument value is raw code, NOT JSON.

    glm_tool_call_pattern = (

        r"<tool_call>\s*(\w[\w.]*)\s*"

        r"(?:<arg_key>[^<]*</arg_key>)?\s*"

        r"<arg_value>(.*?)</arg_value>\s*"

        r"</tool_call>"

    )

    glm_match = re.search(glm_tool_call_pattern, prediction, re.DOTALL)

    if glm_match:

        tool_name = glm_match.group(1).strip()

        code = glm_match.group(2).strip()

        if tool_name == "code_interpreter" and code:

            return "code", code



    # Then check for <code> tags

    code_pattern = r"<code>(.*?)</code>"

    code_match = re.search(code_pattern, prediction, re.DOTALL)

    if code_match:

        content = code_match.group(1).strip()

        return "code", content



    # Finally check for ```python code blocks (lowest priority)

    python_code_pattern = r"```python\s*(.*?)\s*```"

    python_code_match = re.search(python_code_pattern, prediction, re.DOTALL)

    if python_code_match:

        content = python_code_match.group(1).strip()

        return "code", content



    return None, ""





def postprocess_responses(resp: str) -> str:

    """Post-process response to ensure tag completeness"""

    # Handle <tool_call> tags (Qwen and GLM4.7 formats)

    if "<tool_call>" in resp:

        # Try Qwen-style JSON format: <tool_call>{"name": ..., "arguments": {...}}</tool_call>

        tool_call_pattern = r"<tool_call>\s*\{.*?\}\s*</tool_call>"

        matches = list(re.finditer(tool_call_pattern, resp, re.DOTALL))

        if matches:

            last_match = matches[-1]

            return resp[: last_match.end()]

        # Try GLM4.7-native format:

        # <tool_call>funcname<arg_key>argname</arg_key><arg_value>code</arg_value></tool_call>

        glm_tool_call_pattern = (

            r"<tool_call>\s*\w[\w.]*\s*"

            r"(?:<arg_key>[^<]*</arg_key>)?\s*"

            r"<arg_value>.*?</arg_value>\s*"

            r"</tool_call>"

        )

        glm_matches = list(re.finditer(glm_tool_call_pattern, resp, re.DOTALL))

        if glm_matches:

            return resp[: glm_matches[-1].end()]



    # Handle <code> tags

    if "</code>" in resp:

        return resp.split("</code>")[0] + "</code>"



    # Handle ```python code blocks

    if "```python" in resp:

        # Find the last occurrence of ```python...```

        python_pattern = r"```python\s*.*?```"

        matches = list(re.finditer(python_pattern, resp, re.DOTALL))

        if matches:

            last_match = matches[-1]

            return resp[: last_match.end()]



    # Handle Answer: \boxed{...} format (only format we need for math_dapo)

    if "Answer:" in resp and "\\boxed{" in resp:

        # Find the last occurrence of Answer: \boxed{...} with nested braces support

        answer_pattern = r"Answer:\s*\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}"

        matches = list(re.finditer(answer_pattern, resp, re.DOTALL))

        if matches:

            last_match = matches[-1]

            return resp[: last_match.end()]



    return resp





def _uses_glm_tool_format(prediction: str) -> bool:

    """Detect GLM4.7-native tool call format by presence of </arg_value>."""

    return "</arg_value>" in prediction





async def execute_predictions(prediction: str) -> str:

    """Execute predictions and return results"""

    action, content = postprocess_predictions(prediction)

    glm_format = _uses_glm_tool_format(prediction)



    if action == "code":

        # Content is already the Python code (extracted by

        # postprocess_predictions)

        code = content.strip()

        if code:

            async with SEMAPHORE:

                result = await tool_registry.execute_tool("code_interpreter", {"code": code})

            if glm_format:

                # GLM4.7: model generates <|observation|> as stop token before

                # halting; append result and signal the next assistant turn.

                next_obs = f"\n\n<interpreter>\n{result}\n</interpreter>\n<|assistant|>\n"

            else:

                next_obs = f"\n\n<interpreter>\n{result}\n</interpreter>\n\n"

            done = False

        else:

            if glm_format:

                next_obs = "\n\n<interpreter>\nError: No Python code found\n</interpreter>\n<|assistant|>\n"

            else:

                next_obs = "\n\n<interpreter>\nError: No Python code found\n</interpreter>\n\n"

            done = False

    elif action == "answer":

        next_obs = ""

        done = True

    else:

        next_obs = (

            "\nMy previous action is invalid. "

            "If I want to execute code, I should put the code between "

            "<code> and </code>. "

            "If I want to give the final answer, I should use the format "

            "'Answer: \\boxed{answer}'. Let me try again.\n"

        )

        done = False



    return next_obs, done





async def generate(args, sample: Sample, sampling_params) -> Sample:

    """Custom generation function supporting tool calls"""

    assert not args.partial_rollout, "Partial rollout is not supported for " "this function at the moment."



    state = GenerateState(args)

    url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"



    # Set up the initial prompt with system prompt and tools (outside the loop)

    tool_specs = tool_registry.get_tool_specs()

    # Use the tokenizer's own chat template so tool call instructions match the

    # model's native format (e.g. GLM4.7 vs Qwen).  Fall back to the Qwen-style

    # TOOL_TEMPLATE for tokenizers that do not support the tools= parameter.

    #

    # sample.prompt may be a plain string or a list of message dicts depending

    # on whether --apply-chat-template was set and how the dataset stores prompts.

    if isinstance(sample.prompt, list):

        messages = sample.prompt

    else:

        messages = [{"role": "user", "content": sample.prompt}]

    try:

        prompt = state.tokenizer.apply_chat_template(

            messages,

            tools=tool_specs,

            tokenize=False,

            add_generation_prompt=True,

        )

    except Exception:

        # Fallback: extract raw text from the last user turn if needed.

        if isinstance(sample.prompt, list):

            raw_prompt = next(

                (m["content"] for m in reversed(sample.prompt) if m.get("role") == "user"),

                "",

            )

        else:

            raw_prompt = sample.prompt

        prompt = format_conversation_with_tools(prompt=raw_prompt, tools=tool_specs)



    prompt_tokens_ids = state.tokenizer(prompt, add_special_tokens=False)["input_ids"]

    response = ""

    response_token_ids = []

    loss_masks = []

    tool_call_count = 0  # Track actual tool call rounds

    obs_truncated = False  # Flag: obs caused total length to exceed max_context_length

    # Calculate max context length once at the beginning

    max_context_length = len(prompt_tokens_ids) + args.rollout_max_response_len

    print(f"max_context_length is set to {max_context_length}", flush=True)



    # Randomly select a small fraction of samples for detailed turn-by-turn logging.

    verbose = random.random() < _LOG_SAMPLE_PROB

    if verbose:

        _sep = "═" * _LOG_WIDTH

        print(f"\n{_sep}", flush=True)

        _prompt_display = (

            "".join(m.get("content", "") for m in sample.prompt)

            if isinstance(sample.prompt, list)

            else sample.prompt

        )

        print(f"[ReTool LOG] prompt ({len(prompt_tokens_ids)} tokens): {_trunc(_prompt_display, 200)}", flush=True)

        print(_sep, flush=True)



    for turn in range(TOOL_CONFIGS["max_turns"]):

        # Use token IDs instead of text

        current_token_ids = prompt_tokens_ids + response_token_ids



        # Check if total length exceeds max context length

        total_length = len(current_token_ids)

        if total_length >= max_context_length:

            sample.status = Sample.Status.TRUNCATED

            break



        # Dynamically calculate remaining token budget for this turn

        remaining_tokens = max_context_length - total_length



        # Update max_new_tokens for this turn to respect the remaining budget

        # Make a copy to avoid modifying the original sampling_params

        current_sampling_params = sampling_params.copy()

        current_sampling_params["max_new_tokens"] = min(

            sampling_params.get("max_new_tokens", args.rollout_max_response_len),

            remaining_tokens

        )



        # Check if we have budget for more tokens

        if current_sampling_params["max_new_tokens"] <= 0:

            sample.status = Sample.Status.TRUNCATED

            break



        payload = {

            "input_ids": current_token_ids,

            "sampling_params": current_sampling_params,

            "return_logprob": True,  # Request log probabilities for training

        }



        # Log payload to wandb for debugging

        try:

            import wandb



            if wandb.run is not None:

                # Count available tools (from tool_specs)

                available_tools = len(tool_specs)

                # Count tools used in the current response

                tools_used = response.count("<interpreter>")



                wandb.log(

                    {

                        "debug/payload_length": len(prompt + response),

                        "debug/available_tools": available_tools,

                        "debug/tools_used": tools_used,

                        "debug/turn": turn,

                    }

                )

        except ImportError:

            pass  # wandb not available



        output = await post(url, payload)



        # Handle abort

        if output["meta_info"]["finish_reason"]["type"] == "abort":

            sample.status = Sample.Status.ABORTED

            return sample



        if "output_token_logprobs" in output["meta_info"]:

            cur_response_token_ids = [item[1] for item in output["meta_info"]["output_token_logprobs"]]

            cur_response = state.tokenizer.decode(cur_response_token_ids)

            cur_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]

            if sample.rollout_log_probs is None:

                sample.rollout_log_probs = []

            sample.rollout_log_probs += cur_log_probs



        else:

            cur_response = output["text"]

            cur_response = postprocess_responses(cur_response)

            cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]



        response += cur_response

        response_token_ids += cur_response_token_ids

        loss_masks += [1] * len(cur_response_token_ids)



        # verbose: show what the model generated this turn

        if verbose:

            n_tok = len(cur_response_token_ids)

            finish = output["meta_info"]["finish_reason"]["type"]

            print(f"\n{'─' * _LOG_WIDTH}", flush=True)

            print(f"[Turn {turn + 1}] model output ({n_tok} tok, finish={finish}):", flush=True)

            print("  " + _trunc(cur_response).replace("\n", "\n  "), flush=True)



        # Check length limit

        if output["meta_info"]["finish_reason"]["type"] == "length":

            if verbose:

                print(f"[Turn {turn + 1}] → length limit reached, stopping.", flush=True)

            break



        next_obs, done = await execute_predictions(cur_response)



        # verbose: show action and observation

        if verbose:

            if done:

                print(f"[Turn {turn + 1}] → answer detected (DONE)", flush=True)

            elif "<interpreter>" in next_obs:

                obs_display = "  " + _trunc(next_obs, 300).replace("\n", "\n  ")

                print(f"[Turn {turn + 1}] → code executed, observation:", flush=True)

                print(obs_display, flush=True)

            else:

                print(f"[Turn {turn + 1}] → invalid action (no recognized code or answer)", flush=True)



        if done:

            break



        # Count tool calls (when we get interpreter output, it means a tool

        # was called)

        if "<interpreter>" in next_obs:

            tool_call_count += 1



        assert next_obs != "", "Next observation should not be empty."

        obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"]

        response += next_obs

        response_token_ids += obs_tokens_ids

        loss_masks += [0] * len(obs_tokens_ids)



        # Add dummy log probs for observation tokens (they won't be used due to loss_mask=0)

        # Check if maximum tool call count reached

        if sample.rollout_log_probs is not None:

            sample.rollout_log_probs += [0.0] * len(obs_tokens_ids)



            assert len(response_token_ids) == len(

                sample.rollout_log_probs

            ), f"Token/logp length mismatch at turn {turn}: {len(response_token_ids)} tokens vs {len(sample.rollout_log_probs)} logps"



        # Truncate if obs pushed total response tokens beyond max_context_length

        max_response_tokens = max_context_length - len(prompt_tokens_ids)

        if len(response_token_ids) > max_response_tokens:

            response_token_ids = response_token_ids[:max_response_tokens]

            loss_masks = loss_masks[:max_response_tokens]

            if sample.rollout_log_probs is not None:

                sample.rollout_log_probs = sample.rollout_log_probs[:max_response_tokens]

            obs_truncated = True

            break



        if tool_call_count >= TOOL_CONFIGS["max_tool_calls"]:

            break



    if verbose:

        print(f"\n{'═' * _LOG_WIDTH}", flush=True)

        print(

            f"[ReTool LOG] finished | tool_calls={tool_call_count} | "

            f"response_tokens={len(response_token_ids)} | "

            f"finish={output['meta_info']['finish_reason']['type']}",

            flush=True,

        )

        print("═" * _LOG_WIDTH + "\n", flush=True)



    # Set sample attributes

    sample.tokens = prompt_tokens_ids + response_token_ids

    sample.response_length = len(response_token_ids)

    sample.response = response

    sample.loss_mask = loss_masks



    # Store payload information for wandb logging

    sample.payload_text = prompt + response

    sample.payload_has_system = "<|im_start|>system" in prompt + response

    sample.payload_has_tools = "# Tools" in prompt + response



    # Store tool call count for reward calculation

    sample.tool_call_count = tool_call_count



    # Set status

    match output["meta_info"]["finish_reason"]["type"]:

        case "length":

            sample.status = Sample.Status.TRUNCATED

        case "abort":

            sample.status = Sample.Status.ABORTED

        case "stop":

            sample.status = Sample.Status.COMPLETED



    if obs_truncated:

        sample.status = Sample.Status.TRUNCATED



    return sample





async def reward_func(args, sample, **kwargs):

    """Tool call reward function using math_dapo as primary reward model"""

    if not isinstance(sample, Sample):

        raise TypeError("Sample must be an instance of Sample class.")



    # Build complete solution string.

    # sample.prompt may be a list of message dicts; flatten to plain text.

    if isinstance(sample.prompt, list):

        prompt_str = "".join(m.get("content", "") for m in sample.prompt)

    else:

        prompt_str = sample.prompt

    solution_str = prompt_str + sample.response



    # Get ground truth answer - label is a string, not a dict

    ground_truth = sample.label if sample.label is not None else ""



    # Get tool call count as num_turns

    num_turns = getattr(sample, "tool_call_count", 0)



    # use \\boxed{...} answer

    result = math_dapo_compute_score(solution_str, ground_truth, strict_box_verify=True)



    # encourage model to call tools

    if result["score"] < 0:

        tool_call_reward = (num_turns - 2) / 2 * 0.1

        result["score"] = min(-0.6, result["score"] + tool_call_reward)



    if result["pred"] is None:

        result["pred"] = ""



    return result