ZZhiyu Chengadd files
abf93d00创建于 2025年10月23日历史提交
# SPDX-License-Identifier: Apache-2.0

import ast
import json
import re
from collections.abc import Sequence
from typing import Union

import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaFunctionCall, DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
    ToolParser,
    ToolParserManager,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)


@ToolParserManager.register_module("llama_nemotron_xml")
class LlamaNemotronXMLToolParser(ToolParser):

    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1  # Potentially for streaming
        self.streamed_args_for_tool: list[str] = [] # Potentially for streaming

        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"

        # Regex to find full <tool_call>...</tool_call> blocks and capture their content
        self.tool_call_block_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
        # Regex to find <tool>...</tool> within a tool_call block content
        self.name_regex = re.compile(r"<tool>(.*?)</tool>", re.DOTALL)
        # Regex to find <key>value</key> pairs within the tool_call block content (excluding <tool> tags)
        self.param_regex = re.compile(r"<([^/>\s]+)>(.*?)</\1>", re.DOTALL)

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:

        tool_call_start_index = model_output.find(self.tool_call_start_token)

        if tool_call_start_index == -1:
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )
        
        content = model_output[:tool_call_start_index].strip()
        tool_calls_str_content = model_output[tool_call_start_index:]

        parsed_tool_calls = []
        
        try:
            # Find all occurrences of <tool_call>...</tool_call>
            xml_tool_call_contents = self.tool_call_block_regex.findall(tool_calls_str_content)

            for tool_content_str in xml_tool_call_contents:
                name_match = self.name_regex.search(tool_content_str)
                if not name_match:
                    logger.warning(f"Could not find tool name in XML block: {tool_content_str}")
                    continue
                tool_name = name_match.group(1).strip()

                parsed_arguments = {}
                
                # Find all parameter tags in the tool_call content, excluding the <tool> tag
                param_matches = self.param_regex.finditer(tool_content_str)
                
                for match in param_matches:
                    param_name = match.group(1).strip()
                    param_value_str = match.group(2).strip()
                    
                    # Skip the <tool> tag since it's not a parameter
                    if param_name == "tool":
                        continue
                    
                    target_type = None
                    # Try to get type from request.tools schema
                    if request.tools:
                        for tool_def in request.tools:
                            if tool_def.function.name == tool_name:
                                if tool_def.function.parameters and \
                                   isinstance(tool_def.function.parameters, dict) and \
                                   "properties" in tool_def.function.parameters and \
                                   isinstance(tool_def.function.parameters["properties"], dict) and \
                                   param_name in tool_def.function.parameters["properties"] and \
                                   isinstance(tool_def.function.parameters["properties"][param_name], dict):
                                    target_type = tool_def.function.parameters["properties"][param_name].get("type")
                                break
                    
                    typed_param_value = param_value_str # Default to string
                    if target_type:
                        try:
                            if target_type == "string":
                                typed_param_value = param_value_str
                            elif target_type == "integer":
                                typed_param_value = int(param_value_str)
                            elif target_type == "number":
                                typed_param_value = float(param_value_str)
                            elif target_type == "boolean":
                                typed_param_value = param_value_str.lower() == 'true'
                            elif target_type in ["object", "array"]:
                                try:
                                    typed_param_value = json.loads(param_value_str)
                                except json.JSONDecodeError:
                                    # Fallback for non-strict JSON like Python dict/list string
                                    typed_param_value = ast.literal_eval(param_value_str)
                            else: # Unknown type, keep as string
                                typed_param_value = param_value_str
                        except (ValueError, SyntaxError, json.JSONDecodeError) as e:
                            logger.warning(
                                f"Could not convert param '{param_name}' with value '{param_value_str}' "
                                f"to type '{target_type}'. Error: {e}. Using string value."
                            )
                            typed_param_value = param_value_str
                    else: # No schema type, try ast.literal_eval
                        try:
                            # For values like "true", "123", "['a', 'b']"
                            # ast.literal_eval('some_string_without_quotes') will raise SyntaxError
                            if (param_value_str.startswith("'") and param_value_str.endswith("'")) or \
                               (param_value_str.startswith('"') and param_value_str.endswith('"')) or \
                               (param_value_str.startswith('[') and param_value_str.endswith(']')) or \
                               (param_value_str.startswith('{') and param_value_str.endswith('}')) or \
                               param_value_str.lower() in ['true', 'false', 'none'] or \
                               param_value_str.replace('.', '', 1).isdigit() or \
                               (param_value_str.startswith('-') and param_value_str[1:].replace('.', '', 1).isdigit()):
                                typed_param_value = ast.literal_eval(param_value_str)
                            else: # It's likely a plain string not meant for ast.literal_eval
                                typed_param_value = param_value_str
                        except (ValueError, SyntaxError):
                            typed_param_value = param_value_str # Keep as string if ast.literal_eval fails

                    parsed_arguments[param_name] = typed_param_value
                
                parsed_tool_calls.append(ToolCall(
                    id=f"call_{random_uuid()}",
                    type="function",
                    function=FunctionCall(
                        name=tool_name,
                        arguments=json.dumps(parsed_arguments, ensure_ascii=False),
                    ),
                ))

            return ExtractedToolCallInformation(
                tools_called=len(parsed_tool_calls) > 0,
                tool_calls=parsed_tool_calls,
                content=content if content else None,
            )

        except Exception:
            logger.exception(f"Error in extracting XML tool call from response. Response: {model_output}")
            # Fallback to original model output if parsing fails catastrophically
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        raise NotImplementedError("Tool calling is not supported in streaming mode!")


