import logging
import time
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import psycopg2
import psycopg2.errors
import psycopg2.extensions
import psycopg2.extras
import psycopg2.pool
import sqlparse
from dbmind.common.utils import dbmind_assert, write_to_terminal
_psycopg2_kwargs = dict(
options="-c session_timeout=15 -c search_path=public",
application_name='DBMind-openGauss-exporter',
sslmode='disable',
)
def psycopg2_connect(dsn):
conn = psycopg2.connect(
dsn, **_psycopg2_kwargs
)
return conn
class ConnectionPool(psycopg2.pool.ThreadedConnectionPool):
"""A connection pool that works with the threading module and waits for getting a
new connection rather than raise an exception while the pool is exhausted."""
MAX_RETRY_TIME = 1
WAIT_TICK = 0.1
def _getconn(self, key=None):
max_retry_times = ConnectionPool.MAX_RETRY_TIME / ConnectionPool.WAIT_TICK
n_retry = 1
while len(self._used) == self.maxconn and n_retry < max_retry_times:
time.sleep(0.1)
n_retry += 1
if len(self._used) == self.maxconn:
if key is None:
key = self._getkey()
return self._connect(key)
return super()._getconn(key)
class Driver:
def __init__(self):
self._url = None
self.parsed_dsn = None
self.initialized = False
self._pool = None
def initialize(self, url, pool_size=None):
"""
:param url: connect to database by using this url (or DSN).
:param pool_size: connection pool size, if set None, this calls doesn't use pool.
:return: Nothing to return
"""
try:
conn = psycopg2_connect(url)
conn.cursor().execute('select 1;')
conn.close()
self._url = url
self.parsed_dsn = psycopg2.extensions.parse_dsn(url)
if pool_size:
self._pool = ConnectionPool(
minconn=1, maxconn=pool_size,
dsn=url, **_psycopg2_kwargs
)
self.initialized = True
except Exception as e:
raise ConnectionError(e)
@property
def address(self):
return '%s:%s' % (self.parsed_dsn['host'], self.parsed_dsn['port'])
@property
def host(self):
return self.parsed_dsn['host']
@property
def port(self):
return self.parsed_dsn['port']
@property
def dbname(self):
return self.parsed_dsn['dbname']
@property
def username(self):
return self.parsed_dsn['user']
@property
def pwd(self):
return self.parsed_dsn['password']
def query(self, stmt, timeout=0, force_connection_db=None,
return_tuples=False, fetch_all=False, ignore_error=False):
dbmind_assert(self.initialized)
cursor_dict = {}
if not return_tuples:
cursor_dict['cursor_factory'] = psycopg2.extras.RealDictCursor
try:
conn = self.get_conn(force_connection_db)
with conn.cursor(
**cursor_dict
) as cursor:
try:
start = time.monotonic()
if timeout > 0:
cursor.execute('SET statement_timeout = %d;' % (timeout * 1000))
if not fetch_all:
cursor.execute(stmt)
result = cursor.fetchall()
else:
result = []
for sql in sqlparse.split(stmt):
if ignore_error:
try:
cursor.execute(sql)
except Exception as e:
result.append(None)
continue
finally:
conn.commit()
else:
cursor.execute(sql)
if cursor.pgresult_ptr is not None:
result.append(cursor.fetchall())
else:
result.append(None)
conn.commit()
except psycopg2.extensions.QueryCanceledError as e:
logging.error('%s: %s.' % (e.pgerror, stmt))
logging.info(
'Time elapsed during execution is %fs '
'but threshold is %fs.' % (time.monotonic() - start, timeout)
)
result = []
except psycopg2.errors.FeatureNotSupported:
logging.warning('FeatureNotSupported while executing %s.', stmt)
result = []
except psycopg2.errors.ObjectNotInPrerequisiteState:
logging.warning('ObjectNotInPrerequisiteState while executing %s.', stmt)
result = []
except psycopg2.errors.UndefinedParameter:
logging.warning('UndefinedParameter while executing %s.', stmt)
result = []
except psycopg2.errors.UndefinedColumn:
logging.warning('UndefinedColumn while executing %s.', stmt)
result = []
self.put_conn(conn)
except psycopg2.InternalError as e:
logging.error("Cannot execute '%s' due to internal error: %s." % (stmt, e.pgerror))
result = []
except Exception as e:
logging.exception(e)
result = []
return result
def get_conn(self, force_connection_db=None):
"""Cache the connection in the thread so that the thread can
reuse this connection next time, thereby avoiding repeated creation.
By this way, we can realize thread-safe database query,
and at the same time, it can also have an ability similar to a connection pool. """
if force_connection_db:
parsed_dsn = self.parsed_dsn.copy()
parsed_dsn['dbname'] = force_connection_db
dsn = ' '.join(['{}={}'.format(k, v) for k, v in parsed_dsn.items()])
return psycopg2_connect(dsn)
if not self._pool:
return psycopg2_connect(dsn=self._url)
conn = self._pool.getconn()
try:
conn.cursor().execute('select 1;')
except (
psycopg2.InternalError,
psycopg2.InterfaceError,
psycopg2.errors.AdminShutdown,
psycopg2.OperationalError
) as e:
logging.warning(
'Cached database connection to openGauss'
' has been timeout due to %s.' % e
)
self._pool.putconn(conn, close=True)
except Exception as e:
logging.error('Failed to connect to openGauss '
'with cached connection (%s).' % e)
self._pool.putconn(conn, close=True)
return conn
def put_conn(self, conn, close=False):
"""Put away a connection."""
dbname = psycopg2.extensions.parse_dsn(conn.dsn)['dbname']
if not self._pool or dbname != self.dbname:
conn.close()
return
self._pool.putconn(conn, close=close)
class DriverBundle:
__main_db_name__ = 'postgres'
UPDATE_PERIOD = 300
_thread_pool_executor = ThreadPoolExecutor(
thread_name_prefix='DriverBundleWorker'
)
def __init__(
self, url,
include_db_list=None,
exclude_db_list=None,
each_db_max_connections=None,
log_to_terminal=True
):
self.main_driver = Driver()
self.main_driver.initialize(url, each_db_max_connections)
if self.main_dbname != DriverBundle.__main_db_name__:
msg = (
'The default connection database of the exporter is not postgres, '
'so it is possible that some database metric information '
'cannot be collected, such as slow SQL queries.'
)
logging.warning(msg)
if log_to_terminal:
write_to_terminal(msg)
if not self.guarantee_access():
msg = (
'The current user does not have the Monitoradmin/Sysadmin privilege, '
'which will cause many metrics to fail to obtain. '
'Please consider granting this privilege to the connecting user.'
)
logging.warning(msg)
if log_to_terminal:
write_to_terminal(msg)
self._bundle = dict()
self._include_db_list = include_db_list
self._exclude_db_list = exclude_db_list
self._each_db_max_connections = each_db_max_connections
self._last_updated = 0
self._update_lock = threading.RLock()
self.update()
def update(self):
last_updated, self._last_updated = self._last_updated, time.monotonic()
if time.monotonic() - last_updated < self.UPDATE_PERIOD:
return
with self._update_lock:
self._bundle.clear()
self._bundle = {self.main_driver.dbname: self.main_driver}
for dbname in self._discover_databases(self._include_db_list, self._exclude_db_list):
if dbname in self._bundle:
continue
try:
driver = Driver()
driver.initialize(self._splice_url_for_other_db(dbname),
self._each_db_max_connections)
self._bundle[dbname] = driver
except ConnectionError:
logging.warning(
'Cannot connect to the database %s by using the given user.', dbname
)
def _discover_databases(self, include_dbs, exclude_dbs):
if not include_dbs:
include_dbs = {}
if not exclude_dbs:
exclude_dbs = {}
include_dbs = set(include_dbs)
exclude_dbs = set(exclude_dbs)
dbmind_assert(not (include_dbs and exclude_dbs))
all_db_list = self.main_driver.query(
'SELECT datname FROM pg_catalog.pg_database;',
return_tuples=True
)
discovered = set()
for dbname in all_db_list:
if dbname[0] in ('template0', 'template1'):
continue
discovered.add(dbname[0])
if include_dbs:
return discovered.intersection(include_dbs)
return discovered - exclude_dbs
def _splice_url_for_other_db(self, dbname):
parsed_dsn = self.main_driver.parsed_dsn.copy()
parsed_dsn['dbname'] = dbname
return ' '.join(['{}={}'.format(k, v) for (k, v) in parsed_dsn.items()])
def query(self, stmt, timeout=0, force_connection_db=None, return_tuples=False):
"""A decorator for Driver.query. If the caller sets
the parameter `force_connection_db`, this method only returns
the query result from this specified database.
Otherwise, the method will return the
union set of each database's execution result.
This method need to guaranteed thread safety.
"""
self.update()
if force_connection_db is not None:
if force_connection_db not in self._bundle:
return []
return self._bundle[force_connection_db].query(stmt, timeout, None, return_tuples)
futures = []
for dbname in self._bundle:
driver = self._bundle[dbname]
futures.append(
DriverBundle._thread_pool_executor.submit(
driver.query, stmt, timeout, None, return_tuples
)
)
union_set = set()
for future in as_completed(futures):
try:
result = future.result()
for row in result:
if return_tuples:
union_set.add(tuple(row))
else:
union_set.add(tuple(row.items()))
except Exception as e:
logging.exception(e)
if return_tuples:
return list(union_set)
else:
ret = []
for row in union_set:
dict_based_row = {}
for k, v in row:
dict_based_row[k] = v
ret.append(dict_based_row)
return ret
@property
def address(self):
return self.main_driver.address
@property
def host(self):
return self.main_driver.host
@property
def port(self):
return self.main_driver.port
@property
def main_dbname(self):
return self.main_driver.dbname
@property
def username(self):
return self.main_driver.username
def is_monitor_admin(self):
r = self.main_driver.query(
'select rolmonitoradmin from pg_roles where rolname = CURRENT_USER;',
return_tuples=True
)
return r[0][0]
def is_system_admin(self):
"""test if the current user is the system user."""
res = self.main_driver.query(
'select rolsystemadmin from pg_roles where rolname = CURRENT_USER;',
return_tuples=True
)
return res[0][0]
def guarantee_access(self):
if self.is_monitor_admin():
return True
if self.is_system_admin():
self.main_driver.query(
f'ALTER USER {self.username} monadmin;',
return_tuples=True,
fetch_all=True,
ignore_error=True
)
return self.is_monitor_admin()
return False
def is_standby(self):
r = self.main_driver.query(
'select pg_catalog.pg_is_in_recovery();',
return_tuples=True
)
return r[0][0]