import gc
gc.disable()
import os
import sys
import time
import atexit
import signal
import psutil
import asyncio
import threading
import shutil
from enum import Enum
from dataclasses import dataclass, field
from typing import Callable, Optional, Tuple, List
from sqlalchemy import func
from sqlalchemy.orm.attributes import flag_modified
from mindsdb.utilities import log
logger = log.getLogger("mindsdb")
logger.debug("Starting MindsDB...")
from mindsdb.__about__ import __version__ as mindsdb_version
from mindsdb.utilities.config import config
from mindsdb.utilities.starters import (
start_http,
start_mysql,
start_ml_task_queue,
start_scheduler,
start_tasks,
start_litellm,
)
from mindsdb.utilities.ps import is_pid_listen_port, get_child_pids
import mindsdb.interfaces.storage.db as db
from mindsdb.utilities.fs import clean_process_marks, clean_unlinked_process_marks, create_pid_file, delete_pid_file
from mindsdb.utilities.context import context as ctx
from mindsdb.utilities.auth import register_oauth_client, get_aws_meta_data
from mindsdb.utilities.sentry import sentry_sdk
from mindsdb.utilities.api_status import set_api_status
try:
import torch.multiprocessing as mp
except Exception:
import multiprocessing as mp
try:
mp.set_start_method("spawn")
except RuntimeError:
logger.info("Torch multiprocessing context already set, ignoring...")
gc.enable()
_stop_event = threading.Event()
class TrunkProcessEnum(Enum):
HTTP = "http"
MYSQL = "mysql"
JOBS = "jobs"
TASKS = "tasks"
ML_TASK_QUEUE = "ml_task_queue"
LITELLM = "litellm"
@classmethod
def _missing_(cls, value):
logger.error(f'"{value}" is not a valid name of subprocess')
sys.exit(1)
@dataclass
class TrunkProcessData:
name: str
entrypoint: Callable
need_to_run: bool = False
port: Optional[int] = None
process: Optional[mp.Process] = None
started: bool = False
args: Optional[Tuple] = None
restart_on_failure: bool = False
max_restart_count: int = 3
max_restart_interval_seconds: int = 60
_restart_count: int = 0
_restarts_time: List[int] = field(default_factory=list)
def request_restart_attempt(self) -> bool:
"""Check if the process may be restarted.
If `max_restart_count` == 0, then there are not restrictions on restarts count or interval.
If `max_restart_interval_seconds` == 0, then there are no time limit for restarts count.
Returns:
bool: `True` if the number of restarts in the interval does not exceed
"""
if self.max_restart_count == 0:
return True
current_time_seconds = int(time.time())
self._restarts_time.append(current_time_seconds)
if self.max_restart_interval_seconds > 0:
self._restarts_time = [
x for x in self._restarts_time if x >= (current_time_seconds - self.max_restart_interval_seconds)
]
if len(self._restarts_time) > self.max_restart_count:
return False
return True
@property
def should_restart(self) -> bool:
"""In case of OOM we want to restart the process. OS kill the process with code 9 on linux when an OOM occurs.
On other OS process will be restarted regardless the code.
Returns:
bool: `True` if the process need to be restarted on failure
"""
if config.is_cloud:
return False
if sys.platform in ("linux", "darwin"):
return self.restart_on_failure and self.process.exitcode == -signal.SIGKILL.value
else:
if self.max_restart_count == 0:
logger.warning("In the current OS, it is not possible to use `max_restart_count=0`")
return False
return self.restart_on_failure
def close_api_gracefully(trunc_processes_struct):
_stop_event.set()
delete_pid_file()
try:
for trunc_processes_data in trunc_processes_struct.values():
process = trunc_processes_data.process
if process is None:
continue
try:
childs = get_child_pids(process.pid)
for p in childs:
try:
os.kill(p, signal.SIGTERM)
except Exception:
p.kill()
sys.stdout.flush()
process.terminate()
process.join()
sys.stdout.flush()
except psutil.NoSuchProcess:
pass
except KeyboardInterrupt:
sys.exit(0)
def clean_mindsdb_tmp_dir():
"""Clean the MindsDB tmp dir at exit."""
temp_dir = config["paths"]["tmp"]
for file in temp_dir.iterdir():
if file.is_dir():
shutil.rmtree(file)
else:
file.unlink()
def set_error_model_status_by_pids(unexisting_pids: List[int]):
"""Models have id of its traiing process in the 'training_metadata' field.
If the pid does not exist, we should set the model status to "error".
Note: only for local usage.
Args:
unexisting_pids (List[int]): list of 'pids' that do not exist.
"""
predictor_records = (
db.session.query(db.Predictor)
.filter(
db.Predictor.deleted_at.is_(None),
db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]),
)
.all()
)
for predictor_record in predictor_records:
predictor_process_id = (predictor_record.training_metadata or {}).get("process_id")
if predictor_process_id in unexisting_pids:
predictor_record.status = db.PREDICTOR_STATUS.ERROR
if isinstance(predictor_record.data, dict) is False:
predictor_record.data = {}
if "error" not in predictor_record.data:
predictor_record.data["error"] = "The training process was terminated for unknown reasons"
flag_modified(predictor_record, "data")
db.session.commit()
def set_error_model_status_for_unfinished():
"""Set error status to any model if status not in 'complete' or 'error'
Note: only for local usage.
"""
predictor_records = (
db.session.query(db.Predictor)
.filter(
db.Predictor.deleted_at.is_(None),
db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]),
)
.all()
)
for predictor_record in predictor_records:
predictor_record.status = db.PREDICTOR_STATUS.ERROR
if isinstance(predictor_record.data, dict) is False:
predictor_record.data = {}
if "error" not in predictor_record.data:
predictor_record.data["error"] = "Unknown error"
flag_modified(predictor_record, "data")
db.session.commit()
def do_clean_process_marks():
"""delete unexisting 'process marks'"""
while _stop_event.wait(timeout=5) is False:
unexisting_pids = clean_unlinked_process_marks()
if not config.is_cloud and len(unexisting_pids) > 0:
set_error_model_status_by_pids(unexisting_pids)
def create_permanent_integrations():
"""
Create permanent integrations, for now only the 'files' integration.
NOTE: this is intentional to avoid importing integration_controller
"""
integration_name = "files"
existing = db.session.query(db.Integration).filter_by(name=integration_name, company_id=None).first()
if existing is not None:
return
integration_record = db.Integration(
name=integration_name,
data={},
engine=integration_name,
company_id=None,
)
db.session.add(integration_record)
try:
db.session.commit()
except Exception:
logger.exception(f"Failed to create permanent integration '{integration_name}' in the internal database.")
db.session.rollback()
def validate_default_project() -> None:
"""Handle 'default_project' config option.
Project with the name specified in 'default_project' must exists and be marked with
'is_default' metadata. If it is not possible, then terminate the process with error.
Note: this can be done using 'project_controller', but we want to save init time and used RAM.
"""
new_default_project_name = config.get("default_project")
logger.debug(f"Checking if default project {new_default_project_name} exists")
filter_company_id = ctx.company_id if ctx.company_id is not None else "0"
current_default_project: db.Project | None = db.Project.query.filter(
db.Project.company_id == filter_company_id,
db.Project.metadata_["is_default"].as_boolean() == True,
).first()
if current_default_project is None:
existing_project = db.Project.query.filter(
db.Project.company_id == filter_company_id,
func.lower(db.Project.name) == func.lower(new_default_project_name),
).first()
if existing_project is None:
logger.critical(f"A project with the name '{new_default_project_name}' does not exist")
sys.exit(1)
existing_project.metadata_ = {"is_default": True}
flag_modified(existing_project, "metadata_")
db.session.commit()
elif current_default_project.name != new_default_project_name:
existing_project = db.Project.query.filter(
db.Project.company_id == filter_company_id,
func.lower(db.Project.name) == func.lower(new_default_project_name),
).first()
if existing_project is not None:
logger.critical(f"A project with the name '{new_default_project_name}' already exists")
sys.exit(1)
current_default_project.name = new_default_project_name
db.session.commit()
def start_process(trunc_process_data: TrunkProcessData) -> None:
"""Start a process.
Args:
trunc_process_data (TrunkProcessData): The data of the process to start.
"""
mp_ctx = mp.get_context("spawn")
logger.info(f"{trunc_process_data.name} API: starting...")
try:
trunc_process_data.process = mp_ctx.Process(
target=trunc_process_data.entrypoint,
args=trunc_process_data.args,
name=trunc_process_data.name,
)
trunc_process_data.process.start()
except Exception as e:
logger.exception(f"Failed to start '{trunc_process_data.name}' API process due to unexpected error:")
close_api_gracefully(trunc_processes_struct)
raise e
if __name__ == "__main__":
mp.freeze_support()
if psutil.virtual_memory().available < (1 << 30):
logger.warning(
"The system is running low on memory. " + "This may impact the stability and performance of the program."
)
ctx.set_default()
if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 10):
print(
"""
MindsDB requires Python >= 3.10 to run
Once you have supported Python version installed you can start mindsdb as follows:
1. create and activate venv:
python3 -m venv venv
source venv/bin/activate
2. install MindsDB:
pip3 install mindsdb
3. Run MindsDB
python3 -m mindsdb
More instructions in https://docs.mindsdb.com
"""
)
exit(1)
if config.cmd_args.version:
print(f"MindsDB {mindsdb_version}")
sys.exit(0)
if config.cmd_args.update_gui or config.cmd_args.load_tokenizer:
if config.cmd_args.update_gui:
from mindsdb.api.http.initialize import initialize_static
logger.info("Updating the GUI version")
initialize_static()
if config.cmd_args.load_tokenizer:
try:
from langchain_core.language_models import get_tokenizer
get_tokenizer()
logger.info("Tokenizer successfully loaded")
except Exception:
logger.info("Failed to load tokenizer: ", exc_info=True)
sys.exit(0)
config.raise_warnings(logger=logger)
os.environ["MINDSDB_RUNTIME"] = "1"
if os.environ.get("ARROW_DEFAULT_MEMORY_POOL") is None:
try:
"""It seems like snowflake handler have memory issue that related to pyarrow. Memory usage keep growing with
requests. This is related to 'memory pool' that is 'mimalloc' by default: it is fastest but use a lot of ram
"""
import pyarrow as pa
try:
pa.jemalloc_memory_pool()
os.environ["ARROW_DEFAULT_MEMORY_POOL"] = "jemalloc"
except NotImplementedError:
pa.system_memory_pool()
os.environ["ARROW_DEFAULT_MEMORY_POOL"] = "system"
except Exception:
pass
db.init()
environment = config["environment"]
if environment == "aws_marketplace":
try:
register_oauth_client()
except Exception:
logger.exception("Something went wrong during client register:")
elif environment != "local":
try:
aws_meta_data = get_aws_meta_data()
config.update({"aws_meta_data": aws_meta_data})
except Exception:
pass
apis = os.getenv("MINDSDB_APIS") or config.cmd_args.api
if apis is None:
api_arr = [TrunkProcessEnum.HTTP, TrunkProcessEnum.MYSQL]
elif apis == "":
api_arr = []
else:
api_arr = [TrunkProcessEnum(name) for name in apis.split(",")]
logger.info(f"Version: {mindsdb_version}")
logger.info(f"Configuration file: {config.config_path or 'absent'}")
logger.info(f"Storage path: {config.paths['root']}")
log.log_system_info(logger)
logger.debug(f"User config: {config.user_config}")
logger.debug(f"System config: {config.auto_config}")
logger.debug(f"Env config: {config.env_config}")
is_cloud = config.is_cloud
unexisting_pids = clean_unlinked_process_marks()
if not is_cloud:
try:
from mindsdb.migrations import migrate
migrate.migrate_to_head()
except Exception:
logger.exception("Failed to apply database migrations. This may prevent MindsDB from operating correctly:")
validate_default_project()
if len(unexisting_pids) > 0:
set_error_model_status_by_pids(unexisting_pids)
set_error_model_status_for_unfinished()
create_permanent_integrations()
clean_process_marks()
http_api_config = config.get("api", {}).get("http", {})
mysql_api_config = config.get("api", {}).get("mysql", {})
litellm_api_config = config.get("api", {}).get("litellm", {})
trunc_processes_struct = {
TrunkProcessEnum.HTTP: TrunkProcessData(
name=TrunkProcessEnum.HTTP.value,
entrypoint=start_http,
port=http_api_config["port"],
args=(config.cmd_args.verbose,),
restart_on_failure=http_api_config.get("restart_on_failure", False),
max_restart_count=http_api_config.get("max_restart_count", TrunkProcessData.max_restart_count),
max_restart_interval_seconds=http_api_config.get(
"max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds
),
),
TrunkProcessEnum.MYSQL: TrunkProcessData(
name=TrunkProcessEnum.MYSQL.value,
entrypoint=start_mysql,
port=mysql_api_config["port"],
args=(config.cmd_args.verbose,),
restart_on_failure=mysql_api_config.get("restart_on_failure", False),
max_restart_count=mysql_api_config.get("max_restart_count", TrunkProcessData.max_restart_count),
max_restart_interval_seconds=mysql_api_config.get(
"max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds
),
),
TrunkProcessEnum.JOBS: TrunkProcessData(
name=TrunkProcessEnum.JOBS.value, entrypoint=start_scheduler, args=(config.cmd_args.verbose,)
),
TrunkProcessEnum.TASKS: TrunkProcessData(
name=TrunkProcessEnum.TASKS.value, entrypoint=start_tasks, args=(config.cmd_args.verbose,)
),
TrunkProcessEnum.ML_TASK_QUEUE: TrunkProcessData(
name=TrunkProcessEnum.ML_TASK_QUEUE.value, entrypoint=start_ml_task_queue, args=(config.cmd_args.verbose,)
),
TrunkProcessEnum.LITELLM: TrunkProcessData(
name=TrunkProcessEnum.LITELLM.value,
entrypoint=start_litellm,
port=litellm_api_config.get("port", 8000),
args=(config.cmd_args.verbose,),
restart_on_failure=litellm_api_config.get("restart_on_failure", False),
max_restart_count=litellm_api_config.get("max_restart_count", TrunkProcessData.max_restart_count),
max_restart_interval_seconds=litellm_api_config.get(
"max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds
),
),
}
for api_enum in api_arr:
if api_enum in trunc_processes_struct:
trunc_processes_struct[api_enum].need_to_run = True
else:
logger.error(f"ERROR: {api_enum} API is not a valid api in config")
if config["jobs"]["disable"] is False:
trunc_processes_struct[TrunkProcessEnum.JOBS].need_to_run = True
if config["tasks"]["disable"] is False:
trunc_processes_struct[TrunkProcessEnum.TASKS].need_to_run = True
if config.cmd_args.ml_task_queue_consumer is True:
trunc_processes_struct[TrunkProcessEnum.ML_TASK_QUEUE].need_to_run = True
create_pid_file(config)
for trunc_process_data in trunc_processes_struct.values():
if trunc_process_data.started is True or trunc_process_data.need_to_run is False:
continue
start_process(trunc_process_data)
if trunc_process_data.port is None:
set_api_status(trunc_process_data.name, True)
atexit.register(close_api_gracefully, trunc_processes_struct=trunc_processes_struct)
atexit.register(clean_mindsdb_tmp_dir)
async def wait_api_start(api_name, pid, port):
timeout = 60
start_time = time.time()
started = is_pid_listen_port(pid, port)
while (time.time() - start_time) < timeout and started is False:
await asyncio.sleep(0.5)
started = is_pid_listen_port(pid, port)
set_api_status(api_name, started)
return api_name, port, started
async def wait_apis_start():
futures = [
wait_api_start(
trunc_process_data.name,
trunc_process_data.process.pid,
trunc_process_data.port,
)
for trunc_process_data in trunc_processes_struct.values()
if trunc_process_data.port is not None and trunc_process_data.need_to_run is True
]
for future in asyncio.as_completed(futures):
api_name, port, started = await future
if started:
logger.info(f"{api_name} API: started on {port}")
else:
logger.error(f"ERROR: {api_name} API cant start on {port}")
async def join_process(trunc_process_data: TrunkProcessData):
finish = False
while not finish:
process = trunc_process_data.process
try:
while process.is_alive():
process.join(1)
await asyncio.sleep(0)
except KeyboardInterrupt:
logger.info("Got keyboard interrupt, stopping APIs")
close_api_gracefully(trunc_processes_struct)
finally:
if trunc_process_data.should_restart:
if trunc_process_data.request_restart_attempt():
logger.warning(f"{trunc_process_data.name} API: stopped unexpectedly, restarting")
trunc_process_data.process = None
if trunc_process_data.name == TrunkProcessEnum.HTTP.value:
trunc_process_data.args = (config.cmd_args.verbose, None, True)
start_process(trunc_process_data)
api_name, port, started = await wait_api_start(
trunc_process_data.name,
trunc_process_data.process.pid,
trunc_process_data.port,
)
if started:
logger.info(f"{api_name} API: started on {port}")
else:
logger.error(f"ERROR: {api_name} API cant start on {port}")
else:
finish = True
logger.error(
f'The "{trunc_process_data.name}" process could not restart after failure. '
"There will be no further attempts to restart."
)
else:
finish = True
logger.info(f"{trunc_process_data.name} API: stopped")
async def gather_apis():
await asyncio.gather(
*[
join_process(trunc_process_data)
for trunc_process_data in trunc_processes_struct.values()
if trunc_process_data.need_to_run is True
],
return_exceptions=False,
)
ioloop = asyncio.new_event_loop()
ioloop.run_until_complete(wait_apis_start())
threading.Thread(target=do_clean_process_marks, name="clean_process_marks").start()
if config["logging"]["resources_log"]["enabled"] is True:
threading.Thread(
target=log.resources_log_thread,
args=(_stop_event, config["logging"]["resources_log"]["interval"]),
name="resources_log",
).start()
ioloop.run_until_complete(gather_apis())
ioloop.close()