"""图表生成工具"""
import json
from collections.abc import AsyncGenerator
from typing import Any
from anyio import Path
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from pydantic import Field
from apps.models import LanguageType
from apps.scheduler.call.core import CoreCall
from apps.schemas.enum_var import CallOutputType
from apps.schemas.scheduler import (
CallError,
CallInfo,
CallOutputChunk,
CallVars,
)
from .prompt import GENERATE_STYLE_PROMPT
from .schema import (
RenderFormat,
RenderInput,
RenderOutput,
RenderStyleResult,
)
class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput):
"""Render Call,用于将SQL Tool查询出的数据转换为图表"""
dataset_key: str = Field(description="图表的数据来源(字段名)", default="")
@classmethod
def info(cls, language: LanguageType = LanguageType.CHINESE) -> CallInfo:
"""返回Call的名称和描述"""
i18n_info = {
LanguageType.CHINESE: CallInfo(name="图表", description="将SQL查询出的数据转换为图表"),
LanguageType.ENGLISH: CallInfo(name="Graph", description="Convert the data queried by SQL into a chart."),
}
return i18n_info[language]
async def _init(self, call_vars: CallVars) -> RenderInput:
"""初始化Render Call,校验参数,读取option模板"""
self._env = SandboxedEnvironment(
loader=BaseLoader(),
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
)
try:
option_location = Path(__file__).parent / "option.json"
f = await Path(option_location).open(encoding="utf-8")
data = await f.read()
self._option_template = json.loads(data)
await f.aclose()
except Exception as e:
raise CallError(message=f"图表模板读取失败:{e!s}", data={}) from e
if not self.dataset_key:
last_step_id = call_vars.step_order[-1]
self.dataset_key = f"{last_step_id}/dataset"
data = self._extract_history_variables(self.dataset_key, call_vars.step_data)
return RenderInput(
question=call_vars.question,
data=data,
)
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""运行Render Call"""
data = RenderInput(**input_data)
malformed = True
if isinstance(data.data, list) and len(data.data) > 0 and isinstance(data.data[0], dict):
malformed = False
if malformed:
raise CallError(
message="数据格式错误,无法生成图表!",
data={"data": data.data},
)
processed_data = data.data
column_num = len(processed_data[0]) - 1
if column_num == 0:
processed_data = Graph._separate_key_value(processed_data)
column_num = 1
self._option_template["dataset"]["source"] = processed_data
try:
style_obj = self._env.from_string(GENERATE_STYLE_PROMPT[self._sys_vars.language])
style_prompt = style_obj.render(question=data.question)
result = ""
async for chunk in self._llm(messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": style_prompt},
], streaming=True):
result += chunk
llm_output = RenderStyleResult.model_validate_json(result)
self._parse_options(column_num, llm_output)
except Exception as e:
raise CallError(message=f"图表生成失败:{e!s}", data={"data": data}) from e
yield CallOutputChunk(
type=CallOutputType.DATA,
content=RenderOutput(
output=RenderFormat.model_validate(self._option_template),
).model_dump(exclude_none=True, by_alias=True),
)
@staticmethod
def _separate_key_value(data: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
若数据只有一组(例如:{"aaa": "bbb"}),则分离键值对。
样例:{"type": "aaa", "value": "bbb"}
:param data: 待分离的数据
:return: 分离后的数据
"""
result = []
for item in data:
for key, val in item.items():
result.append({"type": key, "value": val})
return result
def _parse_options(self, column_num: int, style: RenderStyleResult) -> None:
"""解析LLM做出的图表样式选择"""
series_template = {}
if style.chart_type == "line":
series_template["type"] = "line"
elif style.chart_type == "scatter":
series_template["type"] = "scatter"
elif style.chart_type == "pie":
column_num = 1
series_template["type"] = "pie"
if style.additional_style == "ring":
series_template["radius"] = ["40%", "70%"]
else:
series_template["type"] = "bar"
if style.additional_style == "stacked":
series_template["stack"] = "total"
if style.scale_type == "log":
self._option_template["yAxis"]["type"] = "log"
for _ in range(column_num):
self._option_template["series"].append(series_template)