"""用于问题推荐的工具"""
import random
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, Self, ClassVar
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from pydantic import Field
from pydantic.json_schema import SkipJsonSchema
from apps.common.security import Security
from apps.llm.function import FunctionLLM
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.suggest.prompt import SUGGEST_PROMPT
from apps.scheduler.call.suggest.schema import (
SingleFlowSuggestionConfig,
SuggestGenResult,
SuggestionInput,
SuggestionOutput,
)
from apps.schemas.enum_var import CallOutputType, LanguageType
from apps.schemas.pool import NodePool
from apps.schemas.record import RecordContent
from apps.schemas.scheduler import (
CallError,
CallInfo,
CallOutputChunk,
CallVars,
)
from apps.services.record import RecordManager
from apps.services.user_domain import UserDomainManager
if TYPE_CHECKING:
from apps.scheduler.executor.step import StepExecutor
class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionOutput):
"""问题推荐"""
to_user: bool = Field(default=True, description="是否将推荐的问题推送给用户")
configs: list[SingleFlowSuggestionConfig] = Field(description="问题推荐配置", default=[])
num: int = Field(default=3, ge=1, le=6, description="推荐问题的总数量(必须大于等于configs中涉及的Flow的数量)")
context: SkipJsonSchema[list[dict[str, str]]] = Field(description="Executor的上下文", exclude=True)
conversation_id: SkipJsonSchema[str] = Field(description="对话ID", exclude=True)
i18n_info: ClassVar[dict[str, dict]] = {
LanguageType.CHINESE: {
"name": "问题推荐",
"description": "在答案下方显示推荐的下一个问题",
},
LanguageType.ENGLISH: {
"name": "Question Suggestion",
"description": "Display the suggested next question under the answer",
},
}
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self:
"""初始化"""
context = [
{
"role": "user",
"content": executor.task.runtime.question,
},
{
"role": "assistant",
"content": executor.task.runtime.answer,
},
]
obj = cls(
name=executor.step.step.name,
description=executor.step.step.description,
node=node,
context=context,
conversation_id=executor.task.ids.conversation_id,
**kwargs,
)
await obj._set_input(executor)
return obj
async def _init(self, call_vars: CallVars) -> SuggestionInput:
"""初始化"""
from apps.services.appcenter import AppCenterManager
self._history_questions = await self._get_history_questions(
call_vars.ids.user_sub,
self.conversation_id,
)
self._app_id = call_vars.ids.app_id
self._flow_id = call_vars.ids.flow_id
app_metadata = await AppCenterManager.fetch_app_data_by_id(self._app_id)
self._env = SandboxedEnvironment(
loader=BaseLoader(),
autoescape=True,
trim_blocks=True,
lstrip_blocks=True,
)
self._avaliable_flows = {}
for flow in app_metadata.flows:
self._avaliable_flows[flow.id] = {
"name": flow.name,
"description": flow.description,
}
return SuggestionInput(
question=call_vars.question,
user_sub=call_vars.ids.user_sub,
history_questions=self._history_questions,
)
async def _get_history_questions(self, user_sub: str, conversation_id: str) -> list[str]:
"""获取当前对话的历史问题"""
records = await RecordManager.query_record_by_conversation_id(
user_sub,
conversation_id,
15,
)
history_questions = []
for record in records:
record_data = RecordContent.model_validate_json(Security.decrypt(record.content, record.key))
history_questions.append(record_data.question)
return history_questions
async def _exec(
self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE
) -> AsyncGenerator[CallOutputChunk, None]:
"""运行问题推荐"""
data = SuggestionInput(**input_data)
if self.num < len(self.configs):
raise CallError(
message="推荐问题的数量必须大于等于配置的数量",
data={},
)
user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(data.user_sub, 5)
pushed_questions = 0
prompt_tpl = self._env.from_string(SUGGEST_PROMPT[language])
for config in self.configs:
if config.flow_id not in self._avaliable_flows:
raise CallError(
message="配置的Flow ID不存在",
data={},
)
if config.question:
question = config.question
else:
prompt = prompt_tpl.render(
conversation=self.context,
history=self._history_questions,
tool={
"name": config.flow_id,
"description": self._avaliable_flows[config.flow_id],
},
preference=user_domain,
)
result = await FunctionLLM().call(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
schema=SuggestGenResult.model_json_schema(),
)
questions = SuggestGenResult.model_validate(result)
question = questions.predicted_questions[random.randint(0, len(questions.predicted_questions) - 1)]
yield CallOutputChunk(
type=CallOutputType.DATA,
content=SuggestionOutput(
question=question,
flowName=self._avaliable_flows[config.flow_id]["name"],
flowId=config.flow_id,
flowDescription=self._avaliable_flows[config.flow_id]["description"],
).model_dump(by_alias=True, exclude_none=True),
)
pushed_questions += 1
while pushed_questions < self.num:
prompt = prompt_tpl.render(
conversation=self.context,
history=self._history_questions,
tool={
"name": self._flow_id,
"description": self._avaliable_flows[self._flow_id],
},
preference=user_domain,
)
result = await FunctionLLM().call(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
schema=SuggestGenResult.model_json_schema(),
)
questions = SuggestGenResult.model_validate(result)
question = questions.predicted_questions[random.randint(0, len(questions.predicted_questions) - 1)]
for question in questions.predicted_questions:
if pushed_questions >= self.num:
break
yield CallOutputChunk(
type=CallOutputType.DATA,
content=SuggestionOutput(
question=question,
flowName=self._avaliable_flows[self._flow_id]["name"],
flowId=self._flow_id,
flowDescription=self._avaliable_flows[self._flow_id]["description"],
).model_dump(by_alias=True, exclude_none=True),
)
pushed_questions += 1