"""
图表生成模块
功能:
1. 根据图表名称和生成数据,LLM生成Python代码
2. 执行代码生成图表
3. 可选的VLM迭代反馈机制
"""
import asyncio
import logging
import os
import re
import json
import textwrap
from typing import Dict, List, Tuple, Optional, Any
import base64
import io
from PIL import Image
import numpy as np
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName
from openjiuwen_deepsearch.algorithm.chart_generation.utils import (
call_model,
CallModelInput,
)
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.algorithm.chart_generation.sandbox import (
AsyncCodeExecutor,
)
logger = logging.getLogger(__name__)
FONT_PATH = os.path.join(os.path.dirname(__file__), "fonts", "kt_font.ttf")
MAX_GLOBAL_CONCURRENT_CHART_TASKS = 10
MAX_SECTION_CONCURRENT_CHART_TASKS = 5
CHART_THRESHOLD = 85
_global_chart_semaphore: Optional[asyncio.Semaphore] = None
def _get_global_chart_semaphore() -> asyncio.Semaphore:
"""
获取全局图表生成信号量(懒加载,确保在asyncio事件循环中创建)
Returns:
asyncio.Semaphore: 全局并发控制信号量
"""
global _global_chart_semaphore
if _global_chart_semaphore is None:
_global_chart_semaphore = asyncio.Semaphore(MAX_GLOBAL_CONCURRENT_CHART_TASKS)
return _global_chart_semaphore
class ChartGenerator:
"""图表生成器"""
def __init__(
self,
llm_model_name: str,
output_dir: str,
vlm_model_name: Optional[str] = None,
vlm_max_iterations: int = 1,
):
"""
初始化
Args:
llm_model_name: LLM模型名称(用于生成代码)
output_dir: 图表输出目录
vlm_model_name: VLM模型名称(用于评估反馈),None表示不使用
"""
self._llm_model = llm_model_name
self._vlm_model = vlm_model_name
self.output_dir = output_dir
self._vlm_max_iterations = vlm_max_iterations
self._chart_threshold = CHART_THRESHOLD
self._log_prefix = "[ChartGenerator]"
self._max_section_concurrent_tasks = MAX_SECTION_CONCURRENT_CHART_TASKS
self._max_global_concurrent_tasks = MAX_GLOBAL_CONCURRENT_CHART_TASKS
os.makedirs(self.output_dir, exist_ok=True)
def _create_code_executor(self) -> AsyncCodeExecutor:
"""为单个图表任务创建独立沙箱执行器,避免并发任务共享全局变量。"""
code_executor = AsyncCodeExecutor(working_dir=self.output_dir, exec_timeout=120)
return code_executor
async def generate_charts(
self,
chart_tasks: Dict[int, List[Dict[str, Any]]],
) -> List[Dict[int, Dict[str, Any]]]:
"""
批量生成图表
Args:
chart_tasks: 图表生成任务列表
use_vlm_critic: 是否使用VLM评估反馈
Returns:
List[Dict[int, Dict[str, Any]]]: 图表各项信息
"""
section_coroutines: List[asyncio.Future] = []
section_indices: List[int] = []
try:
if not chart_tasks:
raise ValueError(f"{self._log_prefix} chart_tasks is empty!")
for section_idx, section_tasks in chart_tasks.items():
section_indices.append(section_idx)
section_coroutines.append(
self._generate_section_charts(section_tasks, section_idx)
)
results = await asyncio.gather(*section_coroutines, return_exceptions=True)
report_chart_results = self._post_process_report_results(
results, section_indices
)
return report_chart_results
except Exception as e:
error_msg = f"Error generating charts: {e}"
logger.error(error_msg)
raise CustomValueException(
StatusCode.CHART_VLM_GENERATION_ERROR.code,
StatusCode.CHART_VLM_GENERATION_ERROR.errmsg.format(e=error_msg),
) from e
@staticmethod
def _post_process_report_results(
results: List[List[Dict[str, Any]]], chart_tasks_section_idx: List[int]
) -> List[Dict[int, Dict[str, Any]]]:
"""后处理报告图表生成结果,将生成图表对应到各自的章节"""
try:
if len(results) != len(chart_tasks_section_idx):
raise ValueError(
f"Results length ({len(results)}) != "
f"section indices length ({len(chart_tasks_section_idx)})"
)
report_chart_results = {}
for result, section_idx in zip(results, chart_tasks_section_idx):
report_chart_results[section_idx] = result
return report_chart_results
except Exception as e:
error_msg = f"Error post processing report results: {e}"
logger.error(error_msg)
return {}
async def _generate_section_charts(
self, section_chart_tasks: List[Dict[str, Any]], section_idx: int
) -> List[Dict[str, Any]]:
"""
生成同一章节中的图表,使用双层并发控制:
1. 全局信号量:限制所有章节总共的并发任务数(防止"章节数 × 5"扇出)
2. 筠节信号量:限制单个章节内的并发任务数
"""
global_semaphore = _get_global_chart_semaphore()
section_semaphore = asyncio.Semaphore(self._max_section_concurrent_tasks)
async def _generate_with_double_semaphore(chart_task: Dict[str, Any]):
"""
带双层并发控制的图表生成
先获取全局配额,再获取章节配额,确保:
- 全局总并发不超过 MAX_GLOBAL_CONCURRENT_CHART_TASKS
- 单章节并发不超过 MAX_SECTION_CONCURRENT_CHART_TASKS
"""
async with global_semaphore:
async with section_semaphore:
return await self._generate_single_chart(chart_task)
tasks = []
for chart_task in section_chart_tasks:
tasks.append(_generate_with_double_semaphore(chart_task))
results = await asyncio.gather(*tasks, return_exceptions=True)
section_chart_results = self._post_process_section_results(
results, section_chart_tasks, section_idx
)
return section_chart_results
@staticmethod
def _post_process_section_results(
results: List[Dict[str, Any]],
section_chart_tasks: List[Dict[str, Any]],
section_idx: int,
) -> List[Dict[str, Any]]:
"""后处理章节图表生成结果,将生成图表对应到任务"""
try:
section_chart_results = []
if len(results) != len(section_chart_tasks):
raise ValueError(
f"Results length ({len(results)}) != chart tasks length ({len(section_chart_tasks)})"
)
section_chart_idx = 1
for result, chart_task in zip(results, section_chart_tasks):
if result:
section_chart_results.append(chart_task.copy())
section_chart_results[-1]["chart_base64"] = result.get("chart_base64", "")
section_chart_results[-1]["score"] = result.get("score", 0)
section_chart_results[-1]["chart_id"] = (
f"chart_{section_idx}_{section_chart_idx}"
)
section_chart_idx += 1
return section_chart_results
except Exception as e:
logger.error(f"Error post processing section results: {e}")
return []
async def _generate_single_chart(self, chart_task: Dict[str, Any]
) -> Dict[str, Any]:
"""
生成单个图表的代码并执行,进行VLM评估反馈(可选)
Args:
chart_task: 图表生成任务
Returns:
Dict[str, Any]:
chart_base64: 图表base64编码,失败返回None,过滤掉分数低于阈值的图表
score: 图表分数(vlm迭代优化功能开启后该分数才有意义)
"""
try:
chart_data = chart_task.get("data", {})
if not chart_data:
logger.warning(f"No data for chart: {chart_task.get('chart_id', '')}")
return None
if len(chart_data) <= 1:
logger.warning(f"There is only one data to show. "\
f"Filtered data: {json.dumps(chart_data, ensure_ascii=False)}")
return None
chart_title = chart_task.get("chart_title", "")
chart_description = chart_task.get("description", "")
chart_type = chart_task.get("chart_type", "")
figure_id = f"figure_{chart_task.get('section_index', -1)}_{chart_task.get('chart_id_in_section', -1)}"
gen_chart_input = {
"chart_title": chart_title,
"chart_description": chart_description,
"chart_type": chart_type,
"chart_data": chart_data,
"font_path": FONT_PATH,
"history_messages": {},
}
suggestion_list = []
result = await self._generate_and_execute_code(gen_chart_input, figure_id)
if not result or not result.get("chart_base64"):
logger.warning(f"Failed to generate chart for {figure_id}")
return {}
code = result.get("code", "")
chart_base64 = result.get("chart_base64", "")
if self._vlm_max_iterations == 0:
logger.info(f"Chart generated successfully: {figure_id},"
f"chart title: {chart_title}, without iterations.")
return {"chart_base64": chart_base64, "score": 0}
iterate = 0
while iterate <= self._vlm_max_iterations:
suggestion_and_score = await self._vlm_iterate(
chart_base64, gen_chart_input, suggestion_list
)
score = suggestion_and_score.get("score", 0)
suggestions = suggestion_and_score.get("suggestion", "")
iterate += 1
if score >= self._chart_threshold:
del suggestion_list
del result
if self._has_large_blank(chart_base64):
logger.info(f"{self._log_prefix} Filter the chart {figure_id} with large blank.")
del chart_base64
return {}
logger.info(f"{self._log_prefix} Chart generated successfully: {figure_id},"
f"chart title: {chart_title}, score: {score}")
final_base64 = chart_base64
del chart_base64
return {"chart_base64": final_base64, "score": score}
elif iterate > self._vlm_max_iterations:
logger.debug("Chart generated fail: %s, chart title: %s, score: %s",
figure_id, chart_title, score)
return {}
else:
logger.debug("Chart optimization suggestions: %s, chart title: %s, suggestions: %s",
figure_id, chart_title, suggestions)
suggestion_list.append(suggestions)
gen_chart_input["history_messages"] = {
"code": code,
"error_msg": None,
"suggestion": [suggestions],
}
old_base64 = chart_base64
del old_base64
result = await self._generate_and_execute_code(
gen_chart_input, figure_id
)
if not result or not result.get("chart_base64"):
logger.warning(f"Failed to generate chart for {figure_id}")
return {}
code = result.get("code", "")
chart_base64 = result.get("chart_base64", "")
except Exception as e:
logger.warning(f"Error generating chart: {e}")
return {}
return {}
async def _generate_and_execute_code(
self, gen_chart_input: Dict[str, Any], figure_id: str
) -> Dict[str, str]:
"""
生成单个图表
Args:
gen_chart_input: 图表生成任务输入信息
figure_id: 图表ID
Returns:
Dict[str, str]: {
"code": 图表代码,
"chart_base64": 图表base64编码
}
"""
code_executor = self._create_code_executor()
for _ in range(3):
pre_suggestion = gen_chart_input.get("history_messages", {}).get(
"suggestion", []
)
code = await self._generate_chart_code(gen_chart_input)
if not code or "no code" in code.lower():
logger.warning(f"Failed to generate code for {figure_id}")
return {}
logger.debug(f"The origin code is: \n%s.", code)
result = await code_executor.execute(code)
if result["error"]:
logger.debug("Error executing chart code: %s", result['stderr'])
gen_chart_input["history_messages"] = {
"code": code,
"error_msg": f"stdout: {result['stdout']}\nstderr: {result['stderr']}",
"suggestion": (
pre_suggestion if isinstance(pre_suggestion, list) else []
)
+ [""],
}
continue
chart_base64 = result.get("chart_base64")
if chart_base64:
logger.info(f"Chart generated successfully in memory for {figure_id}")
return {"code": code, "chart_base64": chart_base64}
else:
logger.warning(f"No chart generated in code execution for {figure_id}")
gen_chart_input["history_messages"] = {
"code": code,
"error_msg": None,
"suggestion": (
pre_suggestion if isinstance(pre_suggestion, list) else []
)
+ [
"\nThe chart was not generated. Please ensure you create a matplotlib figure."
],
}
continue
return {}
async def _generate_chart_code(
self, gen_chart_input: Dict[str, Any]
) -> Optional[str]:
"""
生成图表代码
Args:
单个图表生成任务的输入信息字典
Returns:
Optional[str]: 生成的Python代码
"""
try:
detect_func_and_args = {
"detection_func": None,
"args": {},
"option": "skip normalize",
}
call_model_input = CallModelInput(
model_name=self._llm_model,
prompt="vlm_generate_chart_code_prompt",
user_input=gen_chart_input,
agent_name=AgentLlmName.VLM_CHART_GENERATOR_GENERATE_CHART_CODE.value,
)
response = await call_model(
call_model_input, detection_func_and_args=detect_func_and_args
)
code = self._extract_code(response)
return self._normalize_code(code)
except Exception as e:
logger.error(f"Error generating chart code: {e}")
return None
@staticmethod
def _extract_code(response: str) -> Optional[str]:
"""
从LLM响应中提取Python代码
优先提取 ```python 和 ``` 之间的内容;
若末尾没有 ```,则提取 ```python 之后的全部内容。
"""
match = re.search(r"```(?:python)?\s*([\s\S]*?)\s*```", response)
if match:
return match.group(1).strip()
match = re.search(r"```(?:python)?\s*([\s\S]*)", response)
if match:
return match.group(1).strip()
return response.strip()
@staticmethod
def _normalize_code(code: str) -> str:
"""
代码规范化
Args:
code: Python代码
"""
if code is None:
return ""
s = code.strip()
if len(s) >= 2 and s[0] == s[-1] and s[0] in ("'", '"'):
try:
loaded = json.loads(s)
if isinstance(loaded, str):
s = loaded
except Exception as e:
logger.debug("Code is not a JSON string, using original: %s", str(e))
s = s.replace("\r\n", "\n").replace("\r", "\n")
s = textwrap.dedent(s)
normalized_lines = []
for line in s.split("\n"):
line = re.sub(r"^\t+", lambda m: " " * 4 * len(m.group(0)), line)
normalized_lines.append(line.rstrip())
return "\n".join(normalized_lines).strip()
async def _vlm_iterate(
self,
chart_base64: str,
gen_chart_input: Dict[str, Any],
suggestion_list: List[str],
) -> Dict[str, Any]:
"""
VLM迭代反馈优化图表
Args:
chart_base64: 图表base64
gen_chart_input: 图表生成任务输入信息
Returns:
Dict[str, Any]:
- 反馈意见内容
- 图表得分
"""
vlm_llm_input = {
"chart_title": gen_chart_input.get("chart_title", ""),
"chart_description": gen_chart_input.get("chart_description", ""),
"chart_type": gen_chart_input.get("chart_type", ""),
"chart_data": gen_chart_input.get("chart_data", {}),
"history_suggestion": suggestion_list,
"chart_base64": chart_base64,
}
try:
call_model_input = CallModelInput(
model_name=self._vlm_model,
prompt="vlm_iterate_prompt",
user_input=vlm_llm_input,
agent_name=AgentLlmName.VLM_CHART_GENERATOR_VLM_ITERATE.value,
)
response = await call_model(call_model_input, use_vlm=True)
if not response:
return {}
return response
except Exception as e:
error_msg = f"Error in VLM iteration: {e}"
logger.error(error_msg)
raise CustomValueException(
StatusCode.CHART_VLM_GENERATION_ERROR.code,
StatusCode.CHART_VLM_GENERATION_ERROR.errmsg.format(e=error_msg),
) from e
def set_vlm_name(self, model_name: str):
"""修改vlm模型名称"""
self._vlm_model = model_name
def set_vlm_iteration(self, iteration: int):
"""修改vlm迭代优化最大次数"""
self._vlm_max_iterations = iteration
@staticmethod
def _has_large_blank(png_base64: str) -> bool:
"""
判断图片中是否存在大面积连续空白(占画布面积超过90%)
使用numpy数组运算同时检测水平方向(空白行)和垂直方向(空白列)
的最大连续空白面积,任一方向超过阈值即判定为大面积空白。
Args:
png_base64: 图片的base64编码字符串
Returns:
bool: True表示存在大面积连续空白,False表示不存在
"""
try:
image_data = base64.b64decode(png_base64)
img = Image.open(io.BytesIO(image_data))
try:
width, height = img.size
if width == 0 or height == 0:
return True
gray = img.convert("L")
arr = np.asarray(gray)
white_threshold = 245
blank_ratio_threshold = 0.8
large_blank_ratio = 0.9
white_mask = arr >= white_threshold
blank_rows = white_mask.mean(axis=1) >= blank_ratio_threshold
blank_cols = white_mask.mean(axis=0) >= blank_ratio_threshold
def _max_run(mask) -> int:
max_run = 0
current = 0
for is_blank in mask:
if is_blank:
current += 1
else:
max_run = max(max_run, current)
current = 0
return max(max_run, current)
total_pixels = width * height
max_blank_area = max(
_max_run(blank_rows) * width,
_max_run(blank_cols) * height,
)
return max_blank_area > total_pixels * large_blank_ratio
finally:
img.close()
except Exception as e:
logger.warning(f"Error checking large blank area: {e}")
return True