# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""提取事实工具"""

from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, Self

from jinja2.loaders import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from pydantic import Field

from apps.llm import json_generator
from apps.models import LanguageType, NodeInfo
from apps.scheduler.call.core import CoreCall
from apps.schemas.enum_var import CallOutputType
from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars
from apps.services.tag import TagManager
from apps.services.user_tag import UserTagManager

from .func import DOMAIN_FUNCTION, FACTS_FUNCTION
from .schema import (
    DomainGen,
    FactsGen,
    FactsInput,
    FactsOutput,
)

if TYPE_CHECKING:
    from apps.scheduler.executor.step import StepExecutor


class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
    """提取事实工具"""

    answer: str = Field(description="用户输入")


    @classmethod
    def info(cls, language: LanguageType = LanguageType.CHINESE) -> CallInfo:
        """返回Call的名称和描述"""
        i18n_info = {
            LanguageType.CHINESE: CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。"),
            LanguageType.ENGLISH: CallInfo(
                name="Fact Extraction",
                description="Extract facts from the conversation context and document snippets.",
            ),
        }
        return i18n_info[language]


    @classmethod
    async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self:
        """初始化工具"""
        obj = cls(
            answer=executor.task.runtime.fullAnswer,
            name=executor.step.step.name,
            description=executor.step.step.description,
            node=node,
            **kwargs,
        )

        await obj._set_input(executor)
        return obj


    async def _init(self, call_vars: CallVars) -> FactsInput:
        """初始化工具"""
        # 组装必要变量
        message = [
            {"role": "user", "content": call_vars.question},
            {"role": "assistant", "content": self.answer},
        ]

        return FactsInput(
            user_id=call_vars.ids.user_id,
            message=message,
        )


    async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
        """执行工具"""
        # 若answer为空,则跳过整个Facts逻辑
        if not self.answer or not self.answer.strip():
            yield CallOutputChunk(
                type=CallOutputType.DATA,
                content=FactsOutput(
                    facts=[],
                    domain=[],
                ).model_dump(by_alias=True, exclude_none=True),
            )
            return

        data = FactsInput(**input_data)

        # 加载并组装facts提示词
        facts_prompt = self._load_prompt("facts")

        # 提取事实信息
        facts_result = await json_generator.generate(
            function=FACTS_FUNCTION[self._sys_vars.language],
            conversation=[
                {"role": "system", "content": "You are a helpful assistant."},
                *data.message,
            ],
            prompt=facts_prompt,
        )
        facts_obj = FactsGen.model_validate(facts_result)

        # 获取所有标签
        all_tags = await TagManager.get_all_tag()
        tag_names = [tag.name for tag in all_tags]

        # 若tag_names为空,跳过Domain部分逻辑
        if not tag_names:
            yield CallOutputChunk(
                type=CallOutputType.DATA,
                content=FactsOutput(
                    facts=facts_obj.facts,
                    domain=[],
                ).model_dump(by_alias=True, exclude_none=True),
            )
            return

        # 加载domain提示词模板
        domain_prompt_template_str = self._load_prompt("domain")

        jinja_env = SandboxedEnvironment(
            loader=BaseLoader(),
            autoescape=False,
        )
        domain_prompt_template = jinja_env.from_string(domain_prompt_template_str)
        domain_prompt = domain_prompt_template.render(available_keywords=tag_names)

        # 更新用户画像
        domain_result = await json_generator.generate(
            function=DOMAIN_FUNCTION[self._sys_vars.language],
            conversation=[
                {"role": "system", "content": "You are a helpful assistant."},
                *data.message,
            ],
            prompt=domain_prompt,
        )
        domain_list = DomainGen.model_validate(domain_result)

        for domain in domain_list.keywords:
            # 先检查标签是否存在
            tag = await TagManager.get_tag_by_name(domain)
            if tag:
                # 标签存在,更新用户标签
                await UserTagManager.update_user_domain_by_user_and_domain_name(data.user_id, domain)
            # 标签不存在,跳过处理

        yield CallOutputChunk(
            type=CallOutputType.DATA,
            content=FactsOutput(
                facts=facts_obj.facts,
                domain=domain_list.keywords,
            ).model_dump(by_alias=True, exclude_none=True),
        )


    async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
        """执行工具"""
        async for chunk in self._exec(input_data):
            content = chunk.content
            if not isinstance(content, dict):
                err = "[FactsCall] 工具输出格式错误"
                raise TypeError(err)
            executor.task.runtime.fact = FactsOutput.model_validate(content).facts
            yield chunk