"""提取事实工具"""
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, Self, ClassVar
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from pydantic import Field
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.facts.prompt import DOMAIN_PROMPT, FACTS_PROMPT
from apps.scheduler.call.facts.schema import (
DomainGen,
FactsGen,
FactsInput,
FactsOutput,
)
from apps.schemas.enum_var import CallOutputType, LanguageType
from apps.schemas.pool import NodePool
from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars
from apps.services.user_domain import UserDomainManager
if TYPE_CHECKING:
from apps.scheduler.executor.step import StepExecutor
class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
"""提取事实工具"""
answer: str = Field(description="用户输入")
i18n_info: ClassVar[dict[str, dict]] = {
LanguageType.CHINESE: {
"name": "提取事实",
"description": "从对话上下文和文档片段中提取事实。",
},
LanguageType.ENGLISH: {
"name": "Fact Extraction",
"description": "Extract facts from the conversation context and document snippets.",
},
}
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self:
"""初始化工具"""
obj = cls(
answer=executor.task.runtime.answer,
name=executor.step.step.name,
description=executor.step.step.description,
node=node,
**kwargs,
)
await obj._set_input(executor)
return obj
async def _init(self, call_vars: CallVars) -> FactsInput:
"""初始化工具"""
message = [
{"role": "user", "content": call_vars.question},
{"role": "assistant", "content": self.answer},
]
return FactsInput(
user_sub=call_vars.ids.user_sub,
message=message,
)
async def _exec(
self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE
) -> AsyncGenerator[CallOutputChunk, None]:
"""执行工具"""
data = FactsInput(**input_data)
env = SandboxedEnvironment(
loader=BaseLoader(),
autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
)
facts_tpl = env.from_string(FACTS_PROMPT[language])
facts_prompt = facts_tpl.render(conversation=data.message)
facts_obj: FactsGen = await self._json([
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": facts_prompt},
], FactsGen)
domain_tpl = env.from_string(DOMAIN_PROMPT[language])
domain_prompt = domain_tpl.render(conversation=data.message)
domain_list: DomainGen = await self._json([
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": domain_prompt},
], DomainGen)
for domain in domain_list.keywords:
await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain)
yield CallOutputChunk(
type=CallOutputType.DATA,
content=FactsOutput(
facts=facts_obj.facts,
domain=domain_list.keywords,
).model_dump(by_alias=True, exclude_none=True),
)
async def exec(
self,
executor: "StepExecutor",
input_data: dict[str, Any],
language: LanguageType = LanguageType.CHINESE,
) -> AsyncGenerator[CallOutputChunk, None]:
"""执行工具"""
async for chunk in self._exec(input_data, language):
content = chunk.content
if not isinstance(content, dict):
err = "[FactsCall] 工具输出格式错误"
raise TypeError(err)
executor.task.runtime.facts = FactsOutput.model_validate(content).facts
yield chunk