import asyncio
import json
import logging
import re
from typing import Any, Dict, Tuple
from openjiuwen_deepsearch.algorithm.search_agent.deepsearch_agent import unwrap_workflow_result_payload
from openjiuwen_deepsearch.algorithm.search_nodes.llm_utils import RunLLMConfig, run_llm
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import State
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
logger = logging.getLogger(__name__)
QUESTION_CLUES_REF_PATTERN = r"\[([0-9]+_[a-z0-9_]+)\]"
def parse_and_validate_init_state(content: str):
try:
content = re.sub(
r"<(think|reasoning|thinking)\b[^>]*>.*?</\1>",
"",
content,
flags=re.IGNORECASE | re.DOTALL,
).strip()
init_state_dict = unwrap_workflow_result_payload(json.loads(content), "answer_variable")
if "variables" in init_state_dict and "state" not in init_state_dict:
init_state_dict["state"] = init_state_dict.pop("variables")
init_state = State(**init_state_dict)
except Exception as e:
raise CustomValueException(
StatusCode.INIT_STATE_FAILED.code,
StatusCode.INIT_STATE_FAILED.errmsg.format(e=e),
) from e
if not init_state:
raise CustomValueException(
StatusCode.INIT_STATE_FAILED.code,
StatusCode.INIT_STATE_FAILED.errmsg.format(e="invalid or empty state"),
)
init_state.id = "0"
init_state.depth = 0
return init_state
async def run_initialize_state(
llm_config: dict,
query: str,
*,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
) -> Dict[str, Any]:
context_vars = {
"messages": [
{
"role": "user",
"content": ("Please apply the above instructions to the following query:\n" f"{query}"),
}
]
}
llm_result, _, input_tokens, output_tokens = await run_llm(
RunLLMConfig(
config=llm_config,
prompt_template_file="deepsearch_initialize_state",
context_vars=context_vars,
need_stream_out=False,
agent_name=AgentLlmName.INITIAL_STATE.value,
)
)
total_input_tokens = total_input_tokens + input_tokens
total_output_tokens = total_output_tokens + output_tokens
content = llm_result.get("content") if isinstance(llm_result, dict) else llm_result
init_state = parse_and_validate_init_state(content)
logger.info(
"[initialize_state] parsed State: %s",
"***" if LogManager.is_sensitive() else init_state,
)
messages = context_vars["messages"] + [{"role": "assistant", "content": content}]
return dict(
init_state=init_state,
total_input_tokens=total_input_tokens,
total_output_tokens=total_output_tokens,
success=True,
messages=messages,
)