@ToolParserManager.register_module("llama_nemotron_json")
class LlamaNemotronJSONToolParser(ToolParser):

    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1
        self.streamed_args_for_tool: list[str] = []

        self.tool_call_start_token: str = "<TOOLCALL>"
        self.tool_call_end_token: str = "</TOOLCALL>"

        self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:

        if self.tool_call_start_token not in model_output:
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

        else:

            try:
                str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip()
                if not str_tool_calls.startswith("["):
                    str_tool_calls = "[" + str_tool_calls
                if not str_tool_calls.endswith("]"):
                    str_tool_calls = str_tool_calls + "]"
                json_tool_calls = json.loads(str_tool_calls)
                tool_calls = []
                for tool_call in json_tool_calls:
                    try:
                        tool_calls.append(ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=tool_call["name"],
                                arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \
                                    if isinstance(tool_call["arguments"], dict) else tool_call["arguments"],
                            ),
                        ))
                    except:
                        continue

                content = model_output[:model_output.rfind(self.tool_call_start_token)]

                return ExtractedToolCallInformation(
                    tools_called=True,
                    tool_calls=tool_calls,
                    content=content if content else None,
                )

            except Exception:
                logger.exception(f"Error in extracting tool call from response. Response: {model_output}")
                return ExtractedToolCallInformation(
                    tools_called=False,
                    tool_calls=[],
                    content=model_output,
                )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        raise NotImplementedError("Tool calling is not supported in streaming mode!")


