# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""
Core Call类是定义了所有Call都应具有的方法和参数的PyDantic类。

所有Call类必须继承此类,并根据需求重载方法。
"""

import logging
import uuid
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Self

from pydantic import BaseModel, ConfigDict, Field
from pydantic.json_schema import SkipJsonSchema

from apps.common.config import config
from apps.models import ExecutorHistory, LanguageType, NodeInfo
from apps.schemas.enum_var import CallOutputType
from apps.schemas.scheduler import (
    CallIds,
    CallInfo,
    CallOutputChunk,
    CallVars,
)

if TYPE_CHECKING:
    from apps.scheduler.executor.step import StepExecutor


logger = logging.getLogger(__name__)


class DataBase(BaseModel):
    """所有Call的输入基类"""

    @classmethod
    def model_json_schema(cls, override: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, Any]:
        """通过override参数,动态填充Schema内容"""
        schema = super().model_json_schema(**kwargs)
        if override:
            for key, value in override.items():
                schema["properties"][key] = value
        return schema


class CoreCall(BaseModel):
    """所有Call的父类,包含通用的逻辑"""

    name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True)
    description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True)
    node: SkipJsonSchema[NodeInfo | None] = Field(description="节点信息", exclude=True)
    enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False, exclude=True)
    input_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field(
        description="Call的输入Pydantic类型;不包含override的模板",
        exclude=True,
        frozen=True,
    )
    output_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field(
        description="Call的输出Pydantic类型;不包含override的模板",
        exclude=True,
        frozen=True,
    )

    to_user: bool = Field(description="是否需要将输出返回给用户", default=False)

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="allow",
    )


    def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None:
        """初始化子类"""
        super().__init_subclass__(**kwargs)
        cls.input_model = input_model
        cls.output_model = output_model


    @classmethod
    def info(cls, language: LanguageType = LanguageType.CHINESE) -> CallInfo:
        """返回Call的名称和描述"""
        err = "[CoreCall] 必须手动实现info方法"
        raise NotImplementedError(err)


    def _load_prompt(self, prompt_id: str) -> str:
        """
        从Markdown文件加载提示词

        :param prompt_id: 提示词ID,例如 "domain", "facts" 等
        :return: 提示词内容
        """
        language = self._sys_vars.language.value

        filename = f"{prompt_id}.{language}.md"
        prompt_dir = Path(config.deploy.data_dir) / "prompts" / "call"
        prompt_file = prompt_dir / filename
        return prompt_file.read_text(encoding="utf-8")


    @staticmethod
    def _assemble_call_vars(executor: "StepExecutor") -> CallVars:
        """组装CallVars"""
        if not executor.task.state:
            err = "[CoreCall] 当前ExecutorState为空"
            logger.error(err)
            raise ValueError(err)

        history: dict[uuid.UUID, ExecutorHistory] = {}
        history_order: list[uuid.UUID] = []
        for item in executor.task.context:
            history[item.stepId] = item
            history_order.append(item.stepId)

        return CallVars(
            language=executor.task.runtime.language,
            ids=CallIds(
                task_id=executor.task.metadata.id,
                executor_id=executor.task.state.executorId,
                auth_header=executor.task.runtime.authHeader,
                user_id=executor.task.metadata.userId,
                app_id=executor.task.state.appId,
                conversation_id=executor.task.metadata.conversationId,
            ),
            question=executor.question,
            step_data=history,
            step_order=history_order,
            background=executor.background,
            thinking=executor.task.runtime.reasoning,
            app_metadata=executor.app_metadata,
        )


    @staticmethod
    def _extract_history_variables(path: str, history: dict[uuid.UUID, ExecutorHistory]) -> Any:
        """
        提取History中的变量

        :param path: 路径,格式为:step_id/key/to/variable
        :param history: Step历史,即call_vars.step_data
        :return: 变量
        """
        split_path = path.split("/")
        if len(split_path) < 1:
            err = f"[CoreCall] 路径格式错误: {path}"
            logger.error(err)
            return None

        # 将字符串形式的步骤ID转换为UUID
        try:
            step_id = uuid.UUID(split_path[0])
        except ValueError:
            err = f"[CoreCall] 步骤ID格式错误: {split_path[0]}"
            logger.exception(err)
            return None

        if step_id not in history:
            err = f"[CoreCall] 步骤{step_id}不存在"
            logger.error(err)
            return None

        data = history[step_id].outputData
        for key in split_path[1:]:
            if key not in data:
                err = f"[CoreCall] 输出Key {key} 不存在"
                logger.error(err)
                return None
            data = data[key]
        return data


    @classmethod
    async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self:
        """实例化Call类"""
        obj = cls(
            name=executor.step.step.name,
            description=executor.step.step.description,
            node=node,
            **kwargs,
        )

        await obj._set_input(executor)
        return obj

    async def _set_input(self, executor: "StepExecutor") -> None:
        """获取Call的输入"""
        self._llm_obj = executor.llm
        self._sys_vars = self._assemble_call_vars(executor)
        input_data = await self._init(self._sys_vars)
        self.input = input_data.model_dump(by_alias=True, exclude_none=True)


    async def _init(self, call_vars: CallVars) -> DataBase:
        """初始化Call类,并返回Call的输入"""
        err = "[CoreCall] 初始化方法必须手动实现"
        raise NotImplementedError(err)


    async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
        """Call类实例的流式输出方法"""
        yield CallOutputChunk(type=CallOutputType.TEXT, content="")


    async def _after_exec(self, input_data: dict[str, Any]) -> None:
        """Call类实例的执行后方法"""


    async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
        """Call类实例的执行方法"""
        async for chunk in self._exec(input_data):
            yield chunk
        await self._after_exec(input_data)


    async def _llm(self, messages: list[dict[str, Any]], *, streaming: bool = False) -> AsyncGenerator[str, None]:
        """Call可直接使用的LLM非流式调用"""
        think_tag_opened = False
        async for chunk in self._llm_obj.call(messages, streaming=streaming):
            if chunk.reasoning_content:
                if not think_tag_opened:
                    yield "<think>"
                    think_tag_opened = True
                yield chunk.reasoning_content

            if chunk.content:
                if think_tag_opened:
                    yield "</think>"
                    think_tag_opened = False
                yield chunk.content