import logging
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, Field, field_validator
from openjiuwen_deepsearch.algorithm.search_nodes.verify_utils import (
compute_overall,
find_relative_info,
has_number_like,
has_placeholder,
norm_val,
render_clues_placeholder,
to_entity_map,
update_entities,
verify_coarse,
verify_coarse_set,
)
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import (
CandidateVerifiedClues,
State,
VerifiedVariableResult,
)
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
logger = logging.getLogger(__name__)
class EntityCandidate(BaseModel):
variable_id: str
value: str
class Entity(BaseModel):
variable_id: str
type: str
clues: List[str]
value: Optional[str] = None
class VerifyOptions(BaseModel):
mode: str = Field(default="isolated", description="isolated | linked | set")
defer_numeric: bool = Field(default=False, description="whether to defer numeric clues")
max_links: Optional[int] = None
use_scholar: Optional[bool] = None
llm_ctx_max: Optional[int] = None
llm_model_name: Optional[str] = None
judge_model: Optional[str] = None
numeric_model: Optional[str] = None
judge_max_tokens: Optional[int] = None
numeric_max_tokens: Optional[int] = None
@field_validator("mode")
@classmethod
def normalize_mode(cls, v: str) -> str:
v = v.lower()
return v if v in ("isolated", "linked", "set") else "isolated"
class Candidate(BaseModel):
variable_id: str
type: str
value: str
evidence: Optional[str] = ""
class VerifyInput(BaseModel):
entities: List[Entity]
candidates: List[Candidate]
problem_text: str
options: VerifyOptions = Field(default_factory=VerifyOptions)
def _build_clue_items(clues: List[str], entity_map: Dict[str, Any], mode: str) -> List[Dict[str, Any]]:
if mode == "linked":
return render_clues_placeholder(clues, entity_map)
return [{"text": c, "has_placeholder": has_placeholder(c)} for c in clues]
def _split_numeric_clues(clue_items: List[Dict[str, Any]], defer_numeric: bool):
if not defer_numeric:
return clue_items, [], []
numeric, filtered = [], []
for item in clue_items:
if has_number_like(item.get("text", "")):
numeric.append(item)
else:
filtered.append(item)
return filtered, numeric, filtered
def process_verify_report(
cand_reports: Dict[str, Any],
clue_items: List[Dict[str, Any]],
defer_numeric: bool,
numeric_clues: List[Dict[str, Any]],
candidates: List[Dict[str, Any]],
) -> Dict[str, Any]:
reports = cand_reports.get("candidates", [])
results: Dict[str, Any] = {}
for cand in reports:
fixed_statuses = []
promote_flag = False
for idx, clue_judge in enumerate(cand.get("clues", [])):
status = (clue_judge.get("status") or "").strip().lower()
text = (clue_judge.get("text") or "").strip()
has_ph = clue_items[idx]["has_placeholder"] if idx < len(clue_items) else has_placeholder(text)
if has_ph and status == "pass":
promote_flag = True
status = "pending"
clue_judge["status"] = status
fixed_statuses.append(status)
if defer_numeric:
for item in numeric_clues:
cand["clues"].append(
{
"text": item.get("text", ""),
"status": "pending",
"reason": "contains digit, requiring careful analysis",
}
)
cand["overall"] = compute_overall(fixed_statuses, promote_flag)
if len(candidates) == 1:
candidate = candidates[0]
else:
candidate = next(
(c for c in candidates if norm_val(c.get("variable_id")) == norm_val(cand.get("id"))),
{},
)
if candidate.get("evidence"):
cand["evidence"] = candidate["evidence"]
cand.pop("id", None)
results[candidate.get("variable_id")] = cand
return results
@dataclass
class VerifyEntityConfig:
entities: Any
candidate: Any
query: str
model_name: str
mode: str = "linked"
defer_numeric: bool = False
async def verify_entity(verify_config: VerifyEntityConfig):
entity_map = to_entity_map(verify_config.entities, key_name="variable_id")
focus = norm_val(verify_config.candidate.get("variable_id"))
entity_names = [norm_val(e.get("variable_id")) for e in verify_config.entities if e.get("variable_id")]
if not focus or focus not in entity_names:
logger.error("[verify_entity] invalid candidate variable_id")
return {}, 0, 0
focus_entity = entity_map[focus]
clue_items = _build_clue_items(focus_entity.get("clues", []), entity_map, verify_config.mode)
used_clues, numeric_clues, _ = _split_numeric_clues(clue_items, verify_config.defer_numeric)
candidate_value = verify_config.candidate.get("value")
evidence = verify_config.candidate.get("evidence")
if not candidate_value:
return {}, 0, 0
from openjiuwen_deepsearch.algorithm.search_nodes.verify_utils import VerifyCoarseConfig
cand_reports, in_tok, out_tok = await verify_coarse(
VerifyCoarseConfig(
query=verify_config.query,
entities=verify_config.entities,
focus=focus,
candidate=candidate_value,
clue_items=used_clues,
evidence=evidence,
model_name=verify_config.model_name,
)
)
results = process_verify_report(
cand_reports,
used_clues,
verify_config.defer_numeric,
numeric_clues,
[verify_config.candidate],
)
return results, in_tok, out_tok
@dataclass
class VerifyEntitySetConfig:
entities: Any
candidates: Any
query: str
model_name: str
mode: str = "linked"
defer_numeric: bool = False
async def verify_entity_set(verify_config: VerifyEntitySetConfig):
entity_map = to_entity_map(verify_config.entities, key_name="variable_id")
entity_names = [norm_val(e.get("variable_id")) for e in verify_config.entities if e.get("variable_id")]
prepared_candidates = []
numeric_clues, clue_items = [], []
for candidate in verify_config.candidates:
focus = norm_val(candidate.get("variable_id"))
if not focus or focus not in entity_names:
logger.error("[verify_entity_set] invalid candidate variable_id")
continue
if not candidate.get("value"):
continue
focus_entity = entity_map[focus]
clue_items = _build_clue_items(focus_entity.get("clues", []), entity_map, verify_config.mode)
used_clues, numeric_clues, _ = _split_numeric_clues(clue_items, verify_config.defer_numeric)
candidate["clue_items"] = used_clues
prepared_candidates.append(candidate)
if not prepared_candidates:
return {}, 0, 0
cand_reports, in_tok, out_tok = await verify_coarse_set(
query=verify_config.query,
entities=verify_config.entities,
candidates=prepared_candidates,
model_name=verify_config.model_name,
)
results = process_verify_report(
cand_reports,
clue_items,
verify_config.defer_numeric,
numeric_clues,
prepared_candidates,
)
return results, in_tok, out_tok
class SimpleTool:
name: str
description: str
args_schema: Type[BaseModel]
async def run(self, raw_args: Dict[str, Any]):
args = self.args_schema.model_validate(raw_args)
return await self._run(**args.model_dump())
async def _run(self, **kwargs):
raise NotImplementedError
class Verify(SimpleTool):
name: str = "verify"
description: str = "Clue verification via retrieval and LLM reasoning."
args_schema: ClassVar[Type[BaseModel]] = VerifyInput
async def _run(
self,
entities: List[Dict[str, Any]],
candidates: List[Dict[str, Any]],
problem_text: str,
options: Dict[str, Any],
**kwargs,
):
for e in entities:
e["variable_id"] = norm_val(e["variable_id"])
if e.get("value"):
e["value"] = norm_val(e["value"])
for c in candidates:
c["variable_id"] = norm_val(c["variable_id"])
c["value"] = norm_val(c["value"])
mode = (options.get("mode") or "isolated").lower()
model_name = options.get("judge_model") or options.get("llm_model_name")
defer_numeric = options.get("defer_numeric", False)
entities = update_entities(entities, candidates)
total_in, total_out = 0, 0
verify_results: Dict[str, Any] = {}
if mode == "set":
verify_results, total_in, total_out = await verify_entity_set(
VerifyEntitySetConfig(
entities=entities,
candidates=candidates,
query=problem_text,
mode="isolated",
defer_numeric=defer_numeric,
model_name=model_name,
)
)
else:
for candidate in candidates:
res, in_tok, out_tok = await verify_entity(
VerifyEntityConfig(
entities=entities,
candidate=candidate,
query=problem_text,
mode="linked",
defer_numeric=defer_numeric,
model_name=model_name,
)
)
total_in += in_tok
total_out += out_tok
verify_results.update(res)
return verify_results, total_in, total_out
async def _arun(self, *args, **kwargs):
raise NotImplementedError
async def validate_new_state(new_state: State, validator_model_name: str, query: str) -> List[VerifiedVariableResult]:
total_in, total_out = 0, 0
verify_results: List[VerifiedVariableResult] = []
entity_mapping: Dict[str, Any] = {}
entities: List[dict] = []
focus_candidates: List[dict] = []
if isinstance(new_state, dict):
new_state = State(**new_state)
for s in new_state.state:
if s.candidate and norm_val(s.candidate) != "none":
entity_mapping[norm_val(s.id)] = s
entities.append({"variable_id": str(s.id), "type": s.type, "clues": s.question_clues})
discovered, in_tok, out_tok = await find_relative_info(
candidate=s.candidate,
clue_items=s.discovered_clues,
model_name=validator_model_name,
)
total_in += in_tok or 0
total_out += out_tok or 0
focus_candidates.append(
{
"variable_id": str(s.id),
"type": s.type,
"value": s.candidate,
"evidence": " | ".join(discovered),
}
)
verify_result = {}
if focus_candidates:
verify_tool = Verify()
verify_result, in_tok, out_tok = await verify_tool.run(
{
"entities": entities,
"candidates": focus_candidates,
"problem_text": query,
"options": {
"mode": "linked",
"defer_numeric": False,
"llm_model_name": validator_model_name,
},
}
)
total_in += in_tok or 0
total_out += out_tok or 0
for key, result in verify_result.items():
variable = entity_mapping.get(norm_val(key))
if not variable:
continue
verify_results.append(
VerifiedVariableResult(
id=variable.id,
type=variable.type,
candidate_verified_clues=CandidateVerifiedClues.model_validate(result),
)
)
logger.info(
"[validate_new_state] verify_results: %s",
"***" if LogManager.is_sensitive() else verify_results,
)
return verify_results, total_in, total_out
async def run_validations(
new_states: List[Any],
llm_model_name: str,
query: str,
) -> List[Tuple[Any, Any, int, int]]:
result: List[Tuple[Any, Any, int, int]] = []
for i, new_state in enumerate(new_states):
try:
verify_results, input_tokens, output_tokens = await validate_new_state(new_state, llm_model_name, query)
result.append((new_state, verify_results, input_tokens, output_tokens))
except Exception as e:
raise CustomValueException(
StatusCode.VALIDATE_NEW_STATE_ERROR.code,
StatusCode.VALIDATE_NEW_STATE_ERROR.errmsg.format(e=f"state[{i}]: {e}"),
) from e
return result