"""Runtime contexts for auto-harness pipelines and stages."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from openjiuwen.auto_harness.schema import (
Experience,
OptimizationTask,
)
from openjiuwen.core.session.stream.base import (
OutputSchema,
)
if TYPE_CHECKING:
from openjiuwen.auto_harness.orchestrator import (
AutoHarnessOrchestrator,
)
def task_key(task: OptimizationTask) -> str:
"""Return the scoped artifact key for a task."""
return task.topic or "task"
@dataclass
class TaskRuntime:
"""Prepared task-scoped execution dependencies."""
related: list[Experience]
wt_path: str
edit_safety_rail: Any
preexisting_dirty_files: list[str]
task_agent: Any
commit_agent: Any
task_session: Any = None
fix_agent: Any = None
@dataclass
class BaseExecutionContext:
"""Shared execution context surface."""
orchestrator: "AutoHarnessOrchestrator"
@property
def task_id(self) -> str:
return ""
def get_artifact(
self,
name: str,
*,
default: Any = None,
) -> Any:
return self.orchestrator.artifacts.get(
name,
task_id=self.task_id,
default=default,
)
def require_artifact(self, name: str) -> Any:
return self.orchestrator.artifacts.require(
name,
task_id=self.task_id,
)
def put_artifact(self, name: str, value: Any) -> None:
self.orchestrator.artifacts.put(
name,
value,
task_id=self.task_id,
)
def put_artifacts(
self,
artifacts: dict[str, Any],
) -> None:
self.orchestrator.artifacts.put_many(
artifacts,
task_id=self.task_id,
)
@staticmethod
def message(
text: str,
*,
stage: str = "",
) -> OutputSchema:
"""Build a progress message OutputSchema."""
payload: dict[str, Any] = {"content": text}
if stage:
payload["stage"] = stage
return OutputSchema(
type="message",
index=0,
payload=payload,
)
@staticmethod
def stage_result_output(
*,
stage: str,
status: str,
error: str = "",
messages: list[str] | None = None,
metrics: dict[str, Any] | None = None,
) -> OutputSchema:
"""Build a stage_result OutputSchema."""
return OutputSchema(
type="stage_result",
index=0,
payload={
"stage": stage,
"status": status,
"error": error,
"messages": messages or [],
"metrics": metrics or {},
},
)
@dataclass
class SessionContext(BaseExecutionContext):
"""Runtime context passed into session pipelines and stages."""
@dataclass
class TaskContext(SessionContext):
"""Runtime context passed into task pipelines and stages."""
task: OptimizationTask
runtime: TaskRuntime
@property
def task_id(self) -> str:
return task_key(self.task)