"""
Callback Framework Chain
Callback chain execution with rollback support.
"""
import asyncio
import logging
from typing import (
Callable,
Dict,
List,
Optional,
)
import anyio
from openjiuwen.core.runner.callback.enums import ChainAction
from openjiuwen.core.runner.callback.models import (
CallbackInfo,
ChainContext,
ChainResult,
)
class CallbackChain:
"""Manages sequential execution of callbacks with rollback support.
Provides ordered execution, error handling, and rollback capabilities
for groups of related callbacks.
Attributes:
name: Chain identifier
callbacks: List of callback information objects
rollback_handlers: Mapping of callbacks to rollback functions
error_handlers: Mapping of callbacks to error handlers
"""
def __init__(self, name: str = ""):
"""Initialize callback chain.
Args:
name: Optional chain name
"""
self.name = name
self.callbacks: List[CallbackInfo] = []
self.rollback_handlers: Dict[Callable, Callable] = {}
self.error_handlers: Dict[Callable, Callable] = {}
def add(
self,
callback_info: CallbackInfo,
rollback_handler: Optional[Callable] = None,
error_handler: Optional[Callable] = None
) -> None:
"""Add callback to the chain.
Args:
callback_info: Callback metadata and configuration
rollback_handler: Optional function to call on rollback
error_handler: Optional function to call on error
"""
self.callbacks.append(callback_info)
self.callbacks.sort(key=lambda x: x.priority, reverse=True)
if rollback_handler:
self.rollback_handlers[callback_info.callback] = rollback_handler
if error_handler:
self.error_handlers[callback_info.callback] = error_handler
def remove(self, callback: Callable) -> None:
"""Remove callback from the chain.
Args:
callback: Callback function to remove
"""
self.callbacks = [ci for ci in self.callbacks if ci.callback != callback]
self.rollback_handlers.pop(callback, None)
self.error_handlers.pop(callback, None)
async def execute(self, context: ChainContext) -> ChainResult:
"""Execute the callback chain.
Executes callbacks in priority order, passing results between them.
Supports retry logic, error handling, and rollback on failure.
Args:
context: Chain execution context
Returns:
ChainResult with execution outcome
"""
executed_callbacks = []
for i, callback_info in enumerate(self.callbacks):
if not callback_info.enabled:
continue
context.current_index = i
callback = callback_info.callback
for attempt in range(callback_info.max_retries + 1):
try:
if context.results:
args = (context.get_last_result(),) + context.initial_args
else:
args = context.initial_args
kwargs = context.initial_kwargs.copy()
kwargs['_chain_context'] = context
if callback_info.timeout:
with anyio.fail_after(callback_info.timeout):
result = await callback(*args, **kwargs)
else:
result = await callback(*args, **kwargs)
if isinstance(result, ChainResult):
if result.action == ChainAction.BREAK:
context.results.append(result.result)
return ChainResult(
ChainAction.BREAK,
result=result.result,
context=context
)
elif result.action == ChainAction.RETRY:
continue
elif result.action == ChainAction.ROLLBACK:
await self._rollback(executed_callbacks, context)
return ChainResult(
ChainAction.ROLLBACK,
context=context,
error=result.error
)
else:
context.results.append(result.result)
else:
context.results.append(result)
executed_callbacks.append(callback)
if callback_info.once:
callback_info.enabled = False
break
except TimeoutError:
logging.error(f"Callback {callback.__name__} timed out")
if attempt < callback_info.max_retries:
await asyncio.sleep(callback_info.retry_delay)
continue
else:
await self._rollback(executed_callbacks, context)
return ChainResult(
ChainAction.ROLLBACK,
context=context,
error=TimeoutError("Callback timeout")
)
except Exception as e:
if callback in self.error_handlers:
try:
error_result = await self.error_handlers[callback](e, context)
if error_result:
context.results.append(error_result)
executed_callbacks.append(callback)
break
except Exception as handler_error:
logging.error(f"Error handler failed: {handler_error}")
if attempt < callback_info.max_retries:
logging.info(
f"Retrying {callback.__name__} (attempt {attempt + 1})"
)
await asyncio.sleep(callback_info.retry_delay)
continue
await self._rollback(executed_callbacks, context)
return ChainResult(ChainAction.ROLLBACK, context=context, error=e)
context.is_completed = True
return ChainResult(
ChainAction.CONTINUE,
result=context.get_last_result(),
context=context
)
async def _rollback(
self,
executed_callbacks: List[Callable],
context: ChainContext
) -> None:
"""Execute rollback handlers for executed callbacks.
Args:
executed_callbacks: List of callbacks that were executed
context: Chain execution context
"""
context.is_rolled_back = True
for callback in reversed(executed_callbacks):
if callback in self.rollback_handlers:
try:
await self.rollback_handlers[callback](context)
except Exception as e:
logging.error(f"Rollback failed for {callback.__name__}: {e}")