import json
from unittest.mock import AsyncMock, patch
import pytest
from openjiuwen_deepsearch.framework.openjiuwen.agent.workflow import DeepresearchAgent
from openjiuwen_deepsearch.utils.constants_utils.session_contextvars import (
local_search_context,
session_context,
web_search_context,
)
from tests.utils.mock_config import get_default_agent_config
@pytest.mark.asyncio
async def test_run_keeps_workflow_llm_usage_for_non_terminal_stream():
"""验证未到 ALL END 的 run 不会清理 conversation 级 token 累计。"""
agent = DeepresearchAgent()
agent.agent.release_session = AsyncMock()
agent._release_checkpointer_session = AsyncMock()
agent_config = get_default_agent_config()
agent_config["stats_info_llm"] = True
conversation_id = "workflow-token-lifecycle"
async def _fake_consume_stream_chunks(*args, **kwargs):
"""模拟一次等待用户交互的非终态流式执行。"""
yield '{"event":"waiting_user_input"}', False, {}
def _fake_initialize_tools(_):
"""返回可安全 reset 的 context token。"""
return web_search_context.set({}), local_search_context.set({})
with patch.object(agent, "_initialize_tools", side_effect=_fake_initialize_tools), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.create_llm_obj",
return_value=object(),
), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepresearchAgent._consume_stream_chunks",
side_effect=_fake_consume_stream_chunks,
), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.pop_workflow_llm_usage",
) as mock_pop_usage, patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.record_interface_log",
):
outputs = []
async for item in agent.run(
message="hello",
conversation_id=conversation_id,
agent_config=agent_config,
report_template="",
interrupt_feedback="",
):
outputs.append(item)
assert outputs == ['{"event":"waiting_user_input"}']
mock_pop_usage.assert_not_called()
@pytest.mark.asyncio
async def test_run_pops_workflow_llm_usage_when_stream_reaches_all_end():
"""验证到达 ALL END 时仍会清理 conversation 级 token 累计。"""
agent = DeepresearchAgent()
agent.agent.release_session = AsyncMock()
agent._release_checkpointer_session = AsyncMock()
agent_config = get_default_agent_config()
agent_config["stats_info_llm"] = True
conversation_id = "workflow-token-all-end"
async def _fake_consume_stream_chunks(*args, **kwargs):
"""模拟一次到达终态的流式执行。"""
yield '{"event":"summary_response"}', True, {}
def _fake_initialize_tools(_):
"""返回可安全 reset 的 context token。"""
return web_search_context.set({}), local_search_context.set({})
with patch.object(agent, "_initialize_tools", side_effect=_fake_initialize_tools), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.create_llm_obj",
return_value=object(),
), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepresearchAgent._consume_stream_chunks",
side_effect=_fake_consume_stream_chunks,
), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.pop_workflow_llm_usage",
) as mock_pop_usage, patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.record_interface_log",
):
outputs = []
async for item in agent.run(
message="hello",
conversation_id=conversation_id,
agent_config=agent_config,
report_template="",
interrupt_feedback="",
):
outputs.append(item)
assert outputs == ['{"event":"summary_response"}']
mock_pop_usage.assert_called_once_with(conversation_id)
@pytest.mark.asyncio
async def test_run_uses_persisted_workflow_usage_snapshot_when_local_usage_empty():
"""验证恢复场景异常退出时会回退 session 中已持久化的 token 快照。"""
agent = DeepresearchAgent()
agent.agent.release_session = AsyncMock()
agent._release_checkpointer_session = AsyncMock()
agent_config = get_default_agent_config()
agent_config["stats_info_llm"] = True
conversation_id = "workflow-token-resume"
empty_usage = {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"llm_call_count": 0,
"agent_name_token_usage": [],
}
snapshot_usage = {
"input_tokens": 12,
"output_tokens": 9,
"total_tokens": 21,
"llm_call_count": 4,
"agent_name_token_usage": [
{
"agent_name": "entry",
"input_tokens": 12,
"output_tokens": 9,
"total_tokens": 21,
"llm_call_count": 4,
}
],
}
async def _fake_streaming(*args, **kwargs):
"""模拟流式执行异常,触发 run 的异常分支。"""
raise RuntimeError("mock streaming error")
yield
def _fake_initialize_tools(_):
"""返回可安全 reset 的 context token。"""
return web_search_context.set({}), local_search_context.set({})
class _FakeSession:
"""模拟 session 快照读取对象。"""
def get_global_state(self, key):
"""读取持久化快照。"""
if key == "search_context.final_result.workflow_llm_token_usage":
return snapshot_usage
return None
session_token = session_context.set(_FakeSession())
try:
with patch.object(agent, "_initialize_tools", side_effect=_fake_initialize_tools), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.create_llm_obj",
return_value=object(),
), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_agent_streaming",
side_effect=_fake_streaming,
), patch(
"openjiuwen_deepsearch.utils.common_utils.llm_utils.get_workflow_llm_usage",
return_value=empty_usage,
), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.pop_workflow_llm_usage",
) as mock_pop_usage, patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.record_interface_log",
):
outputs = []
async for item in agent.run(
message='{"action":"finish"}',
conversation_id=conversation_id,
agent_config=agent_config,
report_template="",
interrupt_feedback="",
):
outputs.append(item)
finally:
session_context.reset(session_token)
assert len(outputs) == 2
error_payload = json.loads(outputs[0])
error_content = json.loads(error_payload["content"])
assert error_content["workflow_llm_token_usage"] == snapshot_usage
mock_pop_usage.assert_called_once_with(conversation_id)