import enum
import inspect
from dataclasses import dataclass
from collections import defaultdict
from typing import List, Dict, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from mindsdb_sql_parser.ast import Select, BinaryOperation, Identifier, Constant, Star
from mindsdb.utilities import log
from mindsdb.utilities.cache import get_cache
from mindsdb.utilities.config import config
from mindsdb.interfaces.storage import db
from mindsdb.interfaces.skills.sql_agent import SQLAgent
from mindsdb.integrations.libs.vectordatabase_handler import TableField
from mindsdb.interfaces.agents.constants import DEFAULT_TEXT2SQL_DATABASE
_DEFAULT_TOP_K_SIMILARITY_SEARCH = 5
_MAX_CACHE_SIZE = 1000
logger = log.getLogger(__name__)
class SkillType(enum.Enum):
TEXT2SQL_LEGACY = "text2sql"
TEXT2SQL = "sql"
KNOWLEDGE_BASE = "knowledge_base"
RETRIEVAL = "retrieval"
@dataclass
class SkillData:
"""Storage for skill's data
Attributes:
name (str): name of the skill
type (str): skill's type (SkillType)
params (dict): skill's attributes
project_id (int): id of the project
agent_tables_list (Optional[List[str]]): the restriction on available tables for an agent using the skill
"""
name: str
type: str
params: dict
project_id: int
agent_tables_list: Optional[List[str]]
@property
def restriction_on_tables(self) -> Optional[Dict[str, set]]:
"""Schemas and tables which agent+skill may use. The result is intersections of skill's and agent's tables lists.
Returns:
Optional[Dict[str, set]]: allowed schemas and tables. Schemas - are keys in dict, tables - are values.
if result is None, then there are no restrictions
Raises:
ValueError: if there is no intersection between skill's and agent's list.
This means that all tables restricted for use.
"""
def list_to_map(input: List) -> Dict:
agent_tables_map = defaultdict(set)
for x in input:
if isinstance(x, str):
table_name = x
schema_name = None
elif isinstance(x, dict):
table_name = x["table"]
schema_name = x.get("schema")
else:
raise ValueError(f"Unexpected value in tables list: {x}")
agent_tables_map[schema_name].add(table_name)
return agent_tables_map
agent_tables_map = list_to_map(self.agent_tables_list or [])
skill_tables_map = list_to_map(self.params.get("tables", []))
if len(agent_tables_map) > 0 and len(skill_tables_map) > 0:
if len(set(agent_tables_map) & set(skill_tables_map)) == 0:
raise ValueError("Skill's and agent's allowed tables list have no shared schemas.")
intersection_tables_map = defaultdict(set)
has_intersection = False
for schema_name in agent_tables_map:
if schema_name not in skill_tables_map:
continue
intersection_tables_map[schema_name] = agent_tables_map[schema_name] & skill_tables_map[schema_name]
if len(intersection_tables_map[schema_name]) > 0:
has_intersection = True
if has_intersection is False:
raise ValueError("Skill's and agent's allowed tables list have no shared tables.")
return intersection_tables_map
if len(skill_tables_map) > 0:
return skill_tables_map
if len(agent_tables_map) > 0:
return agent_tables_map
return None
class SkillToolController:
def __init__(self):
self.command_executor = None
def get_command_executor(self):
if self.command_executor is None:
from mindsdb.api.executor.command_executor import ExecuteCommands
from mindsdb.api.executor.controllers import (
SessionController,
)
sql_session = SessionController()
sql_session.database = config.get("default_project")
self.command_executor = ExecuteCommands(sql_session)
return self.command_executor
def _make_text_to_sql_tools(self, skills: List[db.Skills], llm) -> List:
"""
Uses SQLAgent to execute tool
"""
try:
from mindsdb.interfaces.agents.mindsdb_database_agent import MindsDBSQL
from mindsdb.interfaces.skills.custom.text2sql.mindsdb_sql_toolkit import MindsDBSQLToolkit
except ImportError:
raise ImportError(
"To use the text-to-SQL skill, please install langchain with `pip install mindsdb[langchain]`"
)
command_executor = self.get_command_executor()
def escape_table_name(name: str) -> str:
name = name.strip(" `")
return f"`{name}`"
tables_list = []
knowledge_bases_list = []
ignore_knowledge_bases_list = []
extracted_databases = set()
knowledge_base_database = DEFAULT_TEXT2SQL_DATABASE
for skill in skills:
if skill.params.get("knowledge_base_database"):
knowledge_base_database = skill.params.get("knowledge_base_database")
if skill.params.get("include_tables"):
include_tables = skill.params.get("include_tables")
if isinstance(include_tables, str):
include_tables = [t.strip() for t in include_tables.split(",")]
for table in include_tables:
if "." in table:
db_name = table.split(".")[0]
extracted_databases.add(db_name)
if skill.params.get("include_knowledge_bases"):
include_kbs = skill.params.get("include_knowledge_bases")
if isinstance(include_kbs, str):
include_kbs = [kb.strip() for kb in include_kbs.split(",")]
for kb in include_kbs:
if "." in kb:
db_name = kb.split(".")[0]
if db_name != knowledge_base_database:
knowledge_base_database = db_name
for skill in skills:
database = skill.params.get("database", DEFAULT_TEXT2SQL_DATABASE)
if not database and extracted_databases:
database = next(iter(extracted_databases))
skill.params["database"] = database
if skill.params.get("include_knowledge_bases"):
include_kbs = skill.params.get("include_knowledge_bases")
if isinstance(include_kbs, str):
include_kbs = [kb.strip() for kb in include_kbs.split(",")]
for kb in include_kbs:
if "." not in kb:
knowledge_bases_list.append(f"{knowledge_base_database}.{kb}")
else:
knowledge_bases_list.append(kb)
if skill.params.get("ignore_knowledge_bases"):
ignore_kbs = skill.params.get("ignore_knowledge_bases")
if isinstance(ignore_kbs, str):
ignore_kbs = [kb.strip() for kb in ignore_kbs.split(",")]
for kb in ignore_kbs:
if "." not in kb:
ignore_knowledge_bases_list.append(f"{knowledge_base_database}.{kb}")
else:
ignore_knowledge_bases_list.append(kb)
if not database:
continue
if skill.params.get("include_tables"):
include_tables = skill.params.get("include_tables")
if isinstance(include_tables, str):
include_tables = [t.strip() for t in include_tables.split(",")]
for table in include_tables:
if "." in table:
if "`" in table:
tables_list.append(table)
else:
parts = table.split(".")
if len(parts) == 2:
tables_list.append(f"{parts[0]}.{escape_table_name(parts[1])}")
elif len(parts) == 3:
tables_list.append(f"{parts[0]}.{parts[1]}.{escape_table_name(parts[2])}")
else:
tables_list.append(escape_table_name(table))
else:
tables_list.append(f"{database}.{escape_table_name(table)}")
continue
restriction_on_tables = skill.restriction_on_tables
if restriction_on_tables is None and database:
try:
handler = command_executor.session.integration_controller.get_data_handler(database)
if "all" in inspect.signature(handler.get_tables).parameters:
response = handler.get_tables(all=True)
else:
response = handler.get_tables()
columns = [c.lower() for c in response.data_frame.columns]
name_idx = columns.index("table_name") if "table_name" in columns else 0
if "table_schema" in response.data_frame.columns:
for _, row in response.data_frame.iterrows():
tables_list.append(f"{database}.{row['table_schema']}.{escape_table_name(row[name_idx])}")
else:
for table_name in response.data_frame.iloc[:, name_idx]:
tables_list.append(f"{database}.{escape_table_name(table_name)}")
except Exception:
logger.warning(f"Could not get tables from database {database}:", exc_info=True)
continue
if restriction_on_tables and database:
for schema_name, tables in restriction_on_tables.items():
for table in tables:
if "." in table:
tables_list.append(escape_table_name(table))
else:
if schema_name is None:
tables_list.append(f"{database}.{escape_table_name(table)}")
else:
tables_list.append(f"{database}.{schema_name}.{escape_table_name(table)}")
continue
tables_list = list(set(tables_list))
knowledge_bases_list = list(set(knowledge_bases_list))
ignore_knowledge_bases_list = list(set(ignore_knowledge_bases_list))
include_knowledge_bases = knowledge_bases_list if knowledge_bases_list else None
ignore_knowledge_bases = ignore_knowledge_bases_list if ignore_knowledge_bases_list else None
if include_knowledge_bases:
ignore_knowledge_bases = None
all_databases = []
all_databases = [db for db in all_databases if db is not None]
databases_struct = {}
for skill in skills:
if skill.params.get("database"):
databases_struct[skill.params["database"]] = skill.restriction_on_tables
for db_name in extracted_databases:
if db_name not in databases_struct:
databases_struct[db_name] = None
sql_agent = SQLAgent(
command_executor=command_executor,
databases=all_databases,
databases_struct=databases_struct,
include_tables=tables_list,
ignore_tables=None,
include_knowledge_bases=include_knowledge_bases,
ignore_knowledge_bases=ignore_knowledge_bases,
knowledge_base_database=knowledge_base_database,
sample_rows_in_table_info=3,
cache=get_cache("agent", max_size=_MAX_CACHE_SIZE),
)
db = MindsDBSQL.custom_init(sql_agent=sql_agent)
should_include_kb_tools = include_knowledge_bases is not None and len(include_knowledge_bases) > 0
should_include_tables_tools = len(databases_struct) > 0 or len(tables_list) > 0
toolkit = MindsDBSQLToolkit(
db=db,
llm=llm,
include_tables_tools=should_include_tables_tools,
include_knowledge_base_tools=should_include_kb_tools,
)
return toolkit.get_tools()
def _make_retrieval_tools(self, skill: db.Skills, llm, embedding_model):
"""
creates advanced retrieval tool i.e. RAG
"""
params = skill.params
config = params.get("config", {})
if "llm" not in config:
config["llm"] = llm
tool = dict(
name=params.get("name", skill.name),
source=params.get("source", None),
config=config,
description=f"You must use this tool to get more context or information "
f"to answer a question about {params['description']}. "
f"The input should be the exact question the user is asking.",
type=skill.type,
)
pred_args = {}
pred_args["llm"] = llm
from .retrieval_tool import build_retrieval_tools
return build_retrieval_tools(tool, pred_args, skill)
def _get_rag_query_function(self, skill: db.Skills):
session_controller = self.get_command_executor().session
def _answer_question(question: str) -> str:
knowledge_base_name = skill.params["source"]
query = Select(
targets=[Star()],
where=BinaryOperation(op="=", args=[Identifier(TableField.CONTENT.value), Constant(question)]),
limit=Constant(_DEFAULT_TOP_K_SIMILARITY_SEARCH),
)
kb_table = session_controller.kb_controller.get_table(knowledge_base_name, skill.project_id)
res = kb_table.select_query(query)
if hasattr(res, "chunk_content"):
return "\n".join(res.chunk_content)
elif hasattr(res, "content"):
return "\n".join(res.content)
else:
return "No content or chunk_content found in knowledge base response"
return _answer_question
def _make_knowledge_base_tools(self, skill: db.Skills) -> dict:
description = skill.params.get("description", "")
logger.warning(
"This skill is deprecated and will be removed in the future. Please use `retrieval` skill instead "
)
return dict(
name="Knowledge Base Retrieval",
func=self._get_rag_query_function(skill),
description=f"Use this tool to get more context or information to answer a question about {description}. The input should be the exact question the user is asking.",
type=skill.type,
)
def get_tools_from_skills(
self, skills_data: List[SkillData], llm: BaseChatModel, embedding_model: Embeddings
) -> dict:
"""Creates function for skill and metadata (name, description)
Args:
skills_data (List[SkillData]): Skills to make a tool from
llm (BaseChatModel): LLM which will be used by skills
embedding_model (Embeddings): this model is used by retrieval skill
Returns:
dict: with keys: name, description, func
"""
skills_group = defaultdict(list)
for skill in skills_data:
try:
skill_type = SkillType(skill.type)
except ValueError:
raise NotImplementedError(
f"skill of type {skill.type} is not supported as a tool, supported types are: {list(SkillType._member_names_)}"
)
if skill_type == SkillType.TEXT2SQL_LEGACY:
skill_type = SkillType.TEXT2SQL
skills_group[skill_type].append(skill)
tools = {}
for skill_type, skills in skills_group.items():
if skill_type == SkillType.TEXT2SQL:
tools[skill_type] = self._make_text_to_sql_tools(skills, llm)
elif skill_type == SkillType.KNOWLEDGE_BASE:
tools[skill_type] = [self._make_knowledge_base_tools(skill) for skill in skills]
elif skill_type == SkillType.RETRIEVAL:
tools[skill_type] = []
for skill in skills:
tools[skill_type] += self._make_retrieval_tools(skill, llm, embedding_model)
return tools
skill_tool = SkillToolController()