"""提取事实工具"""
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, Self
from jinja2.loaders import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from pydantic import Field
from apps.llm import json_generator
from apps.models import LanguageType, NodeInfo
from apps.scheduler.call.core import CoreCall
from apps.schemas.enum_var import CallOutputType
from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars
from apps.services.tag import TagManager
from apps.services.user_tag import UserTagManager
from .func import DOMAIN_FUNCTION, FACTS_FUNCTION
from .schema import (
DomainGen,
FactsGen,
FactsInput,
FactsOutput,
)
if TYPE_CHECKING:
from apps.scheduler.executor.step import StepExecutor
class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
"""提取事实工具"""
answer: str = Field(description="用户输入")
@classmethod
def info(cls, language: LanguageType = LanguageType.CHINESE) -> CallInfo:
"""返回Call的名称和描述"""
i18n_info = {
LanguageType.CHINESE: CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。"),
LanguageType.ENGLISH: CallInfo(
name="Fact Extraction",
description="Extract facts from the conversation context and document snippets.",
),
}
return i18n_info[language]
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self:
"""初始化工具"""
obj = cls(
answer=executor.task.runtime.fullAnswer,
name=executor.step.step.name,
description=executor.step.step.description,
node=node,
**kwargs,
)
await obj._set_input(executor)
return obj
async def _init(self, call_vars: CallVars) -> FactsInput:
"""初始化工具"""
message = [
{"role": "user", "content": call_vars.question},
{"role": "assistant", "content": self.answer},
]
return FactsInput(
user_id=call_vars.ids.user_id,
message=message,
)
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行工具"""
if not self.answer or not self.answer.strip():
yield CallOutputChunk(
type=CallOutputType.DATA,
content=FactsOutput(
facts=[],
domain=[],
).model_dump(by_alias=True, exclude_none=True),
)
return
data = FactsInput(**input_data)
facts_prompt = self._load_prompt("facts")
facts_result = await json_generator.generate(
function=FACTS_FUNCTION[self._sys_vars.language],
conversation=[
{"role": "system", "content": "You are a helpful assistant."},
*data.message,
],
prompt=facts_prompt,
)
facts_obj = FactsGen.model_validate(facts_result)
all_tags = await TagManager.get_all_tag()
tag_names = [tag.name for tag in all_tags]
if not tag_names:
yield CallOutputChunk(
type=CallOutputType.DATA,
content=FactsOutput(
facts=facts_obj.facts,
domain=[],
).model_dump(by_alias=True, exclude_none=True),
)
return
domain_prompt_template_str = self._load_prompt("domain")
jinja_env = SandboxedEnvironment(
loader=BaseLoader(),
autoescape=False,
)
domain_prompt_template = jinja_env.from_string(domain_prompt_template_str)
domain_prompt = domain_prompt_template.render(available_keywords=tag_names)
domain_result = await json_generator.generate(
function=DOMAIN_FUNCTION[self._sys_vars.language],
conversation=[
{"role": "system", "content": "You are a helpful assistant."},
*data.message,
],
prompt=domain_prompt,
)
domain_list = DomainGen.model_validate(domain_result)
for domain in domain_list.keywords:
tag = await TagManager.get_tag_by_name(domain)
if tag:
await UserTagManager.update_user_domain_by_user_and_domain_name(data.user_id, domain)
yield CallOutputChunk(
type=CallOutputType.DATA,
content=FactsOutput(
facts=facts_obj.facts,
domain=domain_list.keywords,
).model_dump(by_alias=True, exclude_none=True),
)
async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行工具"""
async for chunk in self._exec(input_data):
content = chunk.content
if not isinstance(content, dict):
err = "[FactsCall] 工具输出格式错误"
raise TypeError(err)
executor.task.runtime.fact = FactsOutput.model_validate(content).facts
yield chunk