"""
Runner 初始化模块
在 FastAPI 启动时调用,配置 openJiuwen Runner 和 Checkpointer。
支持三种模式:in_memory(开发)、persistence(单机生产)、redis(分布式生产)。
"""
import logging
import importlib
import types
from openjiuwen.core.runner import Runner
from openjiuwen.core.runner.runner_config import RunnerConfig
from openjiuwen.core.session.checkpointer import CheckpointerFactory
from openjiuwen.core.session.checkpointer.checkpointer import CheckpointerConfig
from openjiuwen.core.session.interaction.interactive_input import InteractiveInput
from server.core.config import settings
from server.core.kb_obs_requirement import assert_kb_obs_configured_for_redis
logger = logging.getLogger(__name__)
SUPPORTED_CHECKPOINTER_TYPES = {"in_memory", "persistence", "redis"}
PATCHED_CHECKPOINTER_IDS = set()
async def _workflow_state_exists(checkpointer, session) -> bool:
"""判断当前会话是否存在可恢复的 workflow state。
优先兼容 persistence / redis 这类暴露 ``_workflow_storage`` 的实现;
若未暴露统一存储对象,再回退兼容 in_memory checkpointer 的
``_workflow_stores`` 会话内存表。这样业务层就不需要感知具体
checkpointer 类型,统一由服务端决定是否走 InteractiveInput 恢复。
Args:
checkpointer: 当前启用的 checkpointer 实例。
session: workflow session。
Returns:
若当前 session 已存在 workflow checkpoint,返回 ``True``;否则返回 ``False``。
"""
workflow_storage = getattr(checkpointer, "_workflow_storage", None)
if workflow_storage and hasattr(workflow_storage, "exists"):
return await workflow_storage.exists(session)
workflow_stores = getattr(checkpointer, "_workflow_stores", None)
if isinstance(workflow_stores, dict):
workflow_store = workflow_stores.get(session.session_id())
if workflow_store and hasattr(workflow_store, "exists"):
return await workflow_store.exists(session)
return False
def _patch_checkpointer_interactive_recovery():
"""
Make checkpointer accept dict inputs and auto-convert recovery messages.
This keeps business-layer inputs as normal dict while satisfying
checkpointer's interactive recovery path for persistence/redis backends.
"""
checkpointer = CheckpointerFactory.get_checkpointer()
checkpointer_id = id(checkpointer) if checkpointer else None
if not checkpointer or checkpointer_id in PATCHED_CHECKPOINTER_IDS:
return
original = checkpointer.pre_workflow_execute
async def _patched_pre_workflow_execute(self, session, inputs):
effective_inputs = inputs
if isinstance(inputs, dict):
query = inputs.get("query")
if isinstance(query, InteractiveInput):
effective_inputs = query
else:
should_recover = False
try:
should_recover = await _workflow_state_exists(self, session)
except Exception as e:
logger.debug("Failed to auto-detect workflow recovery state: %s", e)
if query is not None and should_recover:
effective_inputs = InteractiveInput(query)
return await original(session, effective_inputs)
checkpointer.pre_workflow_execute = types.MethodType(_patched_pre_workflow_execute, checkpointer)
PATCHED_CHECKPOINTER_IDS.add(checkpointer_id)
logger.info("Applied checkpointer interactive recovery compatibility patch.")
def _build_checkpointer_config() -> CheckpointerConfig:
"""
根据环境配置构建 CheckpointerConfig
Returns:
CheckpointerConfig: 配置好的 Checkpointer 配置对象
"""
cp_type = (settings.checkpointer_type or "").strip().lower()
if cp_type not in SUPPORTED_CHECKPOINTER_TYPES:
raise ValueError(
"Invalid CHECKPOINTER_TYPE: "
f"{settings.checkpointer_type}. Supported values: "
f"{', '.join(sorted(SUPPORTED_CHECKPOINTER_TYPES))}."
)
if cp_type == "redis":
conf = {
"connection": {
"url": settings.redis_url,
"cluster_mode": settings.redis_cluster_mode,
},
"ttl": {
"default_ttl": settings.redis_ttl,
"refresh_on_read": settings.redis_refresh_on_read,
}
}
elif cp_type == "persistence":
conf = {
"db_type": settings.checkpointer_db_type,
"db_path": settings.checkpointer_db_path,
}
else:
conf = {}
return CheckpointerConfig(type=cp_type, conf=conf)
async def init_runner():
"""
初始化 Runner,配置 Checkpointer。
应在 FastAPI lifespan startup 中调用。
"""
cp_type = (settings.checkpointer_type or "").strip().lower()
if cp_type not in SUPPORTED_CHECKPOINTER_TYPES:
raise ValueError(
"Invalid CHECKPOINTER_TYPE: "
f"{settings.checkpointer_type}. Supported values: "
f"{', '.join(sorted(SUPPORTED_CHECKPOINTER_TYPES))}."
)
if cp_type == "redis":
assert_kb_obs_configured_for_redis()
importlib.import_module("openjiuwen.extensions.checkpointer.redis.checkpointer")
logger.info("Redis checkpointer provider registered.")
elif cp_type == "persistence":
importlib.import_module("openjiuwen.core.session.checkpointer.persistence")
logger.info("Persistence checkpointer provider registered.")
runner_config = RunnerConfig()
runner_config.distributed_mode = False
runner_config.checkpointer_config = _build_checkpointer_config()
Runner.set_config(runner_config)
await Runner.start()
_patch_checkpointer_interactive_recovery()
logger.info(
"Runner initialized with checkpointer type: %s",
cp_type,
)
async def shutdown_runner():
"""
关闭 Runner 释放资源。
应在 FastAPI lifespan shutdown 中调用。
"""
try:
if hasattr(Runner, "stop"):
await Runner.stop()
logger.info("Runner shut down.")
except Exception as e:
logger.warning("Error shutting down runner: %s", e)