import asyncio
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
List,
Optional,
Tuple,
)
from uuid import uuid4
import anyio
from openjiuwen.core.common.logging import (
LogEventType,
runner_logger as logger,
)
from openjiuwen.core.common.task_manager.context import (
_current_task_id,
get_task_group,
reset_task_group,
set_task_group,
)
from openjiuwen.core.common.task_manager.exceptions import (
DuplicateTaskError,
)
from openjiuwen.core.common.task_manager.registry import TaskRegistry
from openjiuwen.core.common.task_manager.task import Task
from openjiuwen.core.runner import Runner
from openjiuwen.core.runner.callback.events import TaskManagerEvents
class TaskManager:
"""Coroutine Task Manager
Unified manager for all coroutine tasks with structured concurrency.
Uses WeakValueDictionary for automatic cleanup when tasks are no longer referenced.
"""
_instance: Optional["TaskManager"] = None
def __new__(cls) -> "TaskManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, "_initialized"):
return
self.registry = TaskRegistry()
self._lock = asyncio.Lock()
self._initialized = True
self._callback_framework = Runner.callback_framework
logger.info("CoroutineTaskManager initialized with WeakValueDictionary auto-cleanup",
event_type=LogEventType.CORO_MANAGER_INIT,
metadata={"auto_cleanup": "WeakValueDictionary"})
@classmethod
def reset_instance(cls) -> None:
"""Reset singleton instance (for testing purposes)"""
cls._instance = None
async def create_task(
self,
coro: Coroutine,
*,
task_id: Optional[str] = None,
name: Optional[str] = None,
group: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict] = None,
catch_exceptions: bool = False,
) -> Task:
"""Create and register a new coroutine task.
Examples:
Create a basic task:
>>> async def main():
... manager = get_task_manager()
... task = await manager.create_task(some_async_function())
... print(f"Created task: {task.task_id}")
Create a task with name and group:
>>> async def main():
... manager = get_task_manager()
... task = await manager.create_task(
... process_data(),
... name="data_processor",
... group="processors"
... )
Create a task with timeout and metadata:
>>> async def main():
... manager = get_task_manager()
... task = await manager.create_task(
... fetch_data(),
... name="data_fetcher",
... timeout=30.0,
... metadata={"priority": "high", "retry": 3}
... )
Args:
coro: The coroutine to execute as a task
task_id: Optional custom task ID (auto-generated if not provided)
name: Optional friendly name for the task
group: Optional group name for task organization
timeout: Optional timeout in seconds
metadata: Optional dict of metadata to attach to the task
catch_exceptions: Whether to catch exceptions in the task
Returns:
The created Task object
"""
task_id = task_id or str(uuid4())
tg = get_task_group()
task = Task(
task_id=task_id,
name=name,
group=group,
timeout=timeout,
metadata=metadata or {},
)
async with self._lock:
if self.registry.contains(task_id):
raise DuplicateTaskError(f"Task {task_id} already exists")
task.parent_task_id = _current_task_id.get()
self.registry.add(task)
async def callback_trigger(task: Task, status: str):
event_map = {
"running": TaskManagerEvents.TASK_RUNNING,
"completed": TaskManagerEvents.TASK_COMPLETED,
"cancelled": TaskManagerEvents.TASK_CANCELLED,
"timeout": TaskManagerEvents.TASK_TIMEOUT,
"failed": TaskManagerEvents.TASK_FAILED,
}
event_type = event_map.get(status)
if event_type:
await self._trigger_event(event_type, task)
tg.start_soon(task.execute, coro, callback_trigger, catch_exceptions)
await self._trigger_event(TaskManagerEvents.TASK_CREATED, task)
logger.debug("Created task", event_type=LogEventType.CORO_MANAGER_TASK_STATUS_CHANGED,
metadata={"task_id": task_id, "name": name, "parent": task.parent_task_id,
"previous_status": "pending", "current_status": "running"})
return task
@asynccontextmanager
async def task_group(self) -> AsyncGenerator[anyio.abc.TaskGroup, None]:
"""Create and manage a task group context.
Usage:
async with manager.task_group() as tg:
await manager.create_task(coro())
Yields:
anyio.abc.TaskGroup: The created task group
"""
tg = anyio.create_task_group()
token = set_task_group(tg)
try:
async with tg:
yield tg
finally:
reset_task_group(token)
async def _cascade_cancel(self, parent_id: str, reason: str = "parent_cancelled") -> None:
children = self.registry.get_by_parent(parent_id)
for child in children:
if not child.is_terminal and child.get_cancel_scope():
child.cancelled_by = parent_id
child.abort(reason=reason)
cancel_chain = self._build_cancel_chain(child.task_id)
logger.debug("Cascade cancelled child task", event_type=LogEventType.CORO_MANAGER_TASK_STATUS_CHANGED,
metadata={"task_id": child.task_id, "display_name": child.display_name,
"chain": cancel_chain,
"previous_status": "running", "current_status": "cancelled"})
await self._cascade_cancel(child.task_id, reason=reason)
async def cascade_cancel(self, task_id: str, reason: str = "parent_cancelled") -> None:
"""Cancel a task and all its children recursively (cascade).
Examples:
Cancel a parent task and all its children:
>>> async def main():
... manager = get_task_manager()
... # Assuming parent_task is an existing task with children
... await manager.cascade_cancel(parent_task.task_id, reason="user_requested")
This will cancel the specified task and recursively cancel all its
child tasks, building a cancellation chain for debugging.
Args:
task_id: The ID of the task to cancel
reason: The reason for cancellation (default: "parent_cancelled")
"""
await self._cascade_cancel(task_id, reason=reason)
async def cancel_group(self, group: str) -> int:
"""Cancel all tasks in a specific group.
Examples:
Cancel all tasks in a group:
>>> async def main():
... manager = get_task_manager()
... # Create tasks with group
... await manager.create_task(coro(), group="my_group")
... await manager.create_task(coro(), group="my_group")
... # Cancel all tasks in the group
... count = await manager.cancel_group("my_group")
... print(f"Cancelled {count} tasks")
Args:
group: The group name to cancel
Returns:
Number of tasks that were cancelled
"""
tasks = self.registry.get_by_group(group)
count = 0
for task in tasks:
if await task.cancel(cascade=False):
count += 1
logger.info("Cancelled tasks in group", event_type=LogEventType.CORO_MANAGER_TASK_CANCELLED,
metadata={"group": group, "count": count})
return count
async def cancel_all(self) -> int:
"""Cancel all running tasks managed by this manager.
Examples:
Cancel all running tasks:
>>> async def main():
... manager = get_task_manager()
... # Create several tasks
... await manager.create_task(task1())
... await manager.create_task(task2())
... # Cancel all tasks
... count = await manager.cancel_all()
... print(f"Cancelled {count} tasks")
Returns:
Number of tasks that were cancelled
"""
tasks = self.registry.get_running()
count = 0
for task in tasks:
if await task.cancel(cascade=False):
count += 1
logger.info("Cancelled running tasks", event_type=LogEventType.CORO_MANAGER_TASK_CANCELLED,
metadata={"count": count})
return count
def _build_cancel_chain(self, task_id: str) -> str:
chain = []
current_id = task_id
while current_id:
task = self.registry.get(current_id)
if not task:
break
chain.append(f"{task.display_name}({current_id[:8]})")
current_id = task.cancelled_by
chain.reverse()
return " -> ".join(chain) if chain else task_id[:8]
def get_task_tree(self, task_id: str) -> str:
"""Get a tree representation of a task and its descendants.
Examples:
Print task tree for debugging:
>>> async def main():
... manager = get_task_manager()
... # Create a parent task with child tasks
... parent = await manager.create_task(parent_task(), name="parent")
... await manager.create_task(child_task(), name="child1")
... await manager.create_task(child_task(), name="child2")
... # Get tree representation
... tree = manager.get_task_tree(parent.task_id)
... print(tree)
Args:
task_id: The ID of the root task
Returns:
A string representation of the task tree
"""
lines: List[str] = []
self._build_tree_recursive(task_id, lines, 0)
return "\n".join(lines)
def _build_tree_recursive(self, task_id: str, lines: List[str], indent: int) -> None:
task = self.registry.get(task_id)
if not task:
return
prefix = " " * indent + ("+- " if indent else "")
status_info = f"[{task.status.value}]"
if task.cancelled_by:
cancelled_by_task = self.registry.get(task.cancelled_by)
cancelled_by_name = cancelled_by_task.display_name if cancelled_by_task else task.cancelled_by[:8]
status_info += f" (cancelled by: {cancelled_by_name}, reason: {task.cancel_reason})"
elif task.cancel_reason:
status_info += f" (reason: {task.cancel_reason})"
lines.append(f"{prefix}{task.display_name} {status_info}")
children = self.registry.get_by_parent(task_id)
for child in children:
self._build_tree_recursive(child.task_id, lines, indent + 1)
async def wait_group(
self,
group: str,
timeout: Optional[float] = None,
return_exceptions: bool = False,
) -> List[Any]:
"""Wait for all tasks in a group to complete.
Examples:
Wait for all tasks in a group:
>>> async def main():
... manager = get_task_manager()
... # Create tasks with group
... await manager.create_task(async_add(1, 2), group="math")
... await manager.create_task(async_add(3, 4), group="math")
... # Wait for all tasks in the group to complete
... results = await manager.wait_group("math")
... print(f"Results: {results}") # [3, 7]
Wait with return_exceptions=True (like asyncio.gather):
>>> async def main():
... manager = get_task_manager()
... # Create tasks where some may fail
... await manager.create_task(failing_task(), group="workers")
... await manager.create_task(working_task(), group="workers")
... # Get results including exceptions
... results = await manager.wait_group("workers", return_exceptions=True)
... # results contains Exception objects for failed tasks
Args:
group: The group name to wait for
timeout: Optional timeout in seconds
return_exceptions: If False (default), raise first exception and cancel
other tasks. If True, return exception objects instead
of raising them.
Returns:
List of results from all tasks in the group. If return_exceptions=True,
failed tasks will contain Exception objects.
"""
async with self._lock:
task_ids = self.registry.get_group_task_ids(group)
tasks = [self.registry.get(tid) for tid in task_ids if self.registry.contains(tid)]
results = []
first_exception = None
async with anyio.create_task_group() as tg:
async def wait_one(task: Task) -> None:
nonlocal first_exception
if return_exceptions:
try:
result = await task.wait()
results.append(result)
except Exception as e:
results.append(e)
else:
try:
result = await task.wait()
results.append(result)
except Exception as e:
first_exception = e
raise
for task in tasks:
tg.start_soon(wait_one, task)
if first_exception and not return_exceptions:
raise first_exception
return results
async def wait_all(
self,
timeout: Optional[float] = None,
return_exceptions: bool = False,
) -> List[Any]:
"""Wait for all tasks managed by this manager to complete.
Examples:
Wait for all tasks:
>>> async def main():
... manager = get_task_manager()
... # Create several tasks
... await manager.create_task(async_work(1))
... await manager.create_task(async_work(2))
... # Wait for all tasks to complete
... results = await manager.wait_all()
... print(f"All results: {results}")
Wait with return_exceptions=True (like asyncio.gather):
>>> async def main():
... manager = get_task_manager()
... # Create tasks where some may fail
... await manager.create_task(failing_task())
... await manager.create_task(working_task())
... # Get results including exceptions
... results = await manager.wait_all(return_exceptions=True)
... # results contains Exception objects for failed tasks
Args:
timeout: Optional timeout in seconds for all tasks
return_exceptions: If False (default), raise first exception and cancel
other tasks. If True, return exception objects instead
of raising them.
Returns:
List of results from all tasks. If return_exceptions=True, failed tasks
will contain Exception objects.
"""
async with self._lock:
task_ids = list(self.registry.keys())
results = []
first_exception = None
for task_id in task_ids:
task = self.registry.get(task_id)
if not task:
results.append(None)
continue
if first_exception and not return_exceptions:
await task.cancel(reason="other_task_failed")
results.append(None)
continue
try:
if timeout:
with anyio.fail_after(timeout):
result = await task.wait()
else:
result = await task.wait()
results.append(result)
except Exception as e:
if return_exceptions:
results.append(e)
else:
first_exception = e
results.append(None)
for remaining_id in task_ids[task_ids.index(task_id) + 1:]:
remaining_task = self.registry.get(remaining_id)
if remaining_task:
await remaining_task.cancel(reason="other_task_failed")
if first_exception and not return_exceptions:
raise first_exception
return results
async def as_completed(
self,
tasks: List[Task],
timeout: Optional[float] = None
) -> AsyncIterator[Tuple[Task, Any]]:
"""Return an iterator that yields completed tasks as they finish.
Similar to asyncio.as_completed. Yields (task, result) tuples,
with completed tasks yielded first.
Args:
tasks: List of Task objects to wait on
timeout: Optional timeout in seconds for all tasks
Yields:
Tuples of (task, result) as tasks complete
"""
if not tasks:
return
result_queue: anyio.abc.Queue[Tuple[Task, Any]] = anyio.create_queue(len(tasks))
completed_count = 0
total_tasks = len(tasks)
async def wait_task(task: Task) -> None:
"""Wait for a single task to complete and put result in queue."""
nonlocal completed_count
try:
result = await task.wait()
await result_queue.put((task, result))
except Exception as e:
await result_queue.put((task, e))
finally:
completed_count += 1
tg = anyio.create_task_group()
for task in tasks:
tg.start_soon(wait_task, task)
tg.start()
try:
while completed_count < total_tasks:
try:
if timeout is not None:
with anyio.fail_after(timeout):
result = await result_queue.get()
else:
result = await result_queue.get()
yield result
except anyio.TimeoutError as e:
raise TimeoutError(f"as_completed() timed out after {timeout} second(s)") from e
finally:
tg.cancel_scope.cancel()
async def _trigger_event(self, event_type: str, task: Task) -> None:
"""Trigger an event through the callback framework."""
if self._callback_framework:
await self._callback_framework.trigger(
event_type,
task=task,
task_id=task.task_id,
name=task.name,
group=task.group,
status=task.status.value if task.status else None,
)
async def on(self, event_type: str, callback: Callable) -> None:
"""Register a callback for an event.
Examples:
Register callbacks for task events:
>>> async def on_completed(task):
... print(f"Task {task.name} completed!")
...
>>> async def main():
... manager = get_task_manager()
... await manager.on("task_completed", on_completed)
... await manager.create_task(some_task(), name="my_task")
Using convenience methods:
>>> async def main():
... manager = get_task_manager()
... await manager.on_completed(lambda t: print(f"Completed: {t.name}"))
... await manager.on_failed(lambda t: print(f"Failed: {t.name}"))
Args:
event_type: Event name string (use TaskManagerEvents constants)
callback: Callback function to register
"""
if self._callback_framework:
self._callback_framework.register_sync(event_type, callback)
async def on_created(self, callback: Callable[[Task], Awaitable[None]]) -> None:
await self.on(TaskManagerEvents.TASK_CREATED, callback)
async def on_running(self, callback: Callable[[Task], Awaitable[None]]) -> None:
await self.on(TaskManagerEvents.TASK_RUNNING, callback)
async def on_completed(self, callback: Callable[[Task], Awaitable[None]]) -> None:
await self.on(TaskManagerEvents.TASK_COMPLETED, callback)
async def on_failed(self, callback: Callable[[Task], Awaitable[None]]) -> None:
await self.on(TaskManagerEvents.TASK_FAILED, callback)
async def on_cancelled(self, callback: Callable[[Task], Awaitable[None]]) -> None:
await self.on(TaskManagerEvents.TASK_CANCELLED, callback)
async def on_timeout(self, callback: Callable[[Task], Awaitable[None]]) -> None:
await self.on(TaskManagerEvents.TASK_TIMEOUT, callback)
async def off(self, event_type: str, callback: Callable) -> None:
"""Unregister a callback from an event.
Examples:
Unregister a callback:
>>> async def on_completed(task):
... print(f"Task completed!")
...
>>> async def main():
... manager = get_task_manager()
... await manager.on("task_completed", on_completed)
... # Later, unregister the callback
... await manager.off("task_completed", on_completed)
Args:
event_type: Event name string (use TaskManagerEvents constants)
callback: Callback function to unregister
"""
if self._callback_framework:
await self._callback_framework.unregister(event_type, callback)
def _remove_task_unsafe(self, task_id: str) -> None:
self.registry.remove_unsafe(task_id)
async def remove_task(self, task_id: str) -> bool:
"""Remove a task from the registry.
Examples:
Remove a specific task:
>>> async def main():
... manager = get_task_manager()
... task = await manager.create_task(some_task())
... # Remove the task from registry
... removed = await manager.remove_task(task.task_id)
... print(f"Task removed: {removed}")
Args:
task_id: The ID of the task to remove
Returns:
True if the task was found and removed, False otherwise
"""
async with self._lock:
if not self.registry.contains(task_id):
return False
self.registry.remove_unsafe(task_id)
return True
async def remove_completed(self) -> int:
"""Remove all completed tasks from the registry.
Examples:
Clean up completed tasks:
>>> async def main():
... manager = get_task_manager()
... # Create and run some tasks
... await manager.create_task(some_task())
... await manager.wait_all()
... # Remove all completed tasks
... count = await manager.remove_completed()
... print(f"Removed {count} completed tasks")
Returns:
Number of completed tasks that were removed
"""
async with self._lock:
completed = [tid for tid, t in self.registry.items() if t.is_terminal]
for task_id in completed:
self.registry.remove_unsafe(task_id)
return len(completed)
def get_task_manager() -> TaskManager:
"""Get the global TaskManager singleton instance.
Examples:
Get the singleton instance:
>>> async def main():
... manager = get_task_manager()
... task = await manager.create_task(async_work())
Note: This returns the same instance throughout the application,
making it easy to share task management across modules.
Returns:
The global TaskManager singleton instance
"""
return TaskManager()
async def create_task(
coro: Coroutine,
*,
task_id: Optional[str] = None,
name: Optional[str] = None,
group: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict] = None,
catch_exceptions: bool = False,
) -> Task:
"""Create a task using the global TaskManager (convenience function).
Examples:
Create a task using the convenience function:
>>> async def main():
... task = await create_task(
... fetch_data(),
... name="data_fetcher",
... group="io_tasks",
... timeout=30.0
... )
Args:
coro: The coroutine to execute as a task
task_id: Optional custom task ID (auto-generated if not provided)
name: Optional friendly name for the task
group: Optional group name for task organization
timeout: Optional timeout in seconds
metadata: Optional dict of metadata to attach to the task
catch_exceptions: Whether to catch exceptions in the task
Returns:
The created Task object
"""
return await get_task_manager().create_task(
coro,
task_id=task_id,
name=name,
group=group,
timeout=timeout,
metadata=metadata,
catch_exceptions=catch_exceptions,
)
async def cancel_group(group: str) -> int:
"""Cancel all tasks in a group (convenience function).
Examples:
Cancel all tasks in a group:
>>> async def main():
... await create_task(task1(), group="workers")
... await create_task(task2(), group="workers")
... count = await cancel_group("workers")
... print(f"Cancelled {count} tasks")
Args:
group: The group name to cancel
Returns:
Number of tasks that were cancelled
"""
return await get_task_manager().cancel_group(group)
async def cancel_all() -> int:
"""Cancel all running tasks (convenience function).
Examples:
Cancel all running tasks:
>>> async def main():
... await create_task(task1())
... await create_task(task2())
... count = await cancel_all()
... print(f"Cancelled {count} tasks")
Returns:
Number of tasks that were cancelled
"""
return await get_task_manager().cancel_all()
def print_task_tree(task_id: Optional[str] = None) -> None:
"""Print the task tree for debugging (convenience function).
Examples:
Print tree for a specific task:
>>> async def main():
... task = await create_task(parent_task())
... await create_task(child_task(), parent_id=task.task_id)
... print_task_tree(task.task_id)
Print all task trees:
>>> async def main():
... # Creates several tasks
... print_task_tree() # Prints all root tasks and their children
Args:
task_id: Optional task ID to print tree for. If None, prints all root tasks.
"""
manager = get_task_manager()
if task_id:
task_tree = manager.get_task_tree(task_id)
logger.debug("Task tree", event_type=LogEventType.CORO_MANAGER_DEBUG_TASK_TREE,
metadata={"task_id": task_id, "tree": task_tree})
else:
all_tasks = manager.registry.get_all()
root_tasks = [t for t in all_tasks if not t.parent_task_id]
if not root_tasks:
logger.debug("No tasks found", event_type=LogEventType.CORO_MANAGER_DEBUG_TASK_TREE)
return
for task in root_tasks:
task_tree = manager.get_task_tree(task.task_id)
logger.debug("Task tree", event_type=LogEventType.CORO_MANAGER_DEBUG_TASK_TREE,
metadata={"task_id": task.task_id, "tree": task_tree})
async def as_completed(
tasks: List[Task],
timeout: Optional[float] = None
) -> AsyncIterator[Tuple[Task, Any]]:
"""Return an iterator that yields completed tasks as they finish (convenience function).
Similar to asyncio.as_completed. Yields (task, result) tuples,
with completed tasks yielded first.
Args:
tasks: List of Task objects to wait on
timeout: Optional timeout in seconds for all tasks
Yields:
Tuples of (task, result) as tasks complete
"""
async for task, result in get_task_manager().as_completed(tasks, timeout):
yield task, result