"""
图表生成模块
功能:
1. 根据图表名称和生成数据,LLM生成Python代码
2. 执行代码生成图表
3. 可选的VLM迭代反馈机制
"""
import asyncio
import logging
import os
import re
import json
import textwrap
from typing import Dict, List, Tuple, Optional, Any
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName
from openjiuwen_deepsearch.algorithm.chart_generation.utils import (
call_model,
CallModelInput,
)
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.algorithm.chart_generation.sandbox import (
AsyncCodeExecutor,
)
logger = logging.getLogger(__name__)
FONT_PATH = os.path.join(os.path.dirname(__file__), "fonts", "kt_font.ttf")
MAX_GLOBAL_CONCURRENT_CHART_TASKS = 10
MAX_SECTION_CONCURRENT_CHART_TASKS = 5
CHART_THRESHOLD = 85
_global_chart_semaphore: Optional[asyncio.Semaphore] = None
def _get_global_chart_semaphore() -> asyncio.Semaphore:
"""
获取全局图表生成信号量(懒加载,确保在asyncio事件循环中创建)
Returns:
asyncio.Semaphore: 全局并发控制信号量
"""
global _global_chart_semaphore
if _global_chart_semaphore is None:
_global_chart_semaphore = asyncio.Semaphore(MAX_GLOBAL_CONCURRENT_CHART_TASKS)
return _global_chart_semaphore
class ChartGenerator:
"""图表生成器"""
def __init__(
self,
llm_model_name: str,
output_dir: str,
vlm_model_name: Optional[str] = None,
vlm_max_iterations: int = 1,
):
"""
初始化
Args:
llm_model_name: LLM模型名称(用于生成代码)
output_dir: 图表输出目录
vlm_model_name: VLM模型名称(用于评估反馈),None表示不使用
"""
self._llm_model = llm_model_name
self._vlm_model = vlm_model_name
self.output_dir = output_dir
self._vlm_max_iterations = vlm_max_iterations
self._chart_threshold = CHART_THRESHOLD
self._log_prefix = "[ChartGenerator]"
self._max_section_concurrent_tasks = MAX_SECTION_CONCURRENT_CHART_TASKS
self._max_global_concurrent_tasks = MAX_GLOBAL_CONCURRENT_CHART_TASKS
os.makedirs(self.output_dir, exist_ok=True)
def _create_code_executor(self) -> AsyncCodeExecutor:
"""为单个图表任务创建独立沙箱执行器,避免并发任务共享全局变量。"""
code_executor = AsyncCodeExecutor(working_dir=self.output_dir, exec_timeout=120)
return code_executor
async def generate_charts(
self,
chart_tasks: Dict[int, List[Dict[str, Any]]],
) -> Dict[str, str]:
"""
批量生成图表
Args:
chart_tasks: 图表生成任务列表
use_vlm_critic: 是否使用VLM评估反馈
Returns:
Dict[str, str]: 图表各项信息
"""
section_coroutines: List[asyncio.Future] = []
section_indices: List[int] = []
try:
if not chart_tasks:
raise ValueError(f"{self._log_prefix} chart_tasks is empty!")
for section_idx, section_tasks in chart_tasks.items():
section_indices.append(section_idx)
section_coroutines.append(
self._generate_section_charts(section_tasks, section_idx)
)
results = await asyncio.gather(*section_coroutines, return_exceptions=True)
report_chart_results = self._post_process_report_results(
results, section_indices
)
return report_chart_results
except Exception as e:
error_msg = f"Error generating charts: {e}"
logger.error(error_msg)
raise CustomValueException(
StatusCode.CHART_VLM_GENERATION_ERROR.code,
StatusCode.CHART_VLM_GENERATION_ERROR.errmsg.format(e=error_msg),
) from e
@staticmethod
def _post_process_report_results(
results: List[List[Dict[str, Any]]], chart_tasks_section_idx: List[int]
) -> List[Dict[int, Dict[str, Any]]]:
"""后处理报告图表生成结果,将生成图表对应到各自的章节"""
try:
if len(results) != len(chart_tasks_section_idx):
raise ValueError(
f"Results length ({len(results)}) != "
f"section indices length ({len(chart_tasks_section_idx)})"
)
report_chart_results = {}
for result, section_idx in zip(results, chart_tasks_section_idx):
report_chart_results[section_idx] = result
return report_chart_results
except Exception as e:
error_msg = f"Error post processing report results: {e}"
logger.error(error_msg)
return {}
async def _generate_section_charts(
self, section_chart_tasks: List[Dict[str, Any]], section_idx: int
) -> List[Dict[str, Any]]:
"""
生成同一章节中的图表,使用双层并发控制:
1. 全局信号量:限制所有章节总共的并发任务数(防止"章节数 × 5"扇出)
2. 筠节信号量:限制单个章节内的并发任务数
"""
global_semaphore = _get_global_chart_semaphore()
section_semaphore = asyncio.Semaphore(self._max_section_concurrent_tasks)
async def _generate_with_double_semaphore(chart_task: Dict[str, Any]):
"""
带双层并发控制的图表生成
先获取全局配额,再获取章节配额,确保:
- 全局总并发不超过 MAX_GLOBAL_CONCURRENT_CHART_TASKS
- 单章节并发不超过 MAX_SECTION_CONCURRENT_CHART_TASKS
"""
async with global_semaphore:
async with section_semaphore:
return await self._generate_single_chart(chart_task)
tasks = []
for chart_task in section_chart_tasks:
tasks.append(_generate_with_double_semaphore(chart_task))
results = await asyncio.gather(*tasks, return_exceptions=True)
section_chart_results = self._post_process_section_results(
results, section_chart_tasks, section_idx
)
return section_chart_results
@staticmethod
def _post_process_section_results(
results: List[Dict[str, Any]],
section_chart_tasks: List[Dict[str, Any]],
section_idx: int,
) -> List[Dict[str, Any]]:
"""后处理章节图表生成结果,将生成图表对应到任务"""
try:
section_chart_results = []
if len(results) != len(section_chart_tasks):
raise ValueError(
f"Results length ({len(results)}) != chart tasks length ({len(section_chart_tasks)})"
)
section_chart_idx = 1
for result, chart_task in zip(results, section_chart_tasks):
if result:
section_chart_results.append(chart_task.copy())
section_chart_results[-1]["chart_base64"] = result.get("chart_base64", "")
section_chart_results[-1]["score"] = result.get("score", 0)
section_chart_results[-1]["chart_id"] = (
f"chart_{section_idx}_{section_chart_idx}"
)
section_chart_idx += 1
return section_chart_results
except Exception as e:
logger.error(f"Error post processing section results: {e}")
return []
async def _generate_single_chart(self, chart_task: Dict[str, Any]
) -> Dict[str, Any]:
"""
生成单个图表的代码并执行,进行VLM评估反馈(可选)
Args:
chart_task: 图表生成任务
Returns:
Dict[str, Any]:
chart_base64: 图表base64编码,失败返回None,过滤掉分数低于阈值的图表
score: 图表分数(vlm迭代优化功能开启后该分数才有意义)
"""
try:
chart_data = chart_task.get("data", {})
if not chart_data:
logger.warning(f"No data for chart: {chart_task.get('chart_id', '')}")
return None
chart_title = chart_task.get("chart_title", "")
chart_description = chart_task.get("description", "")
chart_type = chart_task.get("chart_type", "")
figure_id = f"figure_{chart_task.get('section_index', -1)}_{chart_task.get('chart_id_in_section', -1)}"
gen_chart_input = {
"chart_title": chart_title,
"chart_description": chart_description,
"chart_type": chart_type,
"chart_data": chart_data,
"font_path": FONT_PATH,
"history_messages": {},
}
suggestion_list = []
result = await self._generate_and_execute_code(gen_chart_input, figure_id)
if not result or not result.get("chart_base64"):
logger.warning(f"Failed to generate chart for {figure_id}")
return {}
code = result.get("code", "")
chart_base64 = result.get("chart_base64", "")
if self._vlm_max_iterations == 0:
logger.info(f"Chart generated successfully: {figure_id},"
f"chart title: {chart_title}, without iterations.")
return {"chart_base64": chart_base64, "score": 0}
iterate = 0
while iterate <= self._vlm_max_iterations:
suggestion_and_score = await self._vlm_iterate(
chart_base64, gen_chart_input, suggestion_list
)
score = suggestion_and_score.get("score", 0)
suggestions = suggestion_and_score.get("suggestion", "")
iterate += 1
if score >= self._chart_threshold:
logger.info(f"Chart generated successfully: {figure_id},"
f"chart title: {chart_title}, score: {score}")
final_base64 = chart_base64
del chart_base64
del suggestion_list
del result
return {"chart_base64": final_base64, "score": score}
elif iterate > self._vlm_max_iterations:
logger.debug("Chart generated fail: %s, chart title: %s, score: %s",
figure_id, chart_title, score)
return {}
else:
logger.debug("Chart optimization suggestions: %s, chart title: %s, suggestions: %s",
figure_id, chart_title, suggestions)
suggestion_list.append(suggestions)
gen_chart_input["history_messages"] = {
"code": code,
"error_msg": None,
"suggestion": [suggestions],
}
old_base64 = chart_base64
del old_base64
result = await self._generate_and_execute_code(
gen_chart_input, figure_id
)
if not result or not result.get("chart_base64"):
logger.warning(f"Failed to generate chart for {figure_id}")
return {}
code = result.get("code", "")
chart_base64 = result.get("chart_base64", "")
except Exception as e:
logger.warning(f"Error generating chart: {e}")
return {}
return {}
async def _generate_and_execute_code(
self, gen_chart_input: Dict[str, Any], figure_id: str
) -> Dict[str, str]:
"""
生成单个图表
Args:
gen_chart_input: 图表生成任务输入信息
figure_id: 图表ID
Returns:
Dict[str, str]: {
"code": 图表代码,
"chart_base64": 图表base64编码
}
"""
code_executor = self._create_code_executor()
for _ in range(3):
pre_suggestion = gen_chart_input.get("history_messages", {}).get(
"suggestion", []
)
code = await self._generate_chart_code(gen_chart_input)
if not code:
logger.warning(f"Failed to generate code for {figure_id}")
return {}
result = await code_executor.execute(code)
if result["error"]:
logger.debug("Error executing chart code: %s", result['stderr'])
gen_chart_input["history_messages"] = {
"code": code,
"error_msg": f"stdout: {result['stdout']}\nstderr: {result['stderr']}",
"suggestion": (
pre_suggestion if isinstance(pre_suggestion, list) else []
)
+ [""],
}
continue
chart_base64 = result.get("chart_base64")
if chart_base64:
logger.info(f"Chart generated successfully in memory for {figure_id}")
return {"code": code, "chart_base64": chart_base64}
else:
logger.warning(f"No chart generated in code execution for {figure_id}")
gen_chart_input["history_messages"] = {
"code": code,
"error_msg": None,
"suggestion": (
pre_suggestion if isinstance(pre_suggestion, list) else []
)
+ [
"\nThe chart was not generated. Please ensure you create a matplotlib figure."
],
}
continue
return {}
async def _generate_chart_code(
self, gen_chart_input: Dict[str, Any]
) -> Optional[str]:
"""
生成图表代码
Args:
单个图表生成任务的输入信息字典
Returns:
Optional[str]: 生成的Python代码
"""
try:
detect_func_and_args = {
"detection_func": None,
"args": {},
"option": "skip normalize",
}
call_model_input = CallModelInput(
model_name=self._llm_model,
prompt="vlm_generate_chart_code_prompt",
user_input=gen_chart_input,
agent_name=AgentLlmName.VLM_CHART_GENERATOR_GENERATE_CHART_CODE.value,
)
response = await call_model(
call_model_input, detection_func_and_args=detect_func_and_args
)
def extract_code(response: str) -> Optional[str]:
"""
从LLM响应中提取Python代码
优先提取 ```python 和 ``` 之间的内容;
若末尾没有 ```,则提取 ```python 之后的全部内容。
"""
match = re.search(r"```(?:python)?\s*([\s\S]*?)\s*```", response)
if match:
return match.group(1).strip()
match = re.search(r"```(?:python)?\s*([\s\S]*)", response)
if match:
return match.group(1).strip()
return response.strip()
code = extract_code(response)
return self._normalize_code(code)
except Exception as e:
logger.error(f"Error generating chart code: {e}")
return None
@staticmethod
def _normalize_code(code: str) -> str:
"""
代码规范化
Args:
code: Python代码
"""
if code is None:
return ""
s = code.strip()
if len(s) >= 2 and s[0] == s[-1] and s[0] in ("'", '"'):
try:
loaded = json.loads(s)
if isinstance(loaded, str):
s = loaded
except Exception as e:
logger.debug("Code is not a JSON string, using original: %s", str(e))
s = s.replace("\r\n", "\n").replace("\r", "\n")
s = textwrap.dedent(s)
normalized_lines = []
for line in s.split("\n"):
line = re.sub(r"^\t+", lambda m: " " * 4 * len(m.group(0)), line)
normalized_lines.append(line.rstrip())
return "\n".join(normalized_lines).strip()
async def _vlm_iterate(
self,
chart_base64: str,
gen_chart_input: Dict[str, Any],
suggestion_list: List[str],
) -> Dict[str, Any]:
"""
VLM迭代反馈优化图表
Args:
chart_base64: 图表base64
gen_chart_input: 图表生成任务输入信息
Returns:
Dict[str, Any]:
- 反馈意见内容
- 图表得分
"""
vlm_llm_input = {
"chart_title": gen_chart_input.get("chart_title", ""),
"chart_description": gen_chart_input.get("chart_description", ""),
"chart_type": gen_chart_input.get("chart_type", ""),
"chart_data": gen_chart_input.get("chart_data", {}),
"history_suggestion": suggestion_list,
"chart_base64": chart_base64,
}
try:
call_model_input = CallModelInput(
model_name=self._vlm_model,
prompt="vlm_iterate_prompt",
user_input=vlm_llm_input,
agent_name=AgentLlmName.VLM_CHART_GENERATOR_VLM_ITERATE.value,
)
response = await call_model(call_model_input, use_vlm=True)
if not response:
return {}
return response
except Exception as e:
error_msg = f"Error in VLM iteration: {e}"
logger.error(error_msg)
raise CustomValueException(
StatusCode.CHART_VLM_GENERATION_ERROR.code,
StatusCode.CHART_VLM_GENERATION_ERROR.errmsg.format(e=error_msg),
) from e
def set_vlm_name(self, model_name: str):
"""修改vlm模型名称"""
self._vlm_model = model_name
def set_vlm_iteration(self, iteration: int):
"""修改vlm迭代优化最大次数"""
self._vlm_max_iterations = iteration