import enum
from openjiuwen.core.session.node import Session
from openjiuwen_deepsearch.utils.constants_utils.node_constants import NodeId
class LlmConfigCategory(enum.Enum):
GENERAL = "general"
PLAN_UNDERSTANDING = "plan_understanding"
INFO_COLLECTING = "info_collecting"
WRITING_CHECKING = "writing_checking"
VLM_CHART_GENERATING = "vlm_chart_generating"
NODE_LLM_MAPPING = {
NodeId.INTENT_RECOGNITION.value: LlmConfigCategory.PLAN_UNDERSTANDING.value,
NodeId.OUTLINE.value: LlmConfigCategory.PLAN_UNDERSTANDING.value,
NodeId.PLAN_REASONING.value: LlmConfigCategory.PLAN_UNDERSTANDING.value,
NodeId.INFO_COLLECTOR.value: LlmConfigCategory.INFO_COLLECTING.value,
NodeId.SUB_REPORTER.value: LlmConfigCategory.WRITING_CHECKING.value,
NodeId.VLM_CHART_GENERATOR.value: LlmConfigCategory.VLM_CHART_GENERATING.value,
}
def adapt_llm_model_name(session: Session, node_name) -> str:
"""根据当前节点名称,自动适配应使用的 LLM 模型名"""
llm_config = session.get_global_state("config.llm_config")
if node_name in NODE_LLM_MAPPING:
model_category = NODE_LLM_MAPPING.get(node_name)
if model_category not in llm_config:
model_category = LlmConfigCategory.GENERAL.value
else:
model_category = LlmConfigCategory.GENERAL.value
llm_model_name = session.get_global_state(f"config.llm_config.{model_category}.model_name")
return llm_model_name
def adapt_vlm_model_name(session: Session, node_name) -> str:
"""获取vlm模型名称"""
llm_config = session.get_global_state("config.llm_config")
if node_name in NODE_LLM_MAPPING:
model_category = NODE_LLM_MAPPING.get(node_name)
if model_category in llm_config:
llm_model_name = session.get_global_state(f"config.llm_config.{model_category}.model_name")
return llm_model_name
return "NO VLM"