import datetime
from typing import Dict, Iterator, List, Union, Tuple, Optional, Any
import copy
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy import null
import pandas as pd
from mindsdb.interfaces.storage import db
from mindsdb.interfaces.storage.db import Predictor
from mindsdb.utilities.context import context as ctx
from mindsdb.interfaces.database.projects import ProjectController
from mindsdb.interfaces.model.functions import PredictorRecordNotFound
from mindsdb.interfaces.model.model_controller import ModelController
from mindsdb.interfaces.skills.skills_controller import SkillsController
from mindsdb.utilities.config import config
from mindsdb.utilities import log
from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError
from .constants import ASSISTANT_COLUMN, SUPPORTED_PROVIDERS, PROVIDER_TO_MODELS
from .provider_utils import get_llm_provider
logger = log.getLogger(__name__)
default_project = config.get("default_project")
class AgentsController:
"""Handles CRUD operations at the database level for Agents"""
assistant_column = ASSISTANT_COLUMN
autogenerated_skill_prefix = "Auto-generated SQL skill for agent"
def __init__(
self,
project_controller: ProjectController = None,
skills_controller: SkillsController = None,
model_controller: ModelController = None,
):
if project_controller is None:
project_controller = ProjectController()
if skills_controller is None:
skills_controller = SkillsController()
if model_controller is None:
model_controller = ModelController()
self.project_controller = project_controller
self.skills_controller = skills_controller
self.model_controller = model_controller
def check_model_provider(self, model_name: str, provider: str = None) -> Tuple[dict, str]:
"""
Checks if a model exists, and gets the provider of the model.
The provider is either the provider of the model or the provider given as an argument.
Parameters:
model_name (str): The name of the model
provider (str): The provider to check
Returns:
model (dict): The model object
provider (str): The provider of the model
"""
model = None
if model_name is None:
return model, provider
try:
model_name_no_version, model_version = Predictor.get_name_and_version(model_name)
model = self.model_controller.get_model(model_name_no_version, version=model_version)
provider = "mindsdb" if model.get("provider") is None else model.get("provider")
except PredictorRecordNotFound:
if not provider:
provider = get_llm_provider({"model_name": model_name})
elif provider not in SUPPORTED_PROVIDERS and model_name not in PROVIDER_TO_MODELS.get(provider, []):
raise ValueError(f"Model with name does not exist for provider {provider}: {model_name}")
return model, provider
def get_agent(self, agent_name: str, project_name: str = default_project) -> Optional[db.Agents]:
"""
Gets an agent by name.
Parameters:
agent_name (str): The name of the agent
project_name (str): The name of the containing project - must exist
Returns:
agent (Optional[db.Agents]): The database agent object
"""
project = self.project_controller.get(name=project_name)
agent = db.Agents.query.filter(
db.Agents.name == agent_name,
db.Agents.project_id == project.id,
db.Agents.company_id == ctx.company_id,
db.Agents.deleted_at == null(),
).first()
return agent
def get_agent_by_id(self, id: int, project_name: str = default_project) -> db.Agents:
"""
Gets an agent by id.
Parameters:
id (int): The id of the agent
project_name (str): The name of the containing project - must exist
Returns:
agent (db.Agents): The database agent object
"""
project = self.project_controller.get(name=project_name)
agent = db.Agents.query.filter(
db.Agents.id == id,
db.Agents.project_id == project.id,
db.Agents.company_id == ctx.company_id,
db.Agents.deleted_at == null(),
).first()
return agent
def get_agents(self, project_name: str) -> List[dict]:
"""
Gets all agents in a project.
Parameters:
project_name (str): The name of the containing project - must exist
Returns:
all-agents (List[db.Agents]): List of database agent object
"""
all_agents = db.Agents.query.filter(db.Agents.company_id == ctx.company_id, db.Agents.deleted_at == null())
if project_name is not None:
project = self.project_controller.get(name=project_name)
all_agents = all_agents.filter(db.Agents.project_id == project.id)
return all_agents.all()
def _create_default_sql_skill(
self,
name,
project_name,
include_tables: List[str] = None,
include_knowledge_bases: List[str] = None,
):
skill_name = f"{name}_sql_skill"
skill_params = {
"type": "sql",
"description": f"{self.autogenerated_skill_prefix} {name}",
}
if include_tables:
skill_params["include_tables"] = include_tables
if include_knowledge_bases:
skill_params["include_knowledge_bases"] = include_knowledge_bases
try:
existing_skill = self.skills_controller.get_skill(skill_name, project_name)
if existing_skill is None:
skill_type = skill_params.pop("type")
self.skills_controller.add_skill(
name=skill_name, project_name=project_name, type=skill_type, params=skill_params
)
else:
params_changed = False
keys = set(skill_params.keys()) | set(existing_skill.params.keys())
for param_key in keys:
if param_key in skill_params:
if existing_skill.params.get(param_key) != skill_params[param_key]:
existing_skill.params[param_key] = skill_params[param_key]
params_changed = True
else:
existing_skill.params.pop(param_key)
params_changed = True
if params_changed:
flag_modified(existing_skill, "params")
db.session.commit()
except Exception as e:
logger.exception("Failed to auto-create or update SQL skill:")
raise ValueError(f"Failed to auto-create or update SQL skill: {e}") from e
return skill_name
def add_agent(
self,
name: str,
project_name: str = None,
model_name: Union[str, dict] = None,
skills: List[Union[str, dict]] = None,
provider: str = None,
params: Dict[str, Any] = None,
) -> db.Agents:
"""
Adds an agent to the database.
Parameters:
name (str): The name of the new agent
project_name (str): The containing project
model_name (str | dict): The name of the existing ML model the agent will use
skills (List[Union[str, dict]]): List of existing skill names to add to the new agent, or list of dicts
with one of keys is "name", and other is additional parameters for relationship agent<>skill
provider (str): The provider of the model
params (Dict[str, str]): Parameters to use when running the agent
data: Dict, data sources for an agent, keys:
- knowledge_bases: List of KBs to use
- tables: list of tables to use
model: Dict, parameters for the model to use
- provider: The provider of the model (e.g., 'openai', 'google')
- Other model-specific parameters like 'api_key', 'model_name', etc.
<provider>_api_key: API key for the provider (e.g., openai_api_key)
# Deprecated parameters:
database: The database to use for text2sql skills (default is 'mindsdb')
knowledge_base_database: The database to use for knowledge base queries (default is 'mindsdb')
include_tables: List of tables to include for text2sql skills
ignore_tables: List of tables to ignore for text2sql skills
include_knowledge_bases: List of knowledge bases to include for text2sql skills
ignore_knowledge_bases: List of knowledge bases to ignore for text2sql skills
Returns:
agent (db.Agents): The created agent
Raises:
EntityExistsError: Agent with given name already exists, or skill/model with given name does not exist.
"""
if project_name is None:
project_name = default_project
project = self.project_controller.get(name=project_name)
agent = self.get_agent(name, project_name)
if agent is not None:
raise EntityExistsError("Agent already exists", name)
params = params or {}
if isinstance(model_name, dict):
params["model"] = model_name
model_name = None
if model_name is not None:
_, provider = self.check_model_provider(model_name, provider)
if model_name is None:
logger.warning("'model_name' param is not provided. Using default global llm model at runtime.")
if provider is not None:
provider_api_key_param = f"{provider.lower()}_api_key"
if provider_api_key_param in params:
pass
if "api_key" in params:
pass
depreciated_params = [
"database",
"knowledge_base_database",
"include_tables",
"ignore_tables",
"include_knowledge_bases",
"ignore_knowledge_bases",
]
if any(param in params for param in depreciated_params):
raise ValueError(
f"Parameters {', '.join(depreciated_params)} are deprecated. "
"Use 'data' parameter with 'tables' and 'knowledge_bases' keys instead."
)
include_tables = None
include_knowledge_bases = None
if "data" in params:
include_knowledge_bases = params["data"].get("knowledge_bases")
include_tables = params["data"].get("tables")
if isinstance(include_tables, str):
include_tables = [t.strip() for t in include_tables.split(",")]
if isinstance(include_knowledge_bases, str):
include_knowledge_bases = [kb.strip() for kb in include_knowledge_bases.split(",")]
if not skills and (include_tables or include_knowledge_bases):
skill = self._create_default_sql_skill(
name,
project_name,
include_tables=include_tables,
include_knowledge_bases=include_knowledge_bases,
)
skills = [skill]
agent = db.Agents(
name=name,
project_id=project.id,
company_id=ctx.company_id,
user_class=ctx.user_class,
model_name=model_name,
provider=provider,
params=params,
)
for skill in skills:
if isinstance(skill, str):
skill_name = skill
parameters = {}
else:
parameters = skill.copy()
skill_name = parameters.pop("name")
existing_skill = self.skills_controller.get_skill(skill_name, project_name)
if existing_skill is None:
db.session.rollback()
raise ValueError(f"Skill with name does not exist: {skill_name}")
if existing_skill.type == "sql":
if include_tables:
parameters["tables"] = include_tables
if include_knowledge_bases:
parameters["include_knowledge_bases"] = include_knowledge_bases
if "include_knowledge_bases" not in existing_skill.params:
existing_skill.params["include_knowledge_bases"] = include_knowledge_bases
flag_modified(existing_skill, "params")
association = db.AgentSkillsAssociation(parameters=parameters, agent=agent, skill=existing_skill)
db.session.add(association)
db.session.add(agent)
db.session.commit()
return agent
def update_agent(
self,
agent_name: str,
project_name: str = default_project,
name: str = None,
model_name: Union[str, dict] = None,
skills_to_add: List[Union[str, dict]] = None,
skills_to_remove: List[str] = None,
skills_to_rewrite: List[Union[str, dict]] = None,
provider: str = None,
params: Dict[str, str] = None,
):
"""
Updates an agent in the database.
Parameters:
agent_name (str): The name of the new agent, or existing agent to update
project_name (str): The containing project
name (str): The updated name of the agent
model_name (str | dict): The name of the existing ML model the agent will use
skills_to_add (List[Union[str, dict]]): List of skill names to add to the agent, or list of dicts
with one of keys is "name", and other is additional parameters for relationship agent<>skill
skills_to_remove (List[str]): List of skill names to remove from the agent
skills_to_rewrite (List[Union[str, dict]]): new list of skills for the agent
provider (str): The provider of the model
params: (Dict[str, str]): Parameters to use when running the agent
Returns:
agent (db.Agents): The created or updated agent
Raises:
EntityExistsError: if agent with new name already exists
EntityNotExistsError: if agent with name or skill not found
ValueError: if conflict in skills list
"""
skills_to_add = skills_to_add or []
skills_to_remove = skills_to_remove or []
skills_to_rewrite = skills_to_rewrite or []
if len(skills_to_rewrite) > 0 and (len(skills_to_remove) > 0 or len(skills_to_add) > 0):
raise ValueError(
"'skills_to_rewrite' and 'skills_to_add' (or 'skills_to_remove') cannot be used at the same time"
)
existing_agent = self.get_agent(agent_name, project_name=project_name)
if existing_agent is None:
raise EntityNotExistsError(f"Agent with name not found: {agent_name}")
existing_params = existing_agent.params or {}
is_demo = (existing_agent.params or {}).get("is_demo", False)
if is_demo and (
(name is not None and name != agent_name)
or (model_name is not None and existing_agent.model_name != model_name)
or (provider is not None and existing_agent.provider != provider)
or (isinstance(params, dict) and len(params) > 0 and "prompt_template" not in params)
):
raise ValueError("It is forbidden to change properties of the demo object")
if name is not None and name != agent_name:
agent_with_new_name = self.get_agent(name, project_name=project_name)
if agent_with_new_name is not None:
raise EntityExistsError(f"Agent with updated name already exists: {name}")
existing_agent.name = name
if model_name or provider:
if isinstance(model_name, dict):
existing_params["model"] = model_name
model_name = None
model, provider = self.check_model_provider(model_name, provider)
existing_agent.model_name = model_name
existing_agent.provider = provider
if "data" in params:
if len(skills_to_add) > 0 or len(skills_to_remove) > 0:
raise ValueError(
"'data' parameter cannot be used with 'skills_to_remove' or 'skills_to_add' parameters"
)
include_knowledge_bases = params["data"].get("knowledge_bases")
include_tables = params["data"].get("tables")
skill = self._create_default_sql_skill(
agent_name,
project_name,
include_tables=include_tables,
include_knowledge_bases=include_knowledge_bases,
)
skills_to_rewrite = [{"name": skill}]
skill_name_to_record_map = {}
for skill_meta in skills_to_add + skills_to_remove + skills_to_rewrite:
skill_name = skill_meta["name"] if isinstance(skill_meta, dict) else skill_meta
if skill_name not in skill_name_to_record_map:
skill_record = self.skills_controller.get_skill(skill_name, project_name)
if skill_record is None:
raise EntityNotExistsError(f"Skill with name does not exist: {skill_name}")
skill_name_to_record_map[skill_name] = skill_record
if len(skills_to_add) > 0 or len(skills_to_remove) > 0:
skills_to_add = [{"name": x} if isinstance(x, str) else x for x in skills_to_add]
skills_to_add_names = [x["name"] for x in skills_to_add]
if not set(skills_to_add_names).isdisjoint(set(skills_to_remove)):
raise ValueError("Conflict between skills to add and skills to remove.")
existing_agent_skills_names = [rel.skill.name for rel in existing_agent.skills_relationships]
for skill_name in skills_to_remove:
for rel in existing_agent.skills_relationships:
if rel.skill.name == skill_name:
db.session.delete(rel)
for skill_name in set(skills_to_add_names) - set(existing_agent_skills_names):
skill_parameters = next(x for x in skills_to_add if x["name"] == skill_name).copy()
del skill_parameters["name"]
association = db.AgentSkillsAssociation(
parameters=skill_parameters, agent=existing_agent, skill=skill_name_to_record_map[skill_name]
)
db.session.add(association)
elif len(skills_to_rewrite) > 0:
skill_name_to_parameters = {
x["name"]: {k: v for k, v in x.items() if k != "name"} for x in skills_to_rewrite
}
existing_skill_names = set()
for rel in existing_agent.skills_relationships:
if rel.skill.name not in skill_name_to_parameters:
db.session.delete(rel)
else:
existing_skill_names.add(rel.skill.name)
skill_parameters = skill_name_to_parameters[rel.skill.name]
rel.parameters = skill_parameters
flag_modified(rel, "parameters")
for new_skill_name in set(skill_name_to_parameters) - existing_skill_names:
association = db.AgentSkillsAssociation(
parameters=skill_name_to_parameters[new_skill_name],
agent=existing_agent,
skill=skill_name_to_record_map[new_skill_name],
)
db.session.add(association)
if params is not None:
existing_params.update(params)
params = {k: v for k, v in existing_params.items() if v is not None}
existing_agent.params = params
flag_modified(existing_agent, "params")
db.session.commit()
return existing_agent
def delete_agent(self, agent_name: str, project_name: str = default_project):
"""
Deletes an agent by name.
Parameters:
agent_name (str): The name of the agent to delete
project_name (str): The name of the containing project
Raises:
ValueError: Agent does not exist.
"""
agent = self.get_agent(agent_name, project_name)
if agent is None:
raise ValueError(f"Agent with name does not exist: {agent_name}")
if isinstance(agent.params, dict) and agent.params.get("is_demo") is True:
raise ValueError("Unable to delete demo object")
for rel in agent.skills_relationships:
if rel.skill.params is not None and self.autogenerated_skill_prefix in rel.skill.params.get("description"):
self.skills_controller.delete_skill(rel.skill.name, project_name, strict_case=True)
agent.deleted_at = datetime.datetime.now()
db.session.commit()
def get_agent_llm_params(self, agent_params: dict):
"""
Get agent LLM parameters by combining default config with user provided parameters.
Similar to how knowledge bases handle default parameters.
"""
combined_model_params = copy.deepcopy(config.get("default_llm", {}))
if "model" in agent_params:
model_params = agent_params["model"]
else:
model_params = agent_params
if model_params:
combined_model_params.update(model_params)
return combined_model_params
def get_completion(
self,
agent: db.Agents,
messages: list[Dict[str, str]],
project_name: str = default_project,
tools: list = None,
stream: bool = False,
params: dict | None = None,
) -> Union[Iterator[object], pd.DataFrame]:
"""
Queries an agent to get a completion.
Parameters:
agent (db.Agents): Existing agent to get completion from
messages (list[Dict[str, str]]): Chat history to send to the agent
project_name (str): Project the agent belongs to (default mindsdb)
tools (list[BaseTool]): Tools to use while getting the completion
stream (bool): Whether to stream the response
params (dict | None): params to redefine agent params
Returns:
response (Union[Iterator[object], pd.DataFrame]): Completion as a DataFrame or iterator of completion chunks
Raises:
ValueError: Agent's model does not exist.
"""
if stream:
return self._get_completion_stream(agent, messages, project_name=project_name, tools=tools, params=params)
from .langchain_agent import LangchainAgent
model, provider = self.check_model_provider(agent.model_name, agent.provider)
if agent.provider is None and provider is not None:
agent.provider = provider
db.session.commit()
llm_params = self.get_agent_llm_params(agent.params)
lang_agent = LangchainAgent(agent, model, llm_params=llm_params)
return lang_agent.get_completion(messages, params=params)
def _get_completion_stream(
self,
agent: db.Agents,
messages: list[Dict[str, str]],
project_name: str = default_project,
tools: list = None,
params: dict | None = None,
) -> Iterator[object]:
"""
Queries an agent to get a stream of completion chunks.
Parameters:
agent (db.Agents): Existing agent to get completion from
messages (list[Dict[str, str]]): Chat history to send to the agent
trace_id (str): ID of Langfuse trace to use
observation_id (str): ID of parent Langfuse observation to use
project_name (str): Project the agent belongs to (default mindsdb)
tools (list[BaseTool]): Tools to use while getting the completion
params (dict | None): params to redefine agent params
Returns:
chunks (Iterator[object]): Completion chunks as an iterator
Raises:
ValueError: Agent's model does not exist.
"""
from .langchain_agent import LangchainAgent
model, provider = self.check_model_provider(agent.model_name, agent.provider)
if agent.provider is None and provider is not None:
agent.provider = provider
db.session.commit()
llm_params = self.get_agent_llm_params(agent.params)
lang_agent = LangchainAgent(agent, model=model, llm_params=llm_params)
return lang_agent.get_completion(messages, stream=True, params=params)