"""
Core Call类是定义了所有Call都应具有的方法和参数的PyDantic类。
所有Call类必须继承此类,并根据需求重载方法。
"""
import logging
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, ClassVar, Self
from pydantic import BaseModel, ConfigDict, Field
from pydantic.json_schema import SkipJsonSchema
from apps.llm.function import FunctionLLM
from apps.llm.reasoning import ReasoningLLM
from apps.schemas.enum_var import CallOutputType, LanguageType
from apps.schemas.pool import NodePool
from apps.schemas.scheduler import (
CallError,
CallIds,
CallInfo,
CallOutputChunk,
CallTokens,
CallVars,
)
from apps.schemas.task import FlowStepHistory
from apps.schemas.enum_var import LanguageType
from apps.schemas.parameters import DefaultParamName
if TYPE_CHECKING:
from apps.scheduler.executor.step import StepExecutor
logger = logging.getLogger(__name__)
class DataBase(BaseModel):
"""所有Call的输入基类"""
@classmethod
def model_json_schema(cls, override: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, Any]:
"""通过override参数,动态填充Schema内容"""
schema = super().model_json_schema(**kwargs)
if override:
for key, value in override.items():
schema["properties"][key] = value
return schema
class CoreCall(BaseModel):
"""所有Call的父类,包含通用的逻辑"""
name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True)
description: SkipJsonSchema[str] = Field(
description="Step的描述", exclude=True)
node: SkipJsonSchema[NodePool | None] = Field(
description="节点信息", exclude=True)
enable_filling: SkipJsonSchema[bool] = Field(
description="是否需要进行自动参数填充", default=False, exclude=True
)
tokens: SkipJsonSchema[CallTokens] = Field(
description="Call的输入输出Tokens信息",
default=CallTokens(),
exclude=True,
)
input_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field(
description="Call的输入Pydantic类型;不包含override的模板",
exclude=True,
frozen=True,
)
output_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field(
description="Call的输出Pydantic类型;不包含override的模板",
exclude=True,
frozen=True,
)
i18n_info: ClassVar[SkipJsonSchema[dict[str, dict]]] = {}
to_user: bool = Field(description="是否需要将输出返回给用户", default=False)
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="allow",
)
@classmethod
def info(cls, language: LanguageType = LanguageType.CHINESE) -> CallInfo:
"""
返回Call的名称和描述
:return: Call的名称和描述
:rtype: CallInfo
"""
lang_info = cls.i18n_info.get(
language, cls.i18n_info[LanguageType.CHINESE])
return CallInfo(name=lang_info["name"], description=lang_info["description"])
def __init_subclass__(
cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any
) -> None:
"""初始化子类"""
super().__init_subclass__(**kwargs)
cls.input_model = input_model
cls.output_model = output_model
@staticmethod
def _assemble_call_vars(executor: "StepExecutor") -> CallVars:
"""组装CallVars"""
if not executor.task.state:
err = "[CoreCall] 当前ExecutorState为空"
logger.error(err)
raise ValueError(err)
history = {}
history_order = []
for item in executor.task.context:
item_obj = FlowStepHistory.model_validate(item)
history[item_obj.step_id] = item_obj
history_order.append(item_obj.step_id)
return CallVars(
ids=CallIds(
task_id=executor.task.id,
flow_id=executor.task.state.flow_id,
session_id=executor.task.ids.session_id,
user_sub=executor.task.ids.user_sub,
app_id=executor.task.state.app_id,
),
question=executor.question,
history=history,
history_order=history_order,
summary=executor.task.runtime.summary,
)
@staticmethod
def _extract_history_variables(path: str, call_vars: CallVars) -> Any:
"""
提取History中的变量
:param path: 路径,格式为:step_id/key/to/variable
:param history: Step历史,即call_vars.history
:return: 变量
"""
split_path = path.split("/")
if len(split_path) < 1:
err = f"[CoreCall] 路径格式错误: {path}"
logger.error(err)
return None
if split_path[0] == "start":
if split_path[1] == DefaultParamName.QUESTION.value:
return call_vars.question
else:
err = f"[CoreCall] start步骤仅支持提取question变量,当前为: {split_path[1]}"
logger.error(err)
return None
history = call_vars.history
if split_path[0] not in history:
err = f"[CoreCall] 步骤{split_path[0]}不存在"
logger.error(err)
return None
data = history[split_path[0]].output_data
for key in split_path[1:]:
if key not in data:
err = f"[CoreCall] 输出Key {key} 不存在"
logger.error(err)
return None
data = data[key]
return data
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self:
"""实例化Call类"""
obj = cls(
name=executor.step.step.name,
description=executor.step.step.description,
node=node,
**kwargs,
)
await obj._set_input(executor)
return obj
async def _set_input(self, executor: "StepExecutor") -> None:
"""获取Call的输入"""
self._sys_vars = self._assemble_call_vars(executor)
input_data = await self._init(self._sys_vars)
self.input = input_data.model_dump(by_alias=True, exclude_none=True)
async def _init(self, call_vars: CallVars) -> DataBase:
"""初始化Call类,并返回Call的输入"""
err = "[CoreCall] 初始化方法必须手动实现"
raise NotImplementedError(err)
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""Call类实例的流式输出方法"""
yield CallOutputChunk(type=CallOutputType.TEXT, content="")
async def _after_exec(self, input_data: dict[str, Any]) -> None:
"""Call类实例的执行后方法"""
async def exec(
self,
executor: "StepExecutor",
input_data: dict[str, Any],
language: LanguageType = LanguageType.CHINESE,
) -> AsyncGenerator[CallOutputChunk, None]:
"""Call类实例的执行方法"""
async for chunk in self._exec(input_data, language):
yield chunk
await self._after_exec(input_data)
async def _llm(self, messages: list[dict[str, Any]]) -> str:
"""Call可直接使用的LLM非流式调用"""
result = ""
llm = ReasoningLLM()
async for chunk in llm.call(messages, streaming=False):
result += chunk
self.input_tokens = llm.input_tokens
self.output_tokens = llm.output_tokens
return result
async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel:
"""Call可直接使用的JSON生成"""
json = FunctionLLM()
result = await json.call(messages=messages, schema=schema.model_json_schema())
return schema.model_validate(result)