from abc import ABC, abstractmethod
from pymilvus.client.utils import is_successful
from openjiuwen.core.common.constants.constant import INTERACTIVE_INPUT
from openjiuwen.core.common.exception.codes import StatusCode
from openjiuwen.core.common.exception.errors import BaseError, build_error
from openjiuwen.core.common.logging import session_logger, LogEventType
from openjiuwen.core.graph.store import (
Serializer,
Store,
create_serializer,
)
from openjiuwen.core.session.checkpointer import Checkpointer
from openjiuwen.core.session.checkpointer.base import Storage
from openjiuwen.core.session.constants import FORCE_DEL_WORKFLOW_STATE_KEY
from openjiuwen.core.session.interaction.interactive_input import InteractiveInput
from openjiuwen.core.session.session import BaseSession
class InMemoryCheckpointer(Checkpointer):
def __init__(self):
self._agent_stores = {}
self._agent_team_stores = {}
self._workflow_stores = {}
from openjiuwen.core.graph import InMemoryStore
self._graph_store = InMemoryStore()
self._session_to_workflow_ids = {}
async def pre_workflow_execute(self, session: BaseSession, inputs: InteractiveInput):
session_id = session.session_id()
workflow_id = session.workflow_id()
is_new_workflow_store = session_id not in self._workflow_stores
workflow_store = self._workflow_stores.setdefault(session_id, WorkflowStorage())
log_message = dict(
session_id=session_id,
workflow_id=workflow_id,
metadata={"storage_type": "inmemory"})
if is_new_workflow_store:
session_logger.info("Create a new workflow checkpointer store before workflow execute",
event_type=LogEventType.CHECKPOINTER_STORE_ADD, **log_message
)
self._session_to_workflow_ids.setdefault(session_id, set())
if isinstance(inputs, InteractiveInput):
session_logger.info(
"Begin to restore workflow session before workflow execute",
event_type=LogEventType.CHECKPOINT_RESTORE, **log_message
)
await workflow_store.recover(session, inputs)
session_logger.info(
"Succeed to restore workflow session before workflow execute",
event_type=LogEventType.CHECKPOINT_RESTORE, **log_message
)
else:
if not await workflow_store.exists(session):
return
if session.config().get_env(FORCE_DEL_WORKFLOW_STATE_KEY, False):
session_logger.info(
f"Begin to clear all of current workflow's checkpoints forcefully before workflow execute",
event_type=LogEventType.CHECKPOINT_CLEAR, **log_message
)
try:
await self._graph_store.delete(session_id, workflow_id)
session_logger.info(
f"Succeed to clear all of current workflow's checkpoints forcefully before workflow execute",
event_type=LogEventType.CHECKPOINT_CLEAR, **log_message
)
except Exception as e:
raise e
finally:
await workflow_store.clear(workflow_id)
else:
raise build_error(StatusCode.CHECKPOINTER_PRE_WORKFLOW_EXECUTION_ERROR, session_id=session_id,
workflow=workflow_id,
reason="workflow state exists but non-interactive input and cleanup is disabled")
async def post_workflow_execute(self, session: BaseSession, result, exception):
session_id = session.session_id()
workflow_id = session.workflow_id()
workflow_store = self._workflow_stores.get(session_id)
if exception is not None:
if workflow_store is None:
raise build_error(StatusCode.CHECKPOINTER_POST_WORKFLOW_EXECUTION_ERROR, workflow=workflow_id,
reason="workflow store not found")
await self._inner_save_workflow_checkpoint(workflow_id, session_id, session,
f"workflow exception {exception}")
raise exception
from openjiuwen.core.graph.pregel import TASK_STATUS_INTERRUPT
if result.get(TASK_STATUS_INTERRUPT) is None:
try:
await self._inner_clear_workflow_session(workflow_id=workflow_id, session_id=session_id,
reason="workflow execute completion")
finally:
from openjiuwen.core.session.internal.agent import AgentSession
if not isinstance(session.parent(), AgentSession):
self._workflow_stores.pop(session_id, None)
self._session_to_workflow_ids.pop(session_id, None)
session_logger.info(
f"Remove workflow checkpoint store on workflow execute completion",
event_type=LogEventType.CHECKPOINTER_STORE_REMOVE,
session_id=session_id,
workflow_id=workflow_id,
metadata={"storage_type": "inmemory"}
)
else:
if workflow_store is None:
raise build_error(StatusCode.CHECKPOINTER_POST_WORKFLOW_EXECUTION_ERROR, workflow=workflow_id,
reason="workflow store not found")
await self._inner_save_workflow_checkpoint(workflow_id, session_id, session, "workflow interruption")
async def _inner_save_workflow_checkpoint(self, workflow_id, session_id, session, reason):
workflow_store = self._workflow_stores.get(session_id)
workflow_ids = self._session_to_workflow_ids.get(session_id)
session_logger.info(
f"Begin to save workflow checkpoint on {reason}",
event_type=LogEventType.CHECKPOINT_SAVE,
session_id=session_id,
workflow_id=workflow_id,
metadata={"storage_type": "inmemory"}
)
await workflow_store.save(session)
workflow_ids.add(workflow_id)
session_logger.info(
f"Succeed to save workflow checkpoint on {reason}",
event_type=LogEventType.CHECKPOINT_SAVE,
session_id=session_id,
workflow_id=workflow_id,
metadata={"storage_type": "inmemory"}
)
async def _inner_clear_workflow_session(self, workflow_id, session_id, reason):
workflow_store = self._workflow_stores.get(session_id)
workflow_ids = self._session_to_workflow_ids.get(session_id)
log_message = dict(
event_type=LogEventType.CHECKPOINT_CLEAR,
session_id=session_id,
workflow_id=workflow_id,
metadata={"storage_type": "inmemory"}
)
session_logger.info(f"Begin to clear all of current workflow's checkpoints on {reason}", **log_message)
is_succeed = False
try:
await self._graph_store.delete(session_id, workflow_id)
is_succeed = True
except Exception as e:
session_logger.error(f"Failed to clear all of current workflow's checkpoints on {reason}", exception=e,
**log_message)
raise
finally:
if workflow_store is not None:
workflow_ids.discard(workflow_id)
try:
await workflow_store.clear(workflow_id)
except Exception as e:
if not is_succeed:
session_logger.error(f"Failed to clear clear all of current workflow's checkpoints on {reason}",
exception=e, **log_message)
raise
if is_succeed:
session_logger.info(f"Succeed to clear all of current workflow's checkpoints on {reason}",
**log_message)
async def pre_agent_execute(self, session: BaseSession, inputs):
agent_id = session.agent_id() if hasattr(session, "agent_id") else 'Na'
session_id = session.session_id()
is_new_agent_store = session_id not in self._agent_stores
agent_store = self._agent_stores.setdefault(session_id, AgentStorage())
log_message = dict(
session_id=session_id,
agent_id=agent_id,
metadata={"storage_type": "inmemory"}
)
if is_new_agent_store:
session_logger.info("Create a new agent checkpointer store before agent execute",
event_type=LogEventType.CHECKPOINTER_STORE_ADD, **log_message)
session_logger.info(
"Begin to restore agent session before agent execute", event_type=LogEventType.CHECKPOINT_RESTORE,
**log_message
)
await agent_store.recover(session)
session_logger.info(
"Succeed to restore agent session before agent execute", event_type=LogEventType.CHECKPOINT_RESTORE,
**log_message
)
if inputs is not None:
session.state().set_state({INTERACTIVE_INPUT: [inputs]})
async def pre_agent_team_execute(self, session: BaseSession, inputs):
team_id = session.team_id() if hasattr(session, "team_id") else "Na"
session_id = session.session_id()
is_new_team_store = session_id not in self._agent_team_stores
team_store = self._agent_team_stores.setdefault(session_id, AgentTeamStorage())
log_message = dict(
session_id=session_id,
workflow_id=team_id,
metadata={"storage_type": "inmemory"}
)
if is_new_team_store:
session_logger.info("Create a new agent team checkpointer store before team execute",
event_type=LogEventType.CHECKPOINTER_STORE_ADD, **log_message)
session_logger.info(
"Begin to restore agent team session before execute",
event_type=LogEventType.CHECKPOINT_RESTORE, **log_message
)
await team_store.recover(session)
session_logger.info(
"Succeed to restore agent team session before execute",
event_type=LogEventType.CHECKPOINT_RESTORE, **log_message
)
if inputs is not None:
session.state().update_global({INTERACTIVE_INPUT: [inputs]})
async def interrupt_agent_execute(self, session: BaseSession):
agent_id = session.agent_id()
session_id = session.session_id()
agent_store = self._agent_stores.get(session_id)
if agent_store is None:
raise build_error(StatusCode.CHECKPOINTER_INTERRUPT_AGENT_ERROR, agent=agent_id,
reason="agent store not found")
log_message = dict(
session_id=session_id,
agent_id=agent_id,
metadata={"storage_type": "inmemory"}
)
session_logger.info(
"Begin to save agent checkpoint on agent interruption",
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
try:
await agent_store.save(session)
session_logger.info(
"Succeed to save agent checkpoint on agent interruption",
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
except Exception as e:
session_logger.error(
"Failed to save agent checkpoint on agent interruption",
event_type=LogEventType.CHECKPOINT_SAVE, exception=e, **log_message
)
raise
async def post_agent_execute(self, session: BaseSession):
agent_id = session.agent_id()
session_id = session.session_id()
agent_store = self._agent_stores.get(session_id)
if agent_store is None:
raise build_error(StatusCode.CHECKPOINTER_POST_AGENT_EXECUTION_ERROR,
agent=agent_id, reason="agent store not found")
log_message = dict(
session_id=session_id,
agent_id=agent_id,
metadata={"storage_type": "inmemory"}
)
session_logger.info(
"Begin to save agent checkpoint on agent execute completion",
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
try:
await agent_store.save(session)
session_logger.info(
"Succeed to save agent checkpoint on agent execute completion",
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
except Exception as e:
session_logger.error(
"Failed to save agent checkpoint on agent execute completion",
exception=e,
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
raise
async def post_agent_team_execute(self, session: BaseSession):
team_id = session.team_id()
session_id = session.session_id()
team_store = self._agent_team_stores.get(session_id)
if team_store is None:
raise build_error(StatusCode.CHECKPOINTER_POST_AGENT_EXECUTION_ERROR,
agent=team_id, reason="agent team store not found")
log_message = dict(
session_id=session_id,
workflow_id=team_id,
metadata={"storage_type": "inmemory"}
)
session_logger.info(
"Begin to save agent team checkpoint on team execute completion",
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
try:
await team_store.save(session)
session_logger.info(
"Succeed to save agent team checkpoint on team execute completion",
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
except Exception as e:
session_logger.error(
"Failed to save agent team checkpoint on team execute completion",
exception=e,
event_type=LogEventType.CHECKPOINT_SAVE, **log_message
)
raise
async def session_exists(self, session_id: str) -> bool:
return (
session_id in self._agent_stores
or session_id in self._agent_team_stores
or session_id in self._workflow_stores
)
async def release(self, session_id: str, agent_id: str = None):
if agent_id is not None:
agent_store = self._agent_stores.get(session_id)
if agent_store is None:
return
session_logger.info("Begin to clear all of current agent's checkpoints on on manually release",
event_type=LogEventType.CHECKPOINT_CLEAR,
agent_id=agent_id, session_id=session_id, metadata={"storage_type": "inmemory"})
try:
await agent_store.clear(agent_id)
session_logger.info("Succeed to clear all of current agent's checkpoints on on manually release",
event_type=LogEventType.CHECKPOINT_CLEAR,
agent_id=agent_id, session_id=session_id, metadata={"storage_type": "inmemory"})
except Exception as e:
session_logger.error("Failed to clear all of current agent's checkpoints on on manually release",
agent_id=agent_id,
event_type=LogEventType.CHECKPOINT_CLEAR,
session_id=session_id, exception=e, metadata={"storage_type": "inmemory"})
else:
workflow_ids = self._session_to_workflow_ids.get(session_id)
session_logger.info("Begin to clear all of current agent's workflow checkpoints on on manually release",
agent_id=agent_id,
event_type=LogEventType.CHECKPOINT_CLEAR,
session_id=session_id,
workflow_id=str(workflow_ids) if workflow_ids else '[]',
metadata={"storage_type": "inmemory"})
if workflow_ids:
for workflow_id in workflow_ids:
try:
await self._graph_store.delete(session_id, workflow_id)
except Exception as e:
session_logger.warning("Failed to clear workflow checkpoint",
event_type=LogEventType.CHECKPOINT_CLEAR,
e=e, agent_id=agent_id,
session_id=session_id, workflow_id=workflow_id,
metadata={"storage_type": "inmemory"})
self._session_to_workflow_ids.pop(session_id, None)
session_logger.info("Succeed to clear all of current agent's workflow checkpoints on on manually release",
agent_id=agent_id,
event_type=LogEventType.CHECKPOINT_CLEAR,
session_id=session_id,
workflow_id=str(workflow_ids) if workflow_ids else '[]',
metadata={"storage_type": "inmemory"})
removed = self._workflow_stores.pop(session_id, None)
if removed:
session_logger.info(
f"Remove workflow checkpoint store on manually release",
event_type=LogEventType.CHECKPOINTER_STORE_REMOVE,
agent_id=agent_id,
session_id=session_id,
metadata={"storage_type": "inmemory"}
)
matching_session_ids = [sid for sid in list(self._agent_stores.keys()) if sid.startswith(session_id)]
for sid in matching_session_ids:
removed = self._agent_stores.pop(sid, None)
if removed:
session_logger.info(
f"Remove agent checkpoint store on manually release",
event_type=LogEventType.CHECKPOINTER_STORE_REMOVE,
session_id=sid,
agent_id=agent_id,
metadata={"storage_type": "inmemory"}
)
removed = self._agent_team_stores.pop(session_id, None)
if removed:
session_logger.info(
f"Remove agent team checkpoint store on manually release",
event_type=LogEventType.CHECKPOINTER_STORE_REMOVE,
session_id=session_id,
metadata={"storage_type": "inmemory"}
)
def graph_store(self) -> Store:
return self._graph_store
class BaseSingleStateStorage(Storage, ABC):
def __init__(self):
self.state_blobs: dict[
str,
tuple[str, bytes],
] = {}
self.serde: Serializer = create_serializer("pickle")
@abstractmethod
def _get_entity_id(self, session: BaseSession) -> str:
...
@abstractmethod
def _get_state_to_save(self, session: BaseSession):
...
@abstractmethod
def _restore_state(self, session: BaseSession, state) -> None:
...
async def save(self, session: BaseSession):
entity_id = self._get_entity_id(session)
state = self._get_state_to_save(session)
state_blob = self.serde.dumps_typed(state)
if state_blob:
self.state_blobs[entity_id] = state_blob
async def recover(self, session: BaseSession, inputs: InteractiveInput = None):
entity_id = self._get_entity_id(session)
state_blob = self.state_blobs.get(entity_id)
if state_blob is None:
return
state = self.serde.loads_typed(state_blob)
self._restore_state(session, state)
async def clear(self, entity_id: str):
self.state_blobs.pop(entity_id, None)
async def exists(self, session: BaseSession) -> bool:
return self.state_blobs.get(self._get_entity_id(session)) is not None
class AgentStorage(BaseSingleStateStorage):
def _get_entity_id(self, session: BaseSession) -> str:
return session.agent_id()
def _get_state_to_save(self, session: BaseSession):
return session.state().get_state()
def _restore_state(self, session: BaseSession, state) -> None:
session.state().set_state(state)
class WorkflowStorage(Storage):
def __init__(self):
self.serde: Serializer = create_serializer("pickle")
self.state_blobs: dict[
str,
tuple[str, bytes],
] = {}
self.state_updates_blobs: dict[
str,
tuple[str, bytes]
] = {}
async def save(self, session: BaseSession):
workflow_id = session.workflow_id()
state = session.state().get_state()
state_blob = self.serde.dumps_typed(state)
if state_blob:
self.state_blobs[workflow_id] = state_blob
updates = session.state().get_updates()
updates_blob = self.serde.dumps_typed(updates)
if updates_blob:
self.state_updates_blobs[workflow_id] = updates_blob
async def recover(self, session: BaseSession, inputs: InteractiveInput = None):
workflow_id = session.workflow_id()
state_blob = self.state_blobs.get(workflow_id)
if state_blob and state_blob[0] != "empty":
state = self.serde.loads_typed(state_blob)
session.state().set_state(state)
if inputs.raw_inputs is not None:
session.state().update_and_commit_workflow_state({INTERACTIVE_INPUT: inputs.raw_inputs})
else:
from openjiuwen.core.session.internal.workflow import NodeSession
for node_id, value in inputs.user_inputs.items():
node_session = NodeSession(session, node_id)
interactive_input = node_session.state().get(INTERACTIVE_INPUT)
if isinstance(interactive_input, list):
interactive_input.append(value)
node_session.state().update({INTERACTIVE_INPUT: interactive_input})
else:
node_session.state().update({INTERACTIVE_INPUT: [value]})
session.state().commit()
state_updates_blob = self.state_updates_blobs.get(workflow_id)
if state_updates_blob:
state_updates = self.serde.loads_typed(state_updates_blob)
session.state().set_updates(state_updates)
async def clear(self, workflow_id: str):
self.state_blobs.pop(workflow_id, None)
self.state_updates_blobs.pop(workflow_id, None)
async def exists(self, session: BaseSession) -> bool:
state_blob = self.state_blobs.get(session.workflow_id())
if state_blob and state_blob[0] != "empty":
return True
return False
class AgentTeamStorage(BaseSingleStateStorage):
def _get_entity_id(self, session: BaseSession) -> str:
return session.team_id()
def _get_state_to_save(self, session: BaseSession):
return session.state().get_global(None)
def _restore_state(self, session: BaseSession, state) -> None:
session.state().global_state.set_state(state)