import json
import logging
import os
import random
from collections import deque
from typing import Deque, List, Optional
from openjiuwen_deepsearch.config.config import ActionSamplingConfig
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import Action, Result
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
from openjiuwen_deepsearch.utils.run_telemetry import emit
logger = logging.getLogger(__name__)
MAX_STATE_DEPTH = 12
DEPTH_PENALTY_THRESHOLD = 5
DEPTH_PENALTY_DENOMINATOR = 8
SHALLOW_DEPTH_THRESHOLD = 2
SHALLOW_DEPTH_BONUS = 100
MULTIPLE_STATES_DIVISOR_THRESHOLD = 1
class ActionPool:
"""
Pool of actions for the DeepSearch agent. Used from a single coroutine only
(no thread-safety; not needed under asyncio).
_pool: pending actions
immediate_queue: actions placed here are served first (FIFO) when sample()
is called, bypassing the scored sampling logic entirely.
"""
def __init__(self):
self._pool: List[Action] = []
self.running_actions: List[Action] = []
self.completed_actions: List[tuple[Action, Optional[Result]]] = []
self.successfully_completed_actions: List[tuple[Action, Result]] = []
self.immediate_queue: Deque[Action] = deque()
self.config: ActionSamplingConfig = ActionSamplingConfig()
self.log_dir: str = ""
self._score_cache: dict[str, float] = {}
@staticmethod
def _state_hash(action: Action) -> str:
return "".join(str(v.candidate).lower() for v in action.state.state)
def _compute_base_score(self, action: Action) -> float:
"""Full base score: proposal score + candidate strengths + depth adjustments.
The pool-relative promote_unique_states divisor is intentionally excluded
because it is a sampling weight, not an intrinsic action property.
"""
score = action.proposal.score
for v in action.state.state:
if v.candidate is not None:
score += v.candidate_strength or 0.0
depth = action.state.depth
if self.config.depth_weight:
if depth > DEPTH_PENALTY_THRESHOLD:
score *= (MAX_STATE_DEPTH - depth) / DEPTH_PENALTY_DENOMINATOR
if depth < SHALLOW_DEPTH_THRESHOLD:
score += SHALLOW_DEPTH_BONUS
return score
def _get_score(self, action: Action) -> float:
"""Return the cached base score, computing and storing it on first access."""
if action.id not in self._score_cache:
self._score_cache[action.id] = self._compute_base_score(action)
return self._score_cache[action.id]
def _action_snapshot(self, action: Action) -> dict:
return {
"id": action.id,
"direction": action.proposal.direction,
"score": self._get_score(action),
"depth": action.state.depth,
}
def _save_pool_json(self) -> None:
"""Write a live snapshot of the pool state to action_pool.json in log_dir."""
if not self.log_dir:
return
try:
snapshot = {
"pending": [self._action_snapshot(a) for a in self._pool],
"running": [self._action_snapshot(a) for a in self.running_actions],
"completed": [
{**self._action_snapshot(a), "has_result": r is not None}
for a, r in self.completed_actions
],
}
path = os.path.join(self.log_dir, "action_pool.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(snapshot, f, indent=2, ensure_ascii=False)
emit(
"action_pool_snapshot",
{
"pending_count": len(snapshot["pending"]),
"running_count": len(snapshot["running"]),
"completed_count": len(snapshot["completed"]),
"snapshot": snapshot,
},
source="action_pool._save_pool_json",
action_id=None,
)
except Exception as e:
logger.exception("[ActionPool] _save_pool_json failed: %s", e, exc_info=True)
def add(self, actions: List[Action]) -> None:
if not actions:
return
try:
for action in actions:
self._get_score(action)
self._pool.extend(actions)
self._save_pool_json()
except Exception as e:
logger.exception(
"[ActionPool] add failed, pool unchanged: %s", e, exc_info=True
)
def size(self) -> int:
"""Return the number of actions in the pool."""
try:
return len(self._pool)
except Exception as e:
logger.exception("[ActionPool] size failed: %s", e, exc_info=True)
return 0
def record_completed(self, action: Action, result: Optional[Result]) -> None:
"""Register an action that has finished executing, together with its result."""
try:
idx = next(i for i, a in enumerate(self.running_actions) if a.id == action.id)
self.running_actions.pop(idx)
except StopIteration:
logger.warning(
"[ActionPool] record_completed: action_id=%s not in running_actions "
"(possible double completion or pool desync); still recording in completed.",
"***" if LogManager.is_sensitive() else getattr(action, "id", "?"),
)
self.completed_actions.append((action, result))
self._save_pool_json()
def record_successful_answer(self, action: Action, result: Result) -> None:
"""Register an action whose result contains a found answer (top-k mode)."""
self.successfully_completed_actions.append((action, result))
def successful_answer_count(self) -> int:
"""Return the number of answers collected so far in top-k mode."""
return len(self.successfully_completed_actions)
def get_best_answer(self) -> Optional[tuple[Action, Result]]:
"""Return the (action, result) pair whose answer variable has the highest candidate_strength.
Returns None if no successful answers have been collected.
"""
if not self.successfully_completed_actions:
return None
def _score(pair: tuple[Action, Result]) -> float:
action, _ = pair
for v in action.state.state:
if v.id == action.state.answer_variable:
return v.candidate_strength or 0.0
return 0.0
return max(self.successfully_completed_actions, key=_score)
def get_best_guess(self) -> Optional[tuple[Action, Optional[Result], str]]:
candidates: list[tuple[Action, Optional[Result], str, float]] = []
for action, result in self.completed_actions:
try:
answer_var = next(
(v for v in action.state.state if v.id == action.state.answer_variable),
None,
)
except AttributeError:
continue
if answer_var is None or answer_var.candidate is None:
continue
strength = answer_var.candidate_strength or 0.0
candidates.append((action, result, answer_var.candidate, strength))
if not candidates:
return None
best = max(candidates, key=lambda t: t[3])
return best[:3]
def sample(self, k: int) -> List[Action]:
try:
sampled = self._sample_impl(k)
self.running_actions.extend(sampled)
self._save_pool_json()
return sampled
except Exception as e:
logger.exception(
"[ActionPool] sample failed, returning empty list: %s", e, exc_info=True
)
return []
def _sample_impl(self, k: int) -> List[Action]:
immediate: List[Action] = []
while len(immediate) < k and self.immediate_queue:
immediate.append(self.immediate_queue.popleft())
remaining = k - len(immediate)
if remaining == 0:
return immediate
if self.config.random_sample:
scored = random.sample(self._pool, min(remaining, len(self._pool)))
for a in scored:
self._pool.remove(a)
return immediate + scored
sampled_actions: List[Action] = []
for _ in range(remaining):
if not self._pool:
continue
action_scores = [self._get_score(a) for a in self._pool]
if self.config.promote_unique_states:
unique_states: dict[str, int] = {}
for action_item in self._pool:
h = self._state_hash(action_item)
unique_states[h] = unique_states.get(h, 0) + 1
for j, action_item in enumerate(self._pool):
h = self._state_hash(action_item)
if unique_states[h] > MULTIPLE_STATES_DIVISOR_THRESHOLD:
action_scores[j] /= unique_states[h]
total = sum(action_scores)
if total <= 0:
continue
normal_action_scores = [score / total for score in action_scores]
sample_idx = random.choices(
range(len(self._pool)),
weights=normal_action_scores,
k=1,
)[0]
sampled_actions.append(self._pool.pop(sample_idx))
return immediate + sampled_actions