"""综合Json生成器"""
import json
import logging
import re
from typing import Any
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from jsonschema import Draft7Validator
from apps.models import LLMType
from apps.schemas.llm import LLMFunctions
from .llm import LLM
from .token import token_calculator
_logger = logging.getLogger(__name__)
JSON_GEN_MAX_TRIAL = 3
class JsonValidationError(Exception):
"""JSON验证失败异常,携带生成的结果"""
def __init__(self, message: str, result: dict[str, Any]) -> None:
"""初始化异常"""
super().__init__(message)
self.result = result
class JsonGenerator:
"""综合Json生成器(全局单例)"""
def __init__(self, llm: LLM | None = None) -> None:
"""创建JsonGenerator实例"""
self._env = SandboxedEnvironment(
loader=BaseLoader(),
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
extensions=["jinja2.ext.loopcontrols"],
)
self._llm: LLM | None = llm
self._support_function_call: bool = False
if llm is not None:
self._check_function_call_support()
def init(self, llm: LLM) -> None:
"""初始化JsonGenerator,设置LLM配置"""
self._llm = llm
self._check_function_call_support()
def _check_function_call_support(self) -> None:
"""检查LLM是否支持FunctionCall"""
if self._llm is not None and LLMType.FUNCTION in self._llm.config.llmType:
_logger.info("[JSONGenerator] LLM支持FunctionCall")
self._support_function_call = True
else:
_logger.info("[JSONGenerator] LLM不支持FunctionCall,将使用prompt方式")
self._support_function_call = False
def _build_messages(
self,
prompt: str,
conversation: list[dict[str, str]],
retry_messages: list[dict[str, str]] | None = None,
) -> list[dict[str, str]]:
"""构建messages,拼接顺序:真实对话记录 - Prompt - 重试记录"""
messages = [*conversation, {"role": "user", "content": prompt}]
if self._llm is not None:
token_count = token_calculator.calculate_token_length(messages)
retry_token_count = token_calculator.calculate_token_length(retry_messages) if retry_messages else 0
ctx_length = self._llm.config.ctxLength
available_ctx = ctx_length - retry_token_count
if token_count > available_ctx:
_logger.warning(
"[JSONGenerator] 当前对话 Token 数量 (%d) 超过可用上下文长度 (%d),"
"进行消息裁剪(重试记录 Token: %d)",
token_count,
available_ctx,
retry_token_count,
)
while len(messages) > 1 and token_count > available_ctx:
deleted = False
for i, msg in enumerate(messages):
if msg["role"] in ("user", "assistant"):
messages = messages[:i] + messages[i + 1:]
deleted = True
break
if not deleted:
err = (
f"[JSONGenerator] 无法裁剪消息以满足上下文长度限制,"
f"当前 Token 数量: {token_count},可用上下文长度: {available_ctx}"
)
raise RuntimeError(err)
token_count = token_calculator.calculate_token_length(messages)
_logger.info(
"[JSONGenerator] 裁剪后对话 Token 数量: %d(重试记录 Token: %d,总计: %d)",
token_count,
retry_token_count,
token_count + retry_token_count,
)
messages.extend(retry_messages or [])
return messages
async def _single_trial(
self,
function: dict[str, Any],
prompt: str,
conversation: list[dict[str, str]],
retry_messages: list[dict[str, str]] | None = None,
) -> dict[str, Any]:
"""单次尝试,包含校验逻辑;function使用OpenAI标准Function格式"""
if self._llm is None:
err = "[JSONGenerator] 未初始化,请先调用init()方法"
raise RuntimeError(err)
schema = function["parameters"]
validator = Draft7Validator(schema)
if self._support_function_call:
use_xml_format = False
template = self._env.from_string(prompt)
formatted_prompt = template.render(use_xml_format=use_xml_format)
result = await self._call_with_function(function, formatted_prompt, conversation, retry_messages)
else:
use_xml_format = True
template = self._env.from_string(prompt)
formatted_prompt = template.render(use_xml_format=use_xml_format)
result = await self._call_without_function(function, formatted_prompt, conversation, retry_messages)
_logger.info("[JSONGenerator] 函数调用结果: %s", result)
try:
validator.validate(result)
except Exception as err:
err_info = str(err)
err_info = err_info.split("\n\n")[0]
_logger.info("[JSONGenerator] 验证失败:%s", err_info)
validation_error = JsonValidationError(err_info, result)
raise validation_error from err
else:
return result
async def _call_with_function(
self,
function: dict[str, Any],
prompt: str,
conversation: list[dict[str, str]],
retry_messages: list[dict[str, str]] | None = None,
) -> dict[str, Any]:
"""使用FunctionCall方式调用"""
if self._llm is None:
err = "[JSONGenerator] 未初始化,请先调用init()方法"
raise RuntimeError(err)
messages = self._build_messages(prompt, conversation, retry_messages)
tool = LLMFunctions(
name=function["name"],
description=function["description"],
param_schema=function["parameters"],
)
async for chunk in self._llm.call(
messages, include_thinking=False, streaming=False, tools=[tool], temperature=0,
):
if chunk.tool_call:
first_tool_call = chunk.tool_call[0]
return first_tool_call.arguments
return {}
async def _call_without_function(
self,
_function: dict[str, Any],
prompt: str,
conversation: list[dict[str, str]],
retry_messages: list[dict[str, str]] | None = None,
) -> dict[str, Any]:
"""不使用FunctionCall方式调用"""
if self._llm is None:
err = "[JSONGenerator] 未初始化,请先调用init()方法"
raise RuntimeError(err)
messages = self._build_messages(prompt, conversation, retry_messages)
full_response = ""
async for chunk in self._llm.call(messages, include_thinking=False, streaming=False, temperature=0):
if chunk.content:
full_response += chunk.content
json_match = re.search(r"\{.*\}", full_response, re.DOTALL)
if json_match:
return json.loads(json_match.group(0))
return {}
async def generate(
self,
function: dict[str, Any],
prompt: str,
conversation: list[dict[str, str]] | None = None,
) -> dict[str, Any]:
"""生成JSON;function使用OpenAI标准Function格式"""
if self._llm is None:
err = "[JSONGenerator] 未初始化,请先调用init()方法"
raise RuntimeError(err)
schema = function["parameters"]
Draft7Validator.check_schema(schema)
count = 0
retry_messages: list[dict[str, str]] = []
if conversation is None:
conversation = []
while count < JSON_GEN_MAX_TRIAL:
count += 1
try:
return await self._single_trial(function, prompt, conversation, retry_messages)
except (JsonValidationError, json.JSONDecodeError) as e:
_logger.exception(
"[JSONGenerator] 第 %d/%d 次尝试失败",
count,
JSON_GEN_MAX_TRIAL,
)
if count < JSON_GEN_MAX_TRIAL:
err_info = str(e)
err_info = err_info.split("\n\n")[0]
function_name = function["name"]
result_json = json.dumps(err_info, ensure_ascii=False)
retry_messages.append({
"role": "assistant",
"content": f"I called function {function_name} with parameters: {result_json}",
})
retry_messages.append({
"role": "user",
"content": (
f"The previous function call failed with validation error: {err_info}. "
"Please try again and ensure the output strictly follows the required schema."
),
})
continue
except Exception:
_logger.exception(
"[JSONGenerator] 第 %d/%d 次尝试失败(非验证错误)",
count,
JSON_GEN_MAX_TRIAL,
)
if count < JSON_GEN_MAX_TRIAL:
continue
return {}
json_generator = JsonGenerator()