@ToolParserManager.register_module("llama_nemotron_pythonic")
class LlamaNemotronPythonicToolParser(ToolParser):

    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1
        self.streamed_args_for_tool: list[str] = []

        self.tool_call_start_token: str = "<TOOLCALL>"
        self.tool_call_end_token: str = "</TOOLCALL>"

        self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
        # Regex to parse pythonic function calls: function_name(arg1="value1", arg2=123, arg3=True)
        self.function_call_regex = re.compile(r"(\w+)\((.*?)\)$", re.DOTALL)

    def parse_function_arguments(self, args_str: str) -> dict:
        """Parse pythonic function arguments string into a dictionary"""
        if not args_str.strip():
            return {}
        
        # Use ast.parse to safely parse the function call arguments
        # We'll construct a temporary function call and parse it
        try:
            # Create a dummy function call to parse arguments
            dummy_code = f"dummy_func({args_str})"
            parsed = ast.parse(dummy_code, mode='eval')
            
            # Extract arguments from the AST
            call_node = parsed.body
            if not isinstance(call_node, ast.Call):
                return {}
            
            arguments = {}
            
            # Handle keyword arguments
            for keyword in call_node.keywords:
                if keyword.arg is None:  # **kwargs
                    continue
                    
                # Convert AST value to Python value
                try:
                    value = ast.literal_eval(keyword.value)
                    arguments[keyword.arg] = value
                except (ValueError, TypeError):
                    # If literal_eval fails, try to get the raw value
                    if isinstance(keyword.value, ast.Name):
                        arguments[keyword.arg] = keyword.value.id
                    elif isinstance(keyword.value, ast.Constant):
                        arguments[keyword.arg] = keyword.value.value
                    else:
                        # Fallback: convert to string
                        arguments[keyword.arg] = ast.unparse(keyword.value)
            
            # Handle positional arguments (less common in tool calls but supported)
            for i, arg in enumerate(call_node.args):
                try:
                    value = ast.literal_eval(arg)
                    arguments[f"arg_{i}"] = value
                except (ValueError, TypeError):
                    if isinstance(arg, ast.Name):
                        arguments[f"arg_{i}"] = arg.id
                    elif isinstance(arg, ast.Constant):
                        arguments[f"arg_{i}"] = arg.value
                    else:
                        arguments[f"arg_{i}"] = ast.unparse(arg)
            
            return arguments
            
        except (SyntaxError, ValueError) as e:
            logger.warning(f"Failed to parse function arguments '{args_str}': {e}")
            return {}

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:

        if self.tool_call_start_token not in model_output:
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

        tool_call_start_index = model_output.find(self.tool_call_start_token)
        content = model_output[:tool_call_start_index].strip()
        
        try:
            # Extract content between <TOOLCALL> tags
            tool_call_matches = self.tool_call_regex.findall(model_output)
            if not tool_call_matches:
                return ExtractedToolCallInformation(
                    tools_called=False,
                    tool_calls=[],
                    content=model_output,
                )
            
            tool_calls_content = tool_call_matches[0].strip()
            
            # Split by lines to get individual function calls
            function_lines = [line.strip() for line in tool_calls_content.split('\n') if line.strip()]
            
            parsed_tool_calls = []
            
            for func_line in function_lines:
                # Parse each function call
                match = self.function_call_regex.match(func_line)
                if not match:
                    logger.warning(f"Could not parse function call: {func_line}")
                    continue
                
                function_name = match.group(1)
                args_str = match.group(2)
                
                # Parse arguments
                parsed_arguments = self.parse_function_arguments(args_str)
                
                # Apply type conversion based on schema if available
                if request.tools:
                    for tool_def in request.tools:
                        if tool_def.function.name == function_name:
                            schema_properties = {}
                            if (tool_def.function.parameters and 
                                isinstance(tool_def.function.parameters, dict) and 
                                "properties" in tool_def.function.parameters and 
                                isinstance(tool_def.function.parameters["properties"], dict)):
                                schema_properties = tool_def.function.parameters["properties"]
                            
                            # Convert arguments based on schema types
                            for arg_name, arg_value in parsed_arguments.items():
                                if arg_name in schema_properties:
                                    param_info = schema_properties[arg_name]
                                    target_type = param_info.get("type")
                                    
                                    try:
                                        if target_type == "string" and not isinstance(arg_value, str):
                                            parsed_arguments[arg_name] = str(arg_value)
                                        elif target_type == "integer" and not isinstance(arg_value, int):
                                            parsed_arguments[arg_name] = int(arg_value)
                                        elif target_type == "number" and not isinstance(arg_value, (int, float)):
                                            parsed_arguments[arg_name] = float(arg_value)
                                        elif target_type == "boolean" and not isinstance(arg_value, bool):
                                            if isinstance(arg_value, str):
                                                parsed_arguments[arg_name] = arg_value.lower() in ['true', '1', 'yes']
                                            else:
                                                parsed_arguments[arg_name] = bool(arg_value)
                                        elif target_type in ["object", "array"]:
                                            if isinstance(arg_value, str):
                                                try:
                                                    parsed_arguments[arg_name] = json.loads(arg_value)
                                                except json.JSONDecodeError:
                                                    # Keep as string if JSON parsing fails
                                                    pass
                                    except (ValueError, TypeError) as e:
                                        logger.warning(f"Type conversion failed for {arg_name}: {e}")
                                        # Keep original value if conversion fails
                            break
                
                parsed_tool_calls.append(ToolCall(
                    id=f"call_{random_uuid()}",
                    type="function",
                    function=FunctionCall(
                        name=function_name,
                        arguments=json.dumps(parsed_arguments, ensure_ascii=False),
                    ),
                ))

            return ExtractedToolCallInformation(
                tools_called=len(parsed_tool_calls) > 0,
                tool_calls=parsed_tool_calls,
                content=content if content else None,
            )

        except Exception:
            logger.exception(f"Error in extracting pythonic tool call from response. Response: {model_output}")
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output,
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:

        raise NotImplementedError("Tool calling is not supported in streaming mode!")