import copy
import json
from abc import abstractmethod
from datetime import datetime
from enum import Enum
from typing import Any
from dateutil.tz import tzlocal
from openjiuwen.core.common.exception.codes import StatusCode
from openjiuwen.core.common.exception.errors import BaseError
from openjiuwen.core.common.logging import session_logger, LogEventType
from openjiuwen.core.session.stream.manager import StreamWriterManager
from openjiuwen.core.session.tracer.data import InvokeType, NodeStatus
from openjiuwen.core.session.tracer.span import Span, TraceAgentSpan, TraceWorkflowSpan
from openjiuwen.core.session.tracer.span import SpanManager
from openjiuwen.core.graph.pregel import GraphInterrupt
class TracerHandlerName(Enum):
"""
Trigger handler name.
"""
TRACE_AGENT = "tracer_agent"
TRACER_WORKFLOW = "tracer_workflow"
class TraceBaseHandler:
def __init__(self, stream_writer_manager: StreamWriterManager, span_manager: SpanManager):
self._stream_writer = stream_writer_manager.get_trace_writer()
self._span_manager = span_manager
async def emit_stream_writer(self, data):
await self._emit_stream_writer(data)
@abstractmethod
def _format_data(self, span: Span) -> dict:
...
async def _emit_stream_writer(self, span):
if self._stream_writer is None:
return
await self._stream_writer.write(self._format_data(span))
async def _send_data(self, span, exclude=None):
if exclude:
clean_dict = span.model_dump(exclude=exclude)
await self.emit_stream_writer(type(span).model_validate(clean_dict))
return
await self.emit_stream_writer(copy.deepcopy(span))
@staticmethod
def _get_elapsed_time(start_time: datetime, end_time: datetime) -> str:
"""get elapsed time"""
elapsed_time = end_time - start_time
ms = elapsed_time.total_seconds() * 1000
if ms < 1000:
return f"{ms:.0f}ms"
return f"{(ms / 1000):.2f}s"
@staticmethod
def _get_node_status(span: Span) -> str:
if span.error:
return NodeStatus.ERROR.value
inner_error = getattr(span, "inner_error", None)
if inner_error:
return NodeStatus.ERROR.value
if span.on_invoke_data:
return NodeStatus.RUNNING.value if not span.end_time else NodeStatus.FINISH.value
if span.end_time:
return NodeStatus.FINISH.value
return NodeStatus.START.value
class TraceAgentHandler(TraceBaseHandler):
def __init__(self, stream_writer_manager, span_manager):
super().__init__(stream_writer_manager, span_manager)
def event_name(self):
return TracerHandlerName.TRACE_AGENT.value
def _format_data(self, span: TraceAgentSpan) -> dict:
if span.status != NodeStatus.INTERRUPTED.value:
span.status = self._get_node_status(span)
return {"type": self.event_name(), "payload": span.model_dump(by_alias=True)}
def _get_tracer_agent_span(self, invoke_id: str) -> TraceAgentSpan:
span = self._span_manager.get_span(invoke_id)
if span is not None:
return span
return self._span_manager.create_agent_span(self._span_manager.last_span)
def _update_start_trace_data(self, span: TraceAgentSpan, invoke_type: str, inputs: Any, instance_info: dict,
**kwargs):
try:
meta_data = json.loads(
json.dumps(instance_info, ensure_ascii=False,
default=lambda _obj: f"<<no-serializable: {type(_obj).__qualname__}>>")
)
except json.decoder.JSONDecodeError as err:
session_logger.error(
"Failed to process metadata for trace",
event_type=LogEventType.SYSTEM_ERROR,
metadata={"error": str(err), "instance_info": str(instance_info)}
)
raise ValueError(f"meta_data error: Decoder error") from err
update_data = {
"start_time": datetime.now(tz=tzlocal()).replace(tzinfo=None),
"invoke_type": invoke_type,
"inputs": inputs,
"instance_info": instance_info,
"name": instance_info["class_name"],
"meta_data": meta_data
}
self._span_manager.update_span(span, update_data)
def _update_end_trace_data(self, span: TraceAgentSpan, outputs, **kwargs):
end_time = datetime.now(tz=tzlocal()).replace(tzinfo=None)
elapsed_time = self._get_elapsed_time(span.start_time, end_time) if span.start_time else None
update_data = {
"end_time": end_time,
"outputs": outputs
}
if elapsed_time is not None:
update_data["elapsed_time"] = elapsed_time
self._span_manager.update_span(span, update_data)
def _update_error_trace_data(self, span: TraceAgentSpan, error, **kwargs):
end_time = datetime.now(tz=tzlocal()).replace(tzinfo=None)
if isinstance(error, BaseError):
error_info = {"error_code": error.status.code, "message": error.message}
else:
error_info = {"error_code": StatusCode.WORKFLOW_EXECUTION_ERROR.code,
"message": StatusCode.WORKFLOW_EXECUTION_ERROR.errmsg.format(reason=str(error), workflow="")}
elapsed_time = self._get_elapsed_time(span.start_time, end_time) if span.start_time else None
update_data = {
"end_time": end_time,
"error": error_info
}
if elapsed_time is not None:
update_data["elapsed_time"] = elapsed_time
self._span_manager.update_span(span, update_data)
def _update_running_trace_data(self, span: TraceAgentSpan, **kwargs):
if not isinstance(span.on_invoke_data, list):
span.on_invoke_data = []
span.on_invoke_data.append(kwargs)
self._span_manager.update_span(span, {})
async def on_chain_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs):
self._update_start_trace_data(invoke_type=InvokeType.CHAIN.value, span=span, inputs=inputs,
instance_info=instance_info, **kwargs)
await self._send_data(span)
async def on_chain_end(self, span: TraceAgentSpan, outputs, **kwargs):
self._update_end_trace_data(span=span, outputs=outputs, **kwargs)
await self._send_data(span)
async def on_chain_error(self, span: TraceAgentSpan, error, **kwargs):
self._update_error_trace_data(span=span, error=error, **kwargs)
await self._send_data(span)
async def on_llm_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs):
self._update_start_trace_data(invoke_type=InvokeType.LLM.value, span=span, inputs=inputs,
instance_info=instance_info, **kwargs)
await self._send_data(span)
async def on_llm_request(self, span: TraceAgentSpan, **kwargs):
self._update_running_trace_data(span=span, **kwargs)
await self._send_data(span)
async def on_llm_end(self, span: TraceAgentSpan, outputs, **kwargs):
self._update_end_trace_data(span=span, outputs=outputs, **kwargs)
await self._send_data(span)
async def on_llm_error(self, span: TraceAgentSpan, error, **kwargs):
self._update_error_trace_data(span=span, error=error, **kwargs)
await self._send_data(span)
async def on_prompt_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs):
self._update_start_trace_data(invoke_type=InvokeType.PROMPT.value, span=span, inputs=inputs,
instance_info=instance_info, **kwargs)
await self._send_data(span)
async def on_prompt_end(self, span: TraceAgentSpan, outputs, **kwargs):
self._update_end_trace_data(span=span, outputs=outputs, **kwargs)
await self._send_data(span)
async def on_prompt_error(self, span: TraceAgentSpan, error, **kwargs):
self._update_error_trace_data(span=span, error=error, **kwargs)
await self._send_data(span)
async def on_plugin_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs):
self._update_start_trace_data(invoke_type=InvokeType.PLUGIN.value, span=span, inputs=inputs,
instance_info=instance_info, **kwargs)
await self._send_data(span)
async def on_plugin_end(self, span: TraceAgentSpan, outputs, **kwargs):
self._update_end_trace_data(span=span, outputs=outputs, **kwargs)
await self._send_data(span)
async def on_plugin_error(self, span: TraceAgentSpan, error, **kwargs):
self._update_error_trace_data(span=span, error=error, **kwargs)
await self._send_data(span)
async def on_retriever_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs):
self._update_start_trace_data(invoke_type=InvokeType.RETRIEVER.value, span=span, inputs=inputs,
instance_info=instance_info, **kwargs)
await self._send_data(span)
async def on_retriever_end(self, span: TraceAgentSpan, outputs, **kwargs):
self._update_end_trace_data(span=span, outputs=outputs, **kwargs)
await self._send_data(span)
async def on_retriever_error(self, span: TraceAgentSpan, error, **kwargs):
self._update_error_trace_data(span=span, error=error, **kwargs)
await self._send_data(span)
async def on_evaluator_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs):
self._update_start_trace_data(invoke_type=InvokeType.EVALUATOR.value, span=span, inputs=inputs,
instance_info=instance_info, **kwargs)
await self._send_data(span)
async def on_evaluator_end(self, span: TraceAgentSpan, outputs, **kwargs):
self._update_end_trace_data(span=span, outputs=outputs, **kwargs)
await self._send_data(span)
async def on_evaluator_error(self, span: TraceAgentSpan, error, **kwargs):
self._update_error_trace_data(span=span, error=error, **kwargs)
await self._send_data(span)
async def on_workflow_start(self, span: TraceAgentSpan, inputs: Any, instance_info: dict, **kwargs):
self._update_start_trace_data(invoke_type=InvokeType.WORKFLOW.value, span=span, inputs=inputs,
instance_info=instance_info, **kwargs)
await self._send_data(span)
async def on_workflow_end(self, span: TraceAgentSpan, outputs, **kwargs):
self._update_end_trace_data(span=span, outputs=outputs, **kwargs)
await self._send_data(span)
async def on_workflow_error(self, span: TraceAgentSpan, error, **kwargs):
self._update_error_trace_data(span=span, error=error, **kwargs)
await self._send_data(span)
class TraceWorkflowHandler(TraceBaseHandler):
def __init__(self, stream_writer_manager, span_manager):
super().__init__(stream_writer_manager, span_manager)
def event_name(self) -> str:
return TracerHandlerName.TRACER_WORKFLOW.value
def _format_data(self, span: TraceWorkflowSpan) -> dict:
if span.status != NodeStatus.INTERRUPTED.value:
span.status = self._get_node_status(span)
result = span.model_dump(exclude_none=True, by_alias=True, exclude={"child_invokes_id", "llm_invoke_data"})
return {"type": self.event_name(),
"payload": result}
def _get_tracer_workflow_span(self, invoke_id: str) -> TraceWorkflowSpan:
span = self._span_manager.get_span(invoke_id)
if span is not None:
return span
return self._span_manager.create_workflow_span(invoke_id, self._span_manager.last_span)
async def on_call_start(self, invoke_id: str, metadata: dict = None, inputs: Any = None,
need_send: bool = False, source_ids: list = None,
**kwargs):
span = self._get_tracer_workflow_span(invoke_id)
update_data = {
"start_time": datetime.now(tz=tzlocal()).replace(tzinfo=None),
"invoke_type": type,
"on_invoke_data": [],
"inputs": inputs,
"outputs": None,
"stream_outputs": [],
"source_ids": source_ids,
**metadata
}
self._span_manager.update_span(span, update_data)
if need_send:
await self._send_data(span)
async def on_pre_invoke(self, invoke_id: str, inputs: Any, component_metadata: dict, need_send: bool = False,
**kwargs):
span = self._get_tracer_workflow_span(invoke_id)
update_data = {
"inputs": inputs,
**component_metadata
}
self._span_manager.update_span(span, update_data)
if need_send:
await self._send_data(span, exclude={"outputs", "stream_outputs"})
async def on_pre_stream(self, invoke_id: str, chunk, need_send: bool = False, **kwargs):
span = self._get_tracer_workflow_span(invoke_id)
if chunk and isinstance(chunk, dict):
span.append_stream_inputs(chunk)
if need_send:
await self._send_data(span, exclude={"outputs", "stream_outputs"})
async def on_invoke(self, invoke_id: str, on_invoke_data: dict = None, exception: Exception = None, **kwargs):
span = self._get_tracer_workflow_span(invoke_id)
update_data = {}
end_time = datetime.now(tz=tzlocal()).replace(tzinfo=None)
if exception is not None:
if isinstance(exception, BaseError):
span.error = {"error_code": exception.code, "message": exception.message}
elif isinstance(exception, GraphInterrupt):
span.status = NodeStatus.INTERRUPTED.value
else:
span.error = {"error_code": StatusCode.WORKFLOW_EXECUTION_ERROR.code,
"message": StatusCode.WORKFLOW_EXECUTION_ERROR.errmsg.format(reason=str(exception))}
if on_invoke_data:
if isinstance(on_invoke_data, dict) and "inner_error" in on_invoke_data:
span.inner_error = on_invoke_data["inner_error"]
span.on_invoke_data.append(on_invoke_data)
else:
span.on_invoke_data.append(on_invoke_data)
update_data = {
"end_time": end_time
}
elapsed_time = self._get_elapsed_time(span.start_time, end_time) if span.start_time else None
if elapsed_time is not None:
update_data["elapsed_time"] = elapsed_time
else:
if not isinstance(span.on_invoke_data, list):
span.on_invoke_data = []
if isinstance(on_invoke_data, dict) and "inner_error" in on_invoke_data:
span.inner_error = on_invoke_data["inner_error"]
span.on_invoke_data.append(on_invoke_data)
else:
span.on_invoke_data.append(on_invoke_data)
self._span_manager.update_span(span, update_data)
await self._send_data(span)
if exception and span.component_type == "LLM":
self._span_manager.update_span(span, {})
async def on_post_stream(self, invoke_id: str, chunk, **kwargs):
span = self._get_tracer_workflow_span(invoke_id)
span.append_stream_output(chunk)
async def on_post_invoke(self, invoke_id: str, outputs, inputs=None, **kwargs):
span = self._get_tracer_workflow_span(invoke_id)
update_data = {
"outputs": outputs,
}
if inputs and span.component_type in ["End", "Message"]:
span.inputs = inputs
self._span_manager.update_span(span, update_data)
async def on_call_done(self, invoke_id, outputs: Any = None, **kwargs):
span = self._get_tracer_workflow_span(invoke_id)
end_time = datetime.now(tz=tzlocal()).replace(tzinfo=None)
elapsed_time = self._get_elapsed_time(span.start_time, end_time) if span.start_time else None
update_data = {"end_time": end_time, "outputs": outputs} if outputs is not None else {"end_time": end_time}
if elapsed_time is not None:
update_data["elapsed_time"] = elapsed_time
self._span_manager.update_span(span, update_data)
await self._send_data(span)
if span.component_type == "End" and span.end_time:
self._span_manager.update_span(span, {})
async def on_interact(self, invoke_id: str, inputs: Any, component_metadata: dict, need_send: bool = False,
**kwargs):
span = self._get_tracer_workflow_span(invoke_id)
update_data = {
"interactive_inputs": inputs,
**component_metadata
}
self._span_manager.update_span(span, update_data)
if need_send:
await self._send_data(span, exclude={"outputs", "stream_outputs"})