import json
from concurrent.futures import as_completed, TimeoutError
from typing import Dict, Iterable, List, Optional
from uuid import uuid4
import queue
import re
import threading
import numpy as np
import pandas as pd
import logging
from langchain.agents import AgentExecutor
from langchain.agents.initialize import initialize_agent
from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
from langchain_community.chat_models import ChatLiteLLM, ChatOllama
from langchain_writer import ChatWriter
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.agents import AgentAction, AgentStep
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_core.messages.base import BaseMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import Tool
from mindsdb.integrations.libs.llm.utils import get_llm_config
from mindsdb.integrations.utilities.handler_utils import get_api_key
from mindsdb.integrations.utilities.rag.settings import DEFAULT_RAG_PROMPT_TEMPLATE
from mindsdb.interfaces.agents.event_dispatch_callback_handler import (
EventDispatchCallbackHandler,
)
from mindsdb.interfaces.agents.constants import AGENT_CHUNK_POLLING_INTERVAL_SECONDS
from mindsdb.utilities import log
from mindsdb.utilities.context_executor import ContextThreadPoolExecutor
from mindsdb.interfaces.storage import db
from mindsdb.utilities.context import context as ctx
from .mindsdb_chat_model import ChatMindsdb
from .callback_handlers import LogCallbackHandler, ContextCaptureCallback
from .langfuse_callback_handler import LangfuseCallbackHandler, get_skills
from .safe_output_parser import SafeOutputParser
from .providers import get_bedrock_chat_model
from mindsdb.interfaces.agents.constants import (
OPEN_AI_CHAT_MODELS,
DEFAULT_AGENT_TIMEOUT_SECONDS,
get_default_agent_type,
DEFAULT_EMBEDDINGS_MODEL_PROVIDER,
DEFAULT_MAX_ITERATIONS,
DEFAULT_MAX_TOKENS,
DEFAULT_TIKTOKEN_MODEL_NAME,
SUPPORTED_PROVIDERS,
ANTHROPIC_CHAT_MODELS,
GOOGLE_GEMINI_CHAT_MODELS,
OLLAMA_CHAT_MODELS,
NVIDIA_NIM_CHAT_MODELS,
USER_COLUMN,
ASSISTANT_COLUMN,
CONTEXT_COLUMN,
TRACE_ID_COLUMN,
DEFAULT_AGENT_SYSTEM_PROMPT,
WRITER_CHAT_MODELS,
MINDSDB_PREFIX,
EXPLICIT_FORMAT_INSTRUCTIONS,
)
from mindsdb.interfaces.skills.skill_tool import skill_tool, SkillData
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from mindsdb.utilities.langfuse import LangfuseClientWrapper
_PARSING_ERROR_PREFIXES = [
"An output parsing error occurred",
"Could not parse LLM output",
]
logger = log.getLogger(__name__)
def get_llm_provider(args: Dict) -> str:
if "provider" in args:
return args["provider"]
if args["model_name"] in ANTHROPIC_CHAT_MODELS:
return "anthropic"
if args["model_name"] in OPEN_AI_CHAT_MODELS:
return "openai"
if args["model_name"] in OLLAMA_CHAT_MODELS:
return "ollama"
if args["model_name"] in NVIDIA_NIM_CHAT_MODELS:
return "nvidia_nim"
if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS:
return "google"
if args["model_name"] in WRITER_CHAT_MODELS:
return "writer"
raise ValueError("Invalid model name. Please define a supported llm provider")
def get_embedding_model_provider(args: Dict) -> str:
"""Get the embedding model provider from args.
For VLLM, this will use our custom VLLMEmbeddings class from langchain_embedding_handler.
"""
if "embedding_model_provider" in args:
provider = args["embedding_model_provider"]
if provider == "vllm":
if not (args.get("openai_api_base") and args.get("model")):
raise ValueError(
"VLLM embeddings configuration error:\n"
"- Missing required parameters: 'openai_api_base' and/or 'model'\n"
"- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
)
logger.info("Using custom VLLMEmbeddings class")
return "vllm"
return provider
llm_provider = args.get("provider", DEFAULT_EMBEDDINGS_MODEL_PROVIDER)
if llm_provider == "vllm":
if not (args.get("openai_api_base") and args.get("model")):
raise ValueError(
"VLLM embeddings configuration error:\n"
"- Missing required parameters: 'openai_api_base' and/or 'model'\n"
"- When using VLLM as LLM provider, you must specify the embeddings server location and model\n"
"- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
)
logger.info("Using custom VLLMEmbeddings class")
return "vllm"
return llm_provider
def get_chat_model_params(args: Dict) -> Dict:
model_config = args.copy()
model_config["api_keys"] = {p: get_api_key(p, model_config, None, strict=False) for p in SUPPORTED_PROVIDERS}
llm_config = get_llm_config(args.get("provider", get_llm_provider(args)), model_config)
config_dict = llm_config.model_dump(by_alias=True)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
if args.get("provider") == "writer" and "writer_api_key" in config_dict:
config_dict["api_key"] = config_dict.pop("writer_api_key")
return config_dict
def create_chat_model(args: Dict):
model_kwargs = get_chat_model_params(args)
if args["provider"] == "anthropic":
return ChatAnthropic(**model_kwargs)
if args["provider"] == "openai" or args["provider"] == "vllm":
chat_open_ai = ChatOpenAI(**model_kwargs)
try:
chat_open_ai.get_num_tokens_from_messages([])
except NotImplementedError:
chat_open_ai.tiktoken_model_name = DEFAULT_TIKTOKEN_MODEL_NAME
return chat_open_ai
if args["provider"] == "litellm":
return ChatLiteLLM(**model_kwargs)
if args["provider"] == "ollama":
return ChatOllama(**model_kwargs)
if args["provider"] == "nvidia_nim":
return ChatNVIDIA(**model_kwargs)
if args["provider"] == "google":
return ChatGoogleGenerativeAI(**model_kwargs)
if args["provider"] == "writer":
return ChatWriter(**model_kwargs)
if args["provider"] == "bedrock":
ChatBedrock = get_bedrock_chat_model()
return ChatBedrock(**model_kwargs)
if args["provider"] == "mindsdb":
return ChatMindsdb(**model_kwargs)
raise ValueError(f"Unknown provider: {args['provider']}")
def prepare_prompts(df, base_template, input_variables, user_column=USER_COLUMN):
empty_prompt_ids = np.where(df[input_variables].isna().all(axis=1).values)[0]
base_template = f"{DEFAULT_AGENT_SYSTEM_PROMPT}\n\n{base_template}"
base_template = base_template.replace("{{", "{").replace("}}", "}")
prompts = []
for i, row in df.iterrows():
if i not in empty_prompt_ids:
prompt = PromptTemplate(input_variables=input_variables, template=base_template)
kwargs = {col: row[col] if row[col] is not None else "" for col in input_variables}
prompts.append(prompt.format(**kwargs))
elif row.get(user_column):
prompts.append(row[user_column])
return prompts, empty_prompt_ids
def prepare_callbacks(self, args):
context_callback = ContextCaptureCallback()
callbacks = self._get_agent_callbacks(args)
callbacks.append(context_callback)
return callbacks, context_callback
def handle_agent_error(e, error_message=None):
if error_message is None:
error_message = f"An error occurred during agent execution: {str(e)}"
logger.error(error_message, exc_info=True)
return error_message
def process_chunk(chunk):
if isinstance(chunk, dict):
return {k: process_chunk(v) for k, v in chunk.items()}
elif isinstance(chunk, list):
return [process_chunk(item) for item in chunk]
elif isinstance(chunk, (str, int, float, bool, type(None))):
return chunk
else:
return str(chunk)
class LangchainAgent:
def __init__(self, agent: db.Agents, model: dict = None, llm_params: dict = None):
self.agent = agent
self.model = model
self.run_completion_span: Optional[object] = None
self.llm: Optional[object] = None
self.embedding_model: Optional[object] = None
self.log_callback_handler: Optional[object] = None
self.langfuse_callback_handler: Optional[object] = None
self.mdb_langfuse_callback_handler: Optional[object] = None
self.langfuse_client_wrapper = LangfuseClientWrapper()
self.args = self._initialize_args(llm_params)
self.provider = self.args.get("provider", get_llm_provider(self.args))
def _initialize_args(self, llm_params: dict = None) -> dict:
"""
Initialize the arguments for agent execution.
Takes the parameters passed during execution and sets necessary defaults.
The params are already merged with defaults by AgentsController.get_agent_llm_params.
Args:
llm_params: Parameters for agent execution (already merged with defaults)
Returns:
dict: Final parameters for agent execution
"""
args = self.agent.params.copy()
if llm_params:
args.update(llm_params)
if self.agent.model_name is not None:
args["model_name"] = self.agent.model_name
if self.agent.provider is not None:
args["provider"] = self.agent.provider
args["embedding_model_provider"] = args.get("embedding_model", get_embedding_model_provider(args))
if self.agent.provider == "mindsdb":
args["model_name"] = self.agent.model_name
prompt_template = self.model["problem_definition"].get("using", {}).get("prompt_template")
if prompt_template is not None:
args["prompt_template"] = prompt_template
if args.get("prompt_template") is None:
if args.get("mode") == "retrieval":
args["prompt_template"] = DEFAULT_RAG_PROMPT_TEMPLATE
logger.info(f"Using default retrieval prompt template: {DEFAULT_RAG_PROMPT_TEMPLATE[:50]}...")
else:
default_prompt = "you are an assistant, answer using the tables connected"
args["prompt_template"] = default_prompt
logger.info(f"Using default prompt template: {default_prompt}")
if "prompt_template" in args:
logger.info(f"Using prompt template: {args['prompt_template'][:50]}...")
if "model_name" not in args:
raise ValueError(
"No model name provided for agent. Provide it in the model parameter or in the default model setup."
)
return args
def get_metadata(self) -> Dict:
return {
"provider": self.provider,
"model_name": self.args["model_name"],
"embedding_model_provider": self.args.get(
"embedding_model_provider", get_embedding_model_provider(self.args)
),
"skills": get_skills(self.agent),
"user_id": ctx.user_id,
"session_id": ctx.session_id,
"company_id": ctx.company_id,
"user_class": ctx.user_class,
"email_confirmed": ctx.email_confirmed,
}
def get_tags(self) -> List:
return [
self.provider,
]
def get_completion(self, messages, stream: bool = False, params: dict | None = None):
metadata = self.get_metadata()
tags = self.get_tags()
self.langfuse_client_wrapper.setup_trace(
name="api-completion",
input=messages,
tags=tags,
metadata=metadata,
user_id=ctx.user_id,
session_id=ctx.session_id,
)
self.run_completion_span = self.langfuse_client_wrapper.start_span(name="run-completion", input=messages)
if stream:
return self._get_completion_stream(messages)
args = {}
args.update(self.args)
args.update(params or {})
df = pd.DataFrame(messages)
logger.info(f"LangchainAgent.get_completion: Received {len(messages)} messages")
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Messages DataFrame shape: {df.shape}")
logger.debug(f"Messages DataFrame columns: {df.columns.tolist()}")
logger.debug(f"Messages DataFrame content: {df.to_dict('records')}")
self.provider = args.get("provider", get_llm_provider(args))
df = df.reset_index(drop=True)
agent = self.create_agent(df)
response = self.run_agent(df, agent, args)
self.langfuse_client_wrapper.end_span(span=self.run_completion_span, output=response)
return response
def _get_completion_stream(self, messages: List[dict]) -> Iterable[Dict]:
"""Gets a completion as a stream of chunks from given messages.
Args:
messages (List[dict]): Messages to get completion chunks for
Returns:
chunks (Iterable[object]): Completion chunks
"""
args = self.args
df = pd.DataFrame(messages)
logger.info(f"LangchainAgent._get_completion_stream: Received {len(messages)} messages")
if "question" in df.columns and "answer" in df.columns:
logger.debug("DataFrame has question/answer columns for conversation history")
else:
logger.warning("DataFrame missing question/answer columns! Available columns: {df.columns.tolist()}")
self.embedding_model_provider = args.get("embedding_model_provider", get_embedding_model_provider(args))
self.provider = args.get("provider", get_llm_provider(args))
df = df.reset_index(drop=True)
agent = self.create_agent(df)
return self.stream_agent(df, agent, args)
def create_agent(self, df: pd.DataFrame) -> AgentExecutor:
args = self.args
llm = create_chat_model(args)
self.llm = llm
if args.get("mode") == "retrieval":
self.args.pop("mode")
tools = self._langchain_tools_from_skills(llm)
prompt_template = args["prompt_template"]
memory = ConversationSummaryBufferMemory(
llm=llm,
input_key="input",
output_key="output",
max_token_limit=args.get("max_tokens", DEFAULT_MAX_TOKENS),
memory_key="chat_history",
)
memory.chat_memory.messages.insert(0, SystemMessage(content=prompt_template))
user_column = args.get("user_column", USER_COLUMN)
assistant_column = args.get("assistant_column", ASSISTANT_COLUMN)
logger.info(f"Processing conversation history: {len(df)} total messages, {len(df[:-1])} history messages")
logger.debug(f"User column: {user_column}, Assistant column: {assistant_column}")
history_df = df[:-1]
if len(history_df) == 0:
logger.debug("No history rows to process - this is normal for first message")
history_count = 0
for i, row in enumerate(history_df.to_dict("records")):
question = row.get(user_column)
answer = row.get(assistant_column)
logger.debug(f"Converting history row {i}: question='{question}', answer='{answer}'")
if isinstance(question, str) and len(question) > 0:
memory.chat_memory.messages.append(HumanMessage(content=question))
history_count += 1
logger.debug(f"Added HumanMessage to memory: {question}")
if isinstance(answer, str) and len(answer) > 0:
memory.chat_memory.messages.append(AIMessage(content=answer))
history_count += 1
logger.debug(f"Added AIMessage to memory: {answer}")
logger.info(f"Built conversation history with {history_count} history messages + system message")
logger.debug(f"Final memory messages count: {len(memory.chat_memory.messages)}")
self._conversation_memory = memory
default_agent = get_default_agent_type()
agent_type = args.get("agent_type", default_agent)
agent_executor = initialize_agent(
tools,
llm,
agent=agent_type,
agent_kwargs={
"output_parser": SafeOutputParser(),
"prefix": MINDSDB_PREFIX,
"format_instructions": EXPLICIT_FORMAT_INSTRUCTIONS,
"ai_prefix": "AI",
},
early_stopping_method="generate",
handle_parsing_errors=self._handle_parsing_errors,
max_execution_time=args.get(
"timeout_seconds",
args.get("timeout_seconds", DEFAULT_AGENT_TIMEOUT_SECONDS),
),
max_iterations=args.get("max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)),
memory=memory,
verbose=args.get("verbose", args.get("verbose", False)),
)
return agent_executor
def _langchain_tools_from_skills(self, llm):
skills_data = [
SkillData(
name=rel.skill.name,
type=rel.skill.type,
params=rel.skill.params,
project_id=rel.skill.project_id,
agent_tables_list=(rel.parameters or {}).get("tables"),
)
for rel in self.agent.skills_relationships
]
tools_groups = skill_tool.get_tools_from_skills(skills_data, llm, self.embedding_model)
all_tools = []
for skill_type, tools in tools_groups.items():
for tool in tools:
if isinstance(tool, dict):
tool = Tool(
name=tool["name"],
func=tool["func"],
description=tool["description"],
)
all_tools.append(tool)
return all_tools
def _get_agent_callbacks(self, args: Dict) -> List:
all_callbacks = []
if self.log_callback_handler is None:
self.log_callback_handler = LogCallbackHandler(logger, verbose=args.get("verbose", True))
all_callbacks.append(self.log_callback_handler)
if self.langfuse_client_wrapper.trace is None:
metadata = self.get_metadata()
tags = self.get_tags()
trace_name = "NativeTrace-MindsDB-AgentExecutor"
self.langfuse_client_wrapper.setup_trace(
name=trace_name,
tags=tags,
metadata=metadata,
user_id=ctx.user_id,
session_id=ctx.session_id,
)
if self.langfuse_callback_handler is None:
self.langfuse_callback_handler = self.langfuse_client_wrapper.get_langchain_handler()
if self.mdb_langfuse_callback_handler is None:
trace_id = self.langfuse_client_wrapper.get_trace_id()
span_id = None
if self.run_completion_span is not None:
span_id = self.run_completion_span.id
observation_id = args.get("observation_id", span_id or uuid4().hex)
self.mdb_langfuse_callback_handler = LangfuseCallbackHandler(
langfuse=self.langfuse_client_wrapper.client,
trace_id=trace_id,
observation_id=observation_id,
)
if self.langfuse_callback_handler is not None:
all_callbacks.append(self.langfuse_callback_handler)
if self.mdb_langfuse_callback_handler:
all_callbacks.append(self.mdb_langfuse_callback_handler)
return all_callbacks
def _handle_parsing_errors(self, error: Exception) -> str:
response = str(error)
for p in _PARSING_ERROR_PREFIXES:
if response.startswith(p):
logger.info("Handling parsing error, salvaging response...")
response_output = response.split("`")
if len(response_output) >= 2:
response = response_output[-2]
langchain_react_formatted_response = f"""Thought: Do I need to use a tool? No
AI: {response}"""
return langchain_react_formatted_response
return f"Agent failed with error:\n{str(error)}..."
def run_agent(self, df: pd.DataFrame, agent: AgentExecutor, args: Dict) -> pd.DataFrame:
base_template = args.get("prompt_template", args["prompt_template"])
return_context = args.get("return_context", True)
input_variables = re.findall(r"{{(.*?)}}", base_template)
prompts, empty_prompt_ids = prepare_prompts(
df, base_template, input_variables, args.get("user_column", USER_COLUMN)
)
def _invoke_agent_executor_with_prompt(agent_executor, prompt):
if not prompt:
return {CONTEXT_COLUMN: [], ASSISTANT_COLUMN: ""}
try:
callbacks, context_callback = prepare_callbacks(self, args)
if hasattr(self, "_conversation_messages") and self._conversation_messages:
full_messages = self._conversation_messages + [HumanMessage(content=prompt)]
logger.critical(f"🔍 INVOKING AGENT with {len(full_messages)} messages (including history)")
logger.debug(
f"Full conversation messages: {[type(msg).__name__ + ': ' + msg.content[:100] + '...' for msg in full_messages]}"
)
result = agent_executor.invoke({"input": prompt}, config={"callbacks": callbacks})
else:
logger.warning("No conversation messages found - using simple prompt")
result = agent_executor.invoke({"input": prompt}, config={"callbacks": callbacks})
captured_context = context_callback.get_contexts()
output = result["output"] if isinstance(result, dict) and "output" in result else str(result)
return {CONTEXT_COLUMN: captured_context, ASSISTANT_COLUMN: output}
except Exception as e:
error_message = str(e)
if "API key" in error_message and ("not found" in error_message or "missing" in error_message):
logger.error(f"API Key Error: {error_message}")
error_message = f"API Key Error: {error_message}"
return {
CONTEXT_COLUMN: [],
ASSISTANT_COLUMN: handle_agent_error(e, error_message),
}
completions = []
contexts = []
max_workers = args.get("max_workers", None)
agent_timeout_seconds = args.get("timeout", DEFAULT_AGENT_TIMEOUT_SECONDS)
with ContextThreadPoolExecutor(max_workers=max_workers) as executor:
if prompts:
current_prompt = prompts[-1]
futures = [executor.submit(_invoke_agent_executor_with_prompt, agent, current_prompt)]
else:
logger.error("No prompts found to process")
futures = []
try:
for future in as_completed(futures, timeout=agent_timeout_seconds):
result = future.result()
if result is None:
result = {
CONTEXT_COLUMN: [],
ASSISTANT_COLUMN: "No response generated",
}
completions.append(result[ASSISTANT_COLUMN])
contexts.append(result[CONTEXT_COLUMN])
except TimeoutError:
timeout_message = (
f"I'm sorry! I couldn't generate a response within the allotted time ({agent_timeout_seconds} seconds). "
"If you need more time for processing, you can adjust the timeout settings. "
"Please refer to the documentation for instructions on how to change the timeout value. "
"Feel free to try your request again."
)
logger.warning(f"Agent execution timed out after {agent_timeout_seconds} seconds")
for _ in range(len(futures) - len(completions)):
completions.append(timeout_message)
contexts.append([])
for i in sorted(empty_prompt_ids)[:-1]:
completions.insert(i, None)
contexts.insert(i, [])
pred_df = pd.DataFrame(
{
ASSISTANT_COLUMN: completions,
CONTEXT_COLUMN: [json.dumps(ctx) for ctx in contexts],
TRACE_ID_COLUMN: self.langfuse_client_wrapper.get_trace_id(),
}
)
if not return_context:
pred_df = pred_df.drop(columns=[CONTEXT_COLUMN])
return pred_df
def add_chunk_metadata(self, chunk: Dict) -> Dict:
logger.debug(f"Adding metadata to chunk: {chunk}")
logger.debug(f"Trace ID: {self.langfuse_client_wrapper.get_trace_id()}")
chunk["trace_id"] = self.langfuse_client_wrapper.get_trace_id()
return chunk
def _stream_agent_executor(
self,
agent_executor: AgentExecutor,
prompt: str,
callbacks: List[BaseCallbackHandler],
):
chunk_queue = queue.Queue()
event_dispatch_callback_handler = EventDispatchCallbackHandler(chunk_queue)
callbacks.append(event_dispatch_callback_handler)
stream_iterator = agent_executor.stream(prompt, config={"callbacks": callbacks})
agent_executor_finished_event = threading.Event()
def stream_worker(context: dict):
try:
ctx.load(context)
for chunk in stream_iterator:
chunk_queue.put(chunk)
except TimeoutError as e:
error_message = f"Request timed out: The agent took too long to respond. {str(e)}"
logger.error(f"Timeout error during streaming: {error_message}", exc_info=True)
error_chunk = {
"type": "error",
"content": handle_agent_error(e, error_message),
"error_type": "timeout",
}
chunk_queue.put(error_chunk)
except ConnectionError as e:
error_message = f"Connection error: Failed to connect to the service. {str(e)}"
logger.error(f"Connection error during streaming: {error_message}", exc_info=True)
error_chunk = {
"type": "error",
"content": handle_agent_error(e, error_message),
"error_type": "connection",
}
chunk_queue.put(error_chunk)
except Exception as e:
error_message = str(e)
if "API key" in error_message and ("not found" in error_message or "missing" in error_message):
logger.error(f"API Key Error: {error_message}")
error_message = f"API Key Error: {error_message}"
error_type = "authentication"
elif "404" in error_message or "not found" in error_message.lower():
logger.error(f"Model Error: {error_message}")
error_type = "not_found"
elif "rate limit" in error_message.lower() or "429" in error_message:
logger.error(f"Rate Limit Error: {error_message}")
error_message = f"Rate limit exceeded: {error_message}"
error_type = "rate_limit"
else:
logger.error(f"LLM chain encountered an error during streaming: {error_message}", exc_info=True)
error_type = "general"
if not error_message or not error_message.strip():
error_message = f"An unknown error occurred during streaming: {type(e).__name__}"
error_chunk = {
"type": "error",
"content": handle_agent_error(e, error_message),
"error_type": error_type,
}
chunk_queue.put(error_chunk)
finally:
agent_executor_finished_event.set()
executor_stream_thread = threading.Thread(
target=stream_worker,
daemon=True,
args=(ctx.dump(),),
name="LangchainAgent.stream_worker",
)
executor_stream_thread.start()
while not agent_executor_finished_event.is_set():
try:
chunk = chunk_queue.get(block=True, timeout=AGENT_CHUNK_POLLING_INTERVAL_SECONDS)
except queue.Empty:
continue
logger.debug(f"Processing streaming chunk {chunk}")
processed_chunk = self.process_chunk(chunk)
logger.info(f"Processed chunk: {processed_chunk}")
yield self.add_chunk_metadata(processed_chunk)
chunk_queue.task_done()
def stream_agent(self, df: pd.DataFrame, agent_executor: AgentExecutor, args: Dict) -> Iterable[Dict]:
base_template = args.get("prompt_template", args["prompt_template"])
input_variables = re.findall(r"{{(.*?)}}", base_template)
return_context = args.get("return_context", True)
prompts, _ = prepare_prompts(df, base_template, input_variables, args.get("user_column", USER_COLUMN))
callbacks, context_callback = prepare_callbacks(self, args)
current_prompt = prompts[-1] if prompts else ""
yield self.add_chunk_metadata({"type": "start", "prompt": current_prompt})
if not hasattr(agent_executor, "stream") or not callable(agent_executor.stream):
raise AttributeError("The agent_executor does not have a 'stream' method")
stream_iterator = self._stream_agent_executor(agent_executor, current_prompt, callbacks)
for chunk in stream_iterator:
yield chunk
if return_context:
captured_context = context_callback.get_contexts()
if captured_context:
yield {"type": "context", "content": captured_context}
if self.log_callback_handler.generated_sql:
yield self.add_chunk_metadata({"type": "sql", "content": self.log_callback_handler.generated_sql})
self.langfuse_client_wrapper.end_span_stream(span=self.run_completion_span)
@staticmethod
def process_chunk(chunk):
if isinstance(chunk, dict):
return {k: LangchainAgent.process_chunk(v) for k, v in chunk.items()}
if isinstance(chunk, list):
return [LangchainAgent.process_chunk(item) for item in chunk]
if isinstance(chunk, AgentAction):
return {
"tool": LangchainAgent.process_chunk(chunk.tool),
"tool_input": LangchainAgent.process_chunk(chunk.tool_input),
"log": LangchainAgent.process_chunk(chunk.log),
}
if isinstance(chunk, AgentStep):
return {
"action": LangchainAgent.process_chunk(chunk.action),
"observation": LangchainAgent.process_chunk(chunk.observation) if chunk.observation else "",
}
if issubclass(chunk.__class__, BaseMessage):
return {"content": chunk.content}
if isinstance(chunk, (str, int, float, bool, type(None))):
return chunk
return str(chunk)