import concurrent.futures
import json
import logging
import os
import re
import threading
import time
import traceback
import uuid
from typing import Any, Dict, List, Optional, Tuple
import requests
DEFAULT_TIMEOUT = 10
MAX_RETRIES = 4
INITIAL_RETRY_DELAY = 40
API_TIMEOUT = 10
SUPPORTED_LANGUAGES = [
"python",
"cpp",
"nodejs",
"go",
"go_test",
"java",
"php",
"csharp",
"bash",
"typescript",
"sql",
"rust",
"cuda",
"lua",
"R",
"perl",
"D_ut",
"ruby",
"scala",
"julia",
"pytest",
"junit",
"kotlin_script",
"jest",
"verilog",
"python_gpu",
"lean",
"swift",
"racket",
]
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))
DEFAULT_TOOL_CALL_BONUS_PER_CALL = 0.005
DEFAULT_MAX_REWARDED_TOOL_CALLS = 4
def _parse_test_cases(ground_truth):
"""
Parse and validate test cases from ground_truth.
Returns:
tuple: (test_cases, error_result) where error_result is None if successful,
or a tuple (score, metadata) if parsing failed.
"""
test_cases = ground_truth
if isinstance(test_cases, dict):
if not test_cases or "input" not in test_cases or "output" not in test_cases:
logger.error("Invalid test_cases structure.")
logger.error("%s ...", str(test_cases)[:100])
return None, (0.0, [{"error": "Invalid test_cases structure (missing inputs/outputs)"}])
return test_cases, None
try:
test_cases = json.loads(test_cases)
except json.JSONDecodeError as e1:
try:
import pickle
import zlib
import base64
test_cases = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode("utf-8")))))
except json.JSONDecodeError as e2:
logger.error("Failed to parse test_cases JSON: %s", e)
return None, (0.0, [{"error": "Invalid test_cases JSON format"}])
if not test_cases or "input" not in test_cases or "output" not in test_cases:
logger.error("Invalid test_cases structure.")
logger.error("%s ...", str(test_cases)[:100])
return None, (0.0, [{"error": "Invalid test_cases structure (missing inputs/outputs)"}])
return test_cases, None
def _process_api_response(api_response, error_msg, test_cases, solution):
"""
Process the sandbox API response and build metadata.
Returns:
dict: Metadata dictionary with score and execution details.
"""
metadata = {
"input": str(test_cases),
"api_request_error": error_msg,
"api_response": None,
"status": "unknown",
"stdout": None,
"stderr": None,
"exit_code": None,
"duration": None,
"compile_duration": None,
"compile_stderr": None,
"api_status": None,
"compile_status": None,
"run_status": None,
"score": 0.0,
}
if error_msg:
metadata["status"] = "api_error"
logger.error("Sandbox Error Report: API error occurred: %s", error_msg)
generation_to_log = solution[:200] + "..." if len(solution) > 200 else solution
logger.error("Sandbox Error Report: Generation: %s", generation_to_log)
return metadata
if not api_response:
return metadata
logger.debug("Sandbox Debug Report: API Response: %s", api_response)
metadata["api_response"] = api_response
metadata["api_status"] = api_response.get("status")
compile_result = api_response.get("compile_result")
if compile_result:
metadata["compile_status"] = compile_result.get("status")
metadata["compile_duration"] = compile_result.get("execution_time")
metadata["compile_stderr"] = compile_result.get("stderr")
run_result = api_response.get("run_result")
if run_result:
metadata["run_status"] = run_result.get("status")
metadata["stdout"] = run_result.get("stdout")
metadata["stderr"] = run_result.get("stderr")
metadata["exit_code"] = run_result.get("return_code")
metadata["duration"] = run_result.get("execution_time")
if api_response.get("accepted", None) is True:
metadata["status"] = "success"
metadata["score"] = 1.0
else:
metadata["status"] = "wrong_answer"
cases = api_response.get("tests", [])
total_cases = len(cases)
passed_cases = sum(1 for test in cases if test and test.get("passed", False))
if total_cases > 0:
metadata["score"] = passed_cases / total_cases
return metadata
def _get_tool_call_count(extra_info: Optional[Dict[str, Any]]) -> int:
if not extra_info or not isinstance(extra_info, dict):
return 0
if "tool_call_count" in extra_info:
try:
return int(extra_info["tool_call_count"])
except (TypeError, ValueError):
return 0
tool_calls = extra_info.get("tool_calls")
if isinstance(tool_calls, list):
return len(tool_calls)
tool_call_names = extra_info.get("tool_call_names")
if isinstance(tool_call_names, list):
return len(tool_call_names)
return 0
def _compute_code_score_and_metadata(
solution_str: str,
ground_truth: Any,
sandbox_fusion_url: Optional[str],
timeout: int,
) -> Tuple[float, List[Dict[str, Any]]]:
solution = re.sub(r"<think>.*?</think>", "", solution_str, flags=re.DOTALL).strip()
language_str = re.search(r"```(\w+)", solution_str)
if language_str:
language = language_str.group(1).strip()
else:
language = "python"
test_cases, error_result = _parse_test_cases(ground_truth)
if error_result is not None:
return float(error_result[0]), error_result[1]
try:
api_response, error_msg = call_sandbox_api(
sandbox_fusion_url=sandbox_fusion_url,
code=solution,
in_outs=test_cases,
timeout=timeout,
language=language,
)
metadata = _process_api_response(api_response, error_msg, test_cases, solution)
score = float(metadata.get("score", 0.0))
final_metadata = [metadata]
logger.info("Sandbox Info Report: Results: %s", score)
except Exception as e:
score = 0.0
final_metadata = [{"error": f"Unhandled exception: {e}"}]
return float(score), final_metadata
def _format_score_result(
score: float,
metadata: List[Dict[str, Any]],
return_dict: bool,
include_metadata: bool,
extra_fields: Optional[Dict[str, Any]] = None,
):
if return_dict:
result = {"score": float(score)}
if extra_fields:
result.update(extra_fields)
if include_metadata:
result["metadata"] = json.dumps(metadata, ensure_ascii=True)
return result
return float(score), metadata
def _get_tool_call_bonus_config(kwargs: Dict[str, Any]) -> Tuple[float, int]:
tool_call_bonus_per_call = kwargs.get(
"tool_call_bonus_per_call",
kwargs.get("tool_call_reward", DEFAULT_TOOL_CALL_BONUS_PER_CALL),
)
try:
tool_call_bonus_per_call = float(tool_call_bonus_per_call)
except (TypeError, ValueError):
tool_call_bonus_per_call = DEFAULT_TOOL_CALL_BONUS_PER_CALL
max_rewarded_tool_calls = kwargs.get("max_rewarded_tool_calls", DEFAULT_MAX_REWARDED_TOOL_CALLS)
try:
max_rewarded_tool_calls = int(max_rewarded_tool_calls)
except (TypeError, ValueError):
max_rewarded_tool_calls = DEFAULT_MAX_REWARDED_TOOL_CALLS
max_rewarded_tool_calls = max(max_rewarded_tool_calls, 0)
return tool_call_bonus_per_call, max_rewarded_tool_calls
def _compute_tool_call_bonus(
extra_info: Optional[Dict[str, Any]],
kwargs: Dict[str, Any],
) -> Tuple[int, float, float, int]:
tool_call_bonus_per_call, max_rewarded_tool_calls = _get_tool_call_bonus_config(kwargs)
tool_call_count = max(_get_tool_call_count(extra_info), 0)
rewarded_tool_call_count = min(tool_call_count, max_rewarded_tool_calls)
tool_call_bonus = float(rewarded_tool_call_count) * tool_call_bonus_per_call
return tool_call_count, tool_call_bonus, tool_call_bonus_per_call, max_rewarded_tool_calls
def _extract_common_kwargs(kwargs: Dict[str, Any]) -> Tuple[Optional[str], bool, bool, int]:
"""Extract common kwargs shared by compute_reward and compute_score."""
sandbox_fusion_url = kwargs.get("sandbox_fusion_url")
return_dict = kwargs.get("return_dict", False)
include_metadata = kwargs.get("include_metadata", False)
timeout = kwargs.get("timeout", 30)
return sandbox_fusion_url, return_dict, include_metadata, timeout
def compute_reward(
data_source,
solution_str,
ground_truth,
extra_info=None,
**kwargs,
):
"""
Computes training reward: code correctness score + tool-call bonus.
"""
sandbox_fusion_url, return_dict, include_metadata, timeout = _extract_common_kwargs(kwargs)
base_score, final_metadata = _compute_code_score_and_metadata(
solution_str=solution_str,
ground_truth=ground_truth,
sandbox_fusion_url=sandbox_fusion_url,
timeout=timeout,
)
tool_call_count, tool_call_bonus, tool_call_bonus_per_call, max_rewarded_tool_calls = _compute_tool_call_bonus(
extra_info=extra_info,
kwargs=kwargs,
)
final_reward = float(base_score) + tool_call_bonus
return _format_score_result(
score=final_reward,
metadata=final_metadata,
return_dict=return_dict,
include_metadata=include_metadata,
extra_fields={
"code_score": float(base_score),
"base_score": float(base_score),
"tool_call_count": tool_call_count,
"tool_call_bonus": tool_call_bonus,
"tool_call_bonus_per_call": tool_call_bonus_per_call,
"max_rewarded_tool_calls": max_rewarded_tool_calls,
"tool_call_reward": tool_call_bonus_per_call,
},
)
def compute_score(
data_source,
solution_str,
ground_truth,
extra_info=None,
**kwargs,
):
"""
Computes code correctness score only (no tool-call bonus).
See compute_reward for full argument and return documentation.
"""
sandbox_fusion_url, return_dict, include_metadata, timeout = _extract_common_kwargs(kwargs)
score, final_metadata = _compute_code_score_and_metadata(
solution_str=solution_str,
ground_truth=ground_truth,
sandbox_fusion_url=sandbox_fusion_url,
timeout=timeout,
)
if return_dict:
result = {"score": float(score)}
if include_metadata:
result["metadata"] = json.dumps(final_metadata, ensure_ascii=True)
return result
return float(score), final_metadata
def _build_sandbox_payload(code: str, language: str, timeout: int, in_outs: Any) -> str:
"""Build JSON payload for sandbox API request."""
return json.dumps({
"completion": code,
"config": {
"language": language,
"compile_timeout": timeout,
"run_timeout": timeout,
"provided_data": {"test_cases": in_outs},
"extra": {"run_all_cases": True, "total_timeout": 30},
},
})
def _execute_single_request(url: str, payload: str, timeout: int) -> requests.Response:
"""Execute a single HTTP POST request to the sandbox API."""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
return requests.post(url, headers=headers, data=payload, timeout=timeout)
def _try_single_api_call(url: str, payload: str, timeout: int, attempt: int, log_prefix: str):
"""
Try a single API call attempt.
Returns:
tuple: (response, error) - response is not None on success,
error starts with "RETRY:" if should retry.
"""
try:
logger.info("%sAttempt %d/%d: Calling sandbox API at %s",
log_prefix, attempt + 1, MAX_RETRIES, url)
response = _execute_single_request(url, payload, timeout)
if response.status_code in [429, 500, 502, 503, 504]:
logger.warning("%sReceived status %d", log_prefix, response.status_code)
if attempt < MAX_RETRIES - 1:
time.sleep(INITIAL_RETRY_DELAY * (attempt + 1))
return None, "RETRY:%sReceived status %d" % (log_prefix, response.status_code)
response.raise_for_status()
logger.info("%sSandbox API call successful on attempt %d", log_prefix, attempt + 1)
return response, None
except Exception as e:
return None, "%sError: %s" % (log_prefix, e)
def call_sandbox_api(
sandbox_fusion_url: str,
code: str,
in_outs: Any,
timeout: int,
language: str,
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
"""
Calls the remote sandbox API to execute code and retries on specific HTTP errors.
Args:
sandbox_fusion_url (str): The URL of the sandbox fusion API endpoint.
code (str): The source code to be executed.
in_outs (any): The test cases to be used for evaluation.
timeout (int): The timeout in seconds for compilation and execution.
language (str): The programming language of the code.
Returns:
tuple: (response_json, error_message) - response_json is None on failure.
"""
request_id = str(uuid.uuid4())
log_prefix = "[Request ID: %s] " % request_id
if language not in SUPPORTED_LANGUAGES:
error_msg = "%sUnsupported language: %s" % (log_prefix, language)
logger.error(error_msg)
return None, error_msg
payload = _build_sandbox_payload(code, language, timeout, in_outs)
request_timeout = timeout * 2 + API_TIMEOUT
last_error = None
for attempt in range(MAX_RETRIES):
response, last_error = _try_single_api_call(
sandbox_fusion_url, payload, request_timeout, attempt, log_prefix
)
if response is not None:
return response.json(), None
if last_error and not last_error.startswith("RETRY:"):
break
if last_error:
last_error = last_error[6:]
logger.error("%sSandbox API call failed. Last error: %s", log_prefix, last_error)
return None, last_error.replace(log_prefix, "API Call Failed: ") if last_error else "API Call Failed"