import os
import logging
import sqlite3
logging.basicConfig(level=logging.INFO,
format='\n%(asctime)s %(filename)s [line:%(lineno)d] [%(levelname)s] %(message)s')
class EmptyClass:
"""
Empty class
"""
def __init__(self: any, info: str = "") -> None:
self._info = info
@classmethod
def __bool__(cls: any) -> bool:
return False
@classmethod
def __str__(cls: any) -> str:
return ""
@property
def info(self: any) -> str:
"""
get info
:return: _info
"""
return self._info
@staticmethod
def is_empty() -> bool:
"""
return this is a empty class
"""
return True
class DBManager:
"""
class to manage DB operation
"""
FETCH_SIZE = 10000
INSERT_SIZE = 10000
TENNSTONS = 10
NSTOUS = 1000
MAX_ROW_COUNT = 100000000
@staticmethod
def create_connect_db(db_path: str) -> tuple:
"""
create and connect database
"""
try:
conn = sqlite3.connect(db_path)
except sqlite3.Error as err:
logging.error(str(err), exc_info=False)
return EmptyClass("empty conn"), EmptyClass("empty curs")
try:
if isinstance(conn, sqlite3.Connection):
curs = conn.cursor()
os.chmod(db_path, 0o640)
return conn, curs
except sqlite3.Error:
logging.info("conn sqlite error")
return EmptyClass("empty conn"), EmptyClass("empty curs")
return EmptyClass("empty conn"), EmptyClass("empty curs")
@staticmethod
def destroy_db_connect(conn: any, cur: any) -> None:
"""
destroy the db connect
"""
try:
if isinstance(cur, sqlite3.Cursor):
cur.close()
except sqlite3.Error as error:
logging.error(str(error), exc_info=False)
try:
if isinstance(conn, sqlite3.Connection):
conn.close()
except sqlite3.Error as error:
logging.error(str(error), exc_info=False)
@staticmethod
def judge_table_exists(curs: any, table_name: str) -> any:
"""
judge table exists
"""
if not isinstance(curs, sqlite3.Cursor):
return False
try:
curs.execute("select count(*) from sqlite_master where type='table' and name=?", (table_name,))
return curs.fetchone()[0]
except sqlite3.Error as err:
logging.error(err)
return False
@classmethod
def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None) -> list:
"""
fetch 10000 num of data each time to get all data
"""
if not isinstance(curs, sqlite3.Cursor):
return []
data = []
try:
if param:
res = curs.execute(sql, param)
else:
res = curs.execute(sql)
except sqlite3.Error as _err:
logging.error("%s", str(_err), exc_info=False)
logging.debug("%s, sql: %s", str(_err), sql, exc_info=False)
curs.row_factory = None
return []
try:
while True:
res = curs.fetchmany(cls.FETCH_SIZE)
data += res
if len(data) > cls.MAX_ROW_COUNT:
logging.error("Please check the record counts in %s's table",
os.path.basename(curs.execute("PRAGMA database_list;").fetchone()[-1]))
if len(res) < cls.FETCH_SIZE:
break
return data
except sqlite3.Error as _err:
logging.error(str(_err), exc_info=False)
return []
finally:
curs.row_factory = None
@classmethod
def check_item_in_table(cls: any, db_path: str, table_name: str, col: str, item: any):
"""
check if item is in table
"""
conn, curs = DBManager.create_connect_db(db_path)
if not (conn and curs):
return False
sql = "select * from {table_name} where {col}='{item}' limit 1".format(table_name=table_name,
col=col, item=item)
try:
data = DBManager.fetch_all_data(curs, sql)
except sqlite3.Error as _err:
raise RuntimeError(f"Failed to fetch data, ERROR: {_err}")
finally:
DBManager.destroy_db_connect(conn, curs)
if not data:
raise ValueError("Item is not in table.")
@classmethod
def fetch_all_field_name_in_table(cls: any, db_path: str, table_name: str) -> list:
"""
check all field name in table
"""
conn, curs = DBManager.create_connect_db(db_path)
if not (conn and curs):
return []
sql = f"PRAGMA table_info({table_name})"
try:
data = DBManager.fetch_all_data(curs, sql)
return [info[1] for info in data]
except sqlite3.Error as _err:
raise RuntimeError(f"Failed to fetch data, ERROR: {_err}")
finally:
DBManager.destroy_db_connect(conn, curs)
@classmethod
def check_tables_in_db(cls, db_path: any, *tables: any) -> bool:
if os.path.exists(db_path):
conn, curs = cls.create_connect_db(db_path)
if not (conn and curs):
return False
res = True
for table in tables:
if not cls.judge_table_exists(curs, table):
res = False
break
cls.destroy_db_connect(conn, curs)
return res
return False
@classmethod
def check_table_field_content(cls, db_path: str, table_field_content: dict) -> None:
"""
Check if the specified fields in the tables contain all the required content.
Args:
db_path (str): Path to the SQLite database file.
table_field_content (dict): Dictionary where keys are table names and values are dictionaries of field names and required content lists.
Raises:
AssertionError: If any required content is missing.
Example:
table_field_content = {
"table1": {
"field1": ["value1", "value2", "value3"],
"field2": ["value4", "value5"]
},
"table2": {
"field3": ["value6", "value7"]
}
}
FileChecker.check_table_field_content(db_path, table_field_content)
"""
conn, curs = DBManager.create_connect_db(db_path)
if not (conn and curs):
return
try:
for table_name, field_content_dict in table_field_content.items():
for field_name, required_content in field_content_dict.items():
sql = f"SELECT {field_name} FROM {table_name}"
data = DBManager.fetch_all_data(curs, sql)
actual_content = [row[0] for row in data]
missing_content = [content for content in required_content if content not in actual_content]
if missing_content:
raise ValueError(f"Field '{field_name}' in table '{table_name}' in database '{db_path}' "
f"is missing required content: {missing_content}")
except sqlite3.Error as _err:
raise RuntimeError(f"Failed to fetch data, ERROR: {_err}")
finally:
DBManager.destroy_db_connect(conn, curs)
@classmethod
def check_table_field_values(cls, db_path: str, table_name: str, field_names: list, comparison_func) -> None:
"""
Check if all values in the specified fields of the table meet the given condition.
Args:
db_path (str): Path to the SQLite database file.
table_name (str): Name of the table to check.
field_names (list): List of field names to check.
comparison_func (callable): A function that takes a value and returns True if the value meets the condition, otherwise False.
Raises:
AssertionError: If any value in the fields does not meet the condition.
Example:
table_name = "ClusterCommunicationBandwidth"
field_names = ["value1", "value2", "value3"]
# Define the comparison function
def is_greater_than_zero(value):
return value > 0
FileChecker.check_field_values(db_path, table_name, field_names, is_greater_than_zero)
"""
conn, curs = DBManager.create_connect_db(db_path)
if not (conn and curs):
return
try:
fields_str = ", ".join(field_names)
sql = f"SELECT {fields_str} FROM {table_name}"
data = DBManager.fetch_all_data(curs, sql)
invalid_values = []
for row in data:
for value in row:
if not comparison_func(value):
invalid_values.append(value)
if invalid_values:
raise ValueError(f"Fields '{', '.join(field_names)}' in table '{table_name}' in database '{db_path}' "
f"contain values that do not meet the condition: {invalid_values}")
except sqlite3.Error as _err:
raise RuntimeError(f"Failed to fetch data, ERROR: {_err}")
finally:
DBManager.destroy_db_connect(conn, curs)