from typing import (
Any,
Optional,
Union,
)
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from redis.asyncio.client import Redis
from redis.asyncio.cluster import RedisCluster
from openjiuwen.core.common.constants.constant import INTERACTIVE_INPUT
from openjiuwen.core.common.exception.codes import StatusCode
from openjiuwen.core.common.exception.errors import build_error
from openjiuwen.core.common.logging import logger
from openjiuwen.core.graph.pregel import TASK_STATUS_INTERRUPT
from openjiuwen.core.session import (
BaseSession,
Checkpointer,
InteractiveInput,
)
from openjiuwen.core.session.checkpointer import (
CheckpointerFactory,
CheckpointerProvider,
)
from openjiuwen.core.session.constants import FORCE_DEL_WORKFLOW_STATE_KEY
from openjiuwen.extensions.checkpointer.redis.storage import (
AgentStorage,
AgentGroupStorage,
GraphStore,
WorkflowStorage,
)
from openjiuwen.extensions.store.kv import RedisStore
class RedisTTLConfig(BaseModel):
"""
TTL (Time To Live) configuration for Redis stored data.
Attributes:
default_ttl: Default TTL in minutes. If set, all stored data will have
this expiration time.
refresh_on_read: If True, TTL will be refreshed when data is read.
This extends the lifetime of frequently accessed data.
"""
default_ttl: Optional[float] = Field(
default=None,
description="Default TTL in minutes for stored data"
)
refresh_on_read: bool = Field(
default=False,
description="Whether to refresh TTL when data is read"
)
class RedisConnectionConfig(BaseModel):
"""
Redis connection configuration.
This class provides a structured way to configure Redis connections
with validation and type safety.
Attributes:
redis_client: Pre-configured Redis or RedisCluster client instance.
If provided, other connection parameters are ignored.
url: Redis connection URL. Can be a standalone URL (redis://) or
cluster URL (redis+cluster:// or rediss+cluster://).
cluster_mode: Explicitly enable/disable cluster mode.
If None, auto-detected from URL scheme.
connection_args: Additional connection arguments passed to Redis client.
Examples:
Standalone Redis:
>>> config = RedisConnectionConfig(url="redis://localhost:6379")
Cluster mode with explicit flag:
>>> config = RedisConnectionConfig(
... url="redis://localhost:7000",
... cluster_mode=True
... )
Cluster mode with URL scheme:
>>> config = RedisConnectionConfig(url="redis+cluster://localhost:7000")
With pre-configured client:
>>> redis_client = Redis.from_url("redis://localhost:6379")
>>> config = RedisConnectionConfig(redis_client=redis_client)
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
redis_client: Optional[Union[Redis, RedisCluster]] = Field(
default=None,
description="Pre-configured Redis or RedisCluster client instance"
)
url: Optional[str] = Field(
default=None,
description="Redis connection URL (redis:// or redis+cluster://)"
)
cluster_mode: Optional[bool] = Field(
default=None,
description="Explicitly enable/disable cluster mode (auto-detected if None)"
)
connection_args: dict[str, Any] = Field(
default_factory=dict,
description="Additional connection arguments for Redis client"
)
@field_validator('url')
@classmethod
def validate_url(cls, v: Optional[str]) -> Optional[str]:
"""Validate URL format."""
if v is None:
return v
if not (v.startswith("redis://") or
v.startswith("rediss://") or
v.startswith("redis+cluster://") or
v.startswith("rediss+cluster://")):
raise ValueError(
f"Invalid Redis URL format: {v}. "
"URL must start with redis://, rediss://, "
"redis+cluster://, or rediss+cluster://"
)
return v
@model_validator(mode='after')
def validate_config(self) -> 'RedisConnectionConfig':
"""Validate that at least one connection method is provided."""
if self.redis_client is None and self.url is None:
raise ValueError(
"Either 'redis_client' or 'url' must be provided in RedisConnectionConfig"
)
return self
def is_cluster_mode(self) -> bool:
"""
Determine if cluster mode should be used.
Returns:
bool: True if cluster mode should be used, False otherwise.
"""
if self.redis_client is not None:
return isinstance(self.redis_client, RedisCluster)
if self.cluster_mode is not None:
return self.cluster_mode
if self.url is not None:
return (self.url.startswith("redis+cluster://") or
self.url.startswith("rediss+cluster://"))
return False
def get_connection_url(self) -> Optional[str]:
"""
Get the connection URL, normalizing cluster URLs if needed.
Returns:
Optional[str]: Normalized connection URL.
"""
if self.url is None:
return None
if self.url.startswith("redis+cluster://"):
return self.url.replace("redis+cluster://", "redis://")
elif self.url.startswith("rediss+cluster://"):
return self.url.replace("rediss+cluster://", "rediss://")
return self.url
class RedisCheckpointerConfig(BaseModel):
"""
Complete configuration for Redis checkpointer.
This class provides a structured, type-safe configuration for Redis checkpointer
with automatic validation and sensible defaults.
Attributes:
connection: Redis connection configuration.
ttl: TTL configuration for stored data.
Examples:
Minimal configuration (standalone Redis):
>>> config = RedisCheckpointerConfig(
... connection=RedisConnectionConfig(url="redis://localhost:6379")
... )
With TTL configuration:
>>> config = RedisCheckpointerConfig(
... connection=RedisConnectionConfig(url="redis://localhost:6379"),
... ttl=RedisTTLConfig(default_ttl=5, refresh_on_read=True)
... )
Cluster mode:
>>> config = RedisCheckpointerConfig(
... connection=RedisConnectionConfig(
... url="redis://localhost:7000",
... cluster_mode=True
... )
... )
Using pre-configured client:
>>> redis_client = Redis.from_url("redis://localhost:6379")
>>> config = RedisCheckpointerConfig(
... connection=RedisConnectionConfig(redis_client=redis_client),
... ttl=RedisTTLConfig(default_ttl=10)
... )
"""
connection: RedisConnectionConfig = Field(
...,
description="Redis connection configuration"
)
ttl: Optional[RedisTTLConfig] = Field(
default=None,
description="TTL configuration for stored data"
)
@CheckpointerFactory.register("redis")
class RedisCheckpointerProvider(CheckpointerProvider):
"""
Provider for creating Redis-based checkpointers.
Supports both standalone Redis and Redis Cluster modes.
Uses structured configuration with automatic validation.
"""
async def create(self, conf: dict) -> Checkpointer:
"""
Create a RedisCheckpointer instance.
Configuration format:
{
"connection": {
"redis_client": Redis(...), # Optional: Pre-configured client
"url": "redis://...", # Required if redis_client not provided
"cluster_mode": True, # Optional: Auto-detected from URL if None
"connection_args": {...} # Optional: Additional connection args
},
"ttl": { # Optional
"default_ttl": 5, # Optional: TTL in minutes
"refresh_on_read": True # Optional: Refresh TTL on read
}
}
Args:
conf (dict): Configuration dictionary with 'connection' and optional 'ttl' keys.
The 'connection' dict must contain either 'redis_client' or 'url'.
Returns:
Checkpointer: A RedisCheckpointer instance.
Raises:
ValueError: If configuration is invalid or missing required fields.
Examples:
Standalone Redis:
>>> conf = {
... "connection": {"url": "redis://localhost:6379"}
... }
Cluster mode:
>>> conf = {
... "connection": {
... "url": "redis://localhost:7000",
... "cluster_mode": True
... }
... }
With TTL:
>>> conf = {
... "connection": {"url": "redis://localhost:6379"},
... "ttl": {"default_ttl": 5, "refresh_on_read": True}
... }
"""
try:
config = RedisCheckpointerConfig.model_validate(conf)
except Exception as e:
raise ValueError(
f"Invalid Redis checkpointer configuration: {e}. "
"Configuration must have a 'connection' key with either 'redis_client' or 'url'. "
"Optional 'ttl' key for TTL configuration."
) from e
connection = config.connection
if connection.redis_client is not None:
redis_store = RedisStore(connection.redis_client)
ttl_dict = config.ttl.model_dump() if config.ttl else None
return RedisCheckpointer(redis_store, ttl_dict)
connection_url = connection.get_connection_url()
if connection_url is None:
raise ValueError(
"Either 'redis_client' or 'url' must be provided in connection configuration"
)
is_cluster = connection.is_cluster_mode()
try:
if is_cluster:
redis = RedisCluster.from_url(
connection_url,
**connection.connection_args
)
else:
redis = Redis.from_url(
connection_url,
**connection.connection_args
)
except Exception as e:
raise ValueError(
f"Failed to create Redis client: {e}. "
f"URL: {connection_url}, Cluster mode: {is_cluster}"
) from e
redis_store = RedisStore(redis)
ttl_dict = config.ttl.model_dump() if config.ttl else None
return RedisCheckpointer(redis_store, ttl_dict)
class RedisCheckpointer(Checkpointer):
"""
Redis-based checkpointer implementation.
This checkpointer only interacts with RedisStore and does not directly use
Redis client APIs. All Redis operations are performed through RedisStore.
"""
def __init__(self,
redis_store: RedisStore,
ttl: Optional[dict[str, Any]] = None):
"""
Initialize RedisCheckpointer with a RedisStore instance.
Args:
redis_store (RedisStore): The RedisStore instance for all Redis operations.
ttl (Optional[dict[str, Any]]): Optional TTL configuration for stored data.
"""
self._redis_store = redis_store
self._agent_storage = AgentStorage(redis_store, ttl)
self._agent_group_storage = AgentGroupStorage(redis_store, ttl)
self._workflow_storage = WorkflowStorage(redis_store, ttl)
self._graph_state = GraphStore(redis_store, ttl)
async def pre_agent_execute(self, session: BaseSession, inputs):
agent_id = session.agent_id()
logger.info(
f"agent: {agent_id} create or restore checkpoint from session: {session.session_id()}")
await self._agent_storage.recover(session)
if inputs is not None:
session.state().update({INTERACTIVE_INPUT: [inputs]})
async def pre_agent_team_execute(self, session: BaseSession, inputs):
logger.info(
f"agent group: {session.group_id()} create or restore checkpoint from session: {session.session_id()}"
)
await self._agent_group_storage.recover(session)
if inputs is not None:
session.state().update_global({INTERACTIVE_INPUT: [inputs]})
async def interrupt_agent_execute(self, session: BaseSession):
logger.info(f"interaction required, save checkpoint for "
f"agent: {session.agent_id()} in session: {session.session_id()}")
await self._agent_storage.save(session)
async def post_agent_execute(self, session: BaseSession):
logger.info(f"agent finished, save checkpoint for "
f"agent: {session.agent_id()} in session: {session.session_id()}")
await self._agent_storage.save(session)
async def post_agent_team_execute(self, session: BaseSession):
logger.info(
f"agent group finished, save checkpoint for group: {session.group_id()} "
f"in session: {session.session_id()}"
)
await self._agent_group_storage.save(session)
async def pre_workflow_execute(self, session: BaseSession, inputs: InteractiveInput):
"""
Prepare workflow execution by recovering or clearing workflow state.
If inputs is an InteractiveInput, recover the workflow state.
If inputs is not an InteractiveInput and workflow state exists:
- If FORCE_DEL_WORKFLOW_STATE_KEY is True, delete graph state and workflow state
- Otherwise, raise WORKFLOW_STATE_INVALID exception
Args:
session (BaseSession): The session for the workflow.
inputs (InteractiveInput): The input for the workflow execution.
"""
workflow_id = session.workflow_id()
logger.info(f"workflow: {workflow_id} create or restore checkpoint from "
f"session: {session.session_id()}")
if isinstance(inputs, InteractiveInput):
await self._workflow_storage.recover(session, inputs)
else:
if not await self._workflow_storage.exists(session):
return
if session.config().get_env(FORCE_DEL_WORKFLOW_STATE_KEY, False):
workflow_id = session.workflow_id()
if workflow_id is None:
logger.warning(f"workflow_id is None for session: {session.session_id()}")
return
session_id = session.session_id()
await self._graph_state.delete(session_id, workflow_id)
await self._workflow_storage.clear(workflow_id, session_id)
logger.info(
f"Force deleted workflow state for workflow: {workflow_id} "
f"in session: {session_id}"
)
else:
raise build_error(
StatusCode.CHECKPOINTER_PRE_WORKFLOW_EXECUTION_ERROR,
workflow=workflow_id,
reason="workflow state exists but non-interactive input and cleanup is disabled"
)
async def post_workflow_execute(self, session: BaseSession, result, exception):
workflow_id = session.workflow_id()
session_id = session.session_id()
if exception is not None:
logger.info(f"exception in workflow, save checkpoint for "
f"workflow: {workflow_id} in session: {session_id}")
await self._workflow_storage.save(session)
raise exception
if result.get(TASK_STATUS_INTERRUPT) is None:
logger.info(f"clear checkpoint for workflow: {workflow_id} in session: {session_id}")
await self._graph_state.delete(session_id, workflow_id)
await self._workflow_storage.clear(workflow_id, session_id)
else:
logger.info(f"interaction required, save checkpoint for "
f"workflow: {workflow_id} in session: {session_id}")
await self._workflow_storage.save(session)
async def session_exists(self, session_id: str) -> bool:
"""
Check if a session exists in Redis.
This method checks if there are any keys associated with the given session_id
in Redis. It uses prefix-based lookup to find any keys starting with the
session_id pattern.
Args:
session_id (str): The session ID to check.
Returns:
bool: True if the session exists (has associated keys), False otherwise.
"""
if self._redis_store is None:
return False
prefix = f"{session_id}:"
keys = await self._redis_store.get_by_prefix(prefix)
return len(keys) > 0
async def release(self, session_id: str, agent_id: Optional[str] = None):
"""
Release resources for a session, optionally for a specific agent.
All Redis operations are performed through RedisStore.
Args:
session_id (str): The session ID to release resources for.
agent_id (str, optional): If provided, only release resources for this specific agent.
"""
if self._redis_store is None:
logger.warning("Cannot release resources: RedisStore is None")
return
if agent_id is not None:
logger.info(f"clear checkpoint for agent: {agent_id} in session: {session_id}")
await self._agent_storage.clear(agent_id, session_id)
else:
logger.info(f"clear session: {session_id}")
prefix = f"{session_id}:"
await self._redis_store.delete_by_prefix(prefix, batch_size=500)
logger.debug(f"Released all resources for session {session_id}")
def graph_store(self):
return self._graph_state