import os.path
import sqlite3
from typing import Any, Dict, List, Optional, Type, Union, Iterable, get_origin, get_args
_PY_TYPE_TO_SQLITE = {
int: "INTEGER",
float: "REAL",
str: "TEXT",
bool: "INTEGER",
bytes: "BLOB",
}
def _map_py_type_to_sqlite(py_type: Type) -> str:
"""将 Python 类型转换为 SQLite 类型"""
origin = get_origin(py_type) or py_type
if origin in _PY_TYPE_TO_SQLITE:
return _PY_TYPE_TO_SQLITE[origin]
if origin is Union:
args = get_args(py_type)
non_none = [t for t in args if t is not type(None)]
if len(non_none) == 1:
return _map_py_type_to_sqlite(non_none[0])
return "TEXT"
def _sqlite_type_to_py_type(sqlite_type: str) -> type:
"""将 SQLite 声明类型映射为最可能的 Python 类型"""
if not sqlite_type:
return str
upper = sqlite_type.upper()
if "INT" in upper:
return int
elif "CHAR" in upper or "CLOB" in upper or "TEXT" in upper:
return str
elif "BLOB" in upper:
return bytes
elif "REAL" in upper or "FLOA" in upper or "DOUB" in upper:
return float
else:
return str
import ast
def _parse_default_value(dflt_str: str) -> Any:
"""将 PRAGMA 返回的默认值字符串转为 Python 对象"""
if dflt_str is None:
return None
try:
if dflt_str == '1':
return True
elif dflt_str == '0':
return False
elif dflt_str.lower() in ('true', 'false'):
return dflt_str.lower() == 'true'
return ast.literal_eval(dflt_str)
except (ValueError, SyntaxError):
if dflt_str.startswith("'") and dflt_str.endswith("'"):
return dflt_str[1:-1].replace("''", "'")
elif dflt_str.startswith('"') and dflt_str.endswith('"'):
return dflt_str[1:-1].replace('""', '"')
else:
return dflt_str
class SqliteColumn:
def __init__(
self,
name: str,
data_type: Type = str,
primary_key: bool = False,
autoincrement: bool = False,
not_null: bool = False,
unique: bool = False,
default: Optional[Any] = None
):
if autoincrement and not primary_key:
raise ValueError("autoincrement requires primary_key=True")
if autoincrement and data_type is not int:
raise ValueError("autoincrement only supported for INTEGER type")
self.name = name
self.data_type = data_type
self.primary_key = primary_key
self.autoincrement = autoincrement
self.not_null = not_null
self.unique = unique
self.default = default
def _format_default(self) -> str:
"""格式化默认值为 SQL 字面量"""
val = self.default
if val is None:
return "NULL"
elif isinstance(val, bool):
return "1" if val else "0"
elif isinstance(val, str):
escaped = val.replace("'", "''")
return f"'{escaped}'"
elif isinstance(val, (int, float)):
return str(val)
else:
return f"'{str(val)}'"
def to_sql_def(self) -> str:
parts = [f"`{self.name}`", _map_py_type_to_sqlite(self.data_type)]
if self.primary_key:
parts.append("PRIMARY KEY")
if self.autoincrement:
parts.append("AUTOINCREMENT")
if self.not_null:
parts.append("NOT NULL")
if self.unique:
parts.append("UNIQUE")
if self.default is not None:
parts.append(f"DEFAULT {self._format_default()}")
return " ".join(parts)
class SqliteTable:
name: str
column_dict: Dict[str, SqliteColumn]
def __init__(self, table_name: str, columns: Iterable[SqliteColumn] = None):
self.name = table_name
self.column_dict = {}
if columns:
for column in columns:
self.column_dict[column.name] = column
def to_sql_def(self, delete_if_exists: bool = False) -> str:
"""
生成创建表的 SQL 语句
:param delete_if_exists: 是否先 DROP TABLE IF EXISTS
"""
column_defs = [col.to_sql_def() for _, col in self.column_dict.items()]
create_sql = f"CREATE TABLE {self.name} ({', '.join(column_defs)});"
if delete_if_exists:
drop_sql = f"DROP TABLE IF EXISTS {self.name};"
return f"{drop_sql}\n{create_sql}"
else:
create_sql = create_sql.replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS", 1)
return create_sql
def create_table(self, conn: sqlite3.Connection, delete_if_exists: bool = False):
"""
在数据库中创建表
:param conn: sqlite3.Connection 对象
:param delete_if_exists: 是否先删除已存在的表
"""
sql = self.to_sql_def(delete_if_exists=delete_if_exists)
conn.executescript(sql)
conn.commit()
def insert_record(self, conn: sqlite3.Connection, record: Dict[str, Any]):
"""插入单条记录"""
columns = SqliteTable.get_insert_columns_by_record(record)
placeholders = SqliteTable.get_insert_placeholder_by_record(record)
sql = f"INSERT INTO {self.name} ({', '.join(columns)}) VALUES ({placeholders})"
conn.execute(sql, tuple(record.values()))
conn.commit()
def insert_records(self, conn: sqlite3.Connection, records: List[Dict[str, Any]]):
"""批量插入多条记录"""
if not records:
return
columns = SqliteTable.get_insert_columns_by_record(records[0])
placeholders = SqliteTable.get_insert_placeholder_by_record(records[0])
sql = f"INSERT INTO {self.name} ({', '.join(columns)}) VALUES ({placeholders})"
values = SqliteTable.get_insert_values_by_records(records)
conn.executemany(sql, values)
conn.commit()
@staticmethod
def get_insert_columns_by_record(record: Dict[str, Any]):
return [f"`{key}`" for key in record.keys()]
@staticmethod
def get_insert_placeholder_by_record(record: Dict[str, Any]):
return ', '.join(['?' for _ in record.keys()])
@staticmethod
def get_insert_values_by_records(records: List[Dict[str, Any]]):
if not records:
return []
return [tuple(r[k] for k in records[0].keys()) for r in records]
class SqliteDB:
path: str
conn: sqlite3.Connection
table_cache: Dict[str, SqliteTable]
def __init__(self, path: str, auto_create: bool = True):
self.path = os.path.realpath(path)
if not os.path.exists(self.path):
if not auto_create:
raise FileNotFoundError(f"Db file not found: {self.path}.")
dir_path = os.path.dirname(self.path)
if not os.path.exists(dir_path):
os.makedirs(dir_path, exist_ok=True)
self.conn = sqlite3.connect(self.path)
self.table_cache = {}
def create_table(self, table: SqliteTable, delete_if_exists: bool = True):
table.create_table(self.conn, delete_if_exists)
self.table_cache[table.name] = table
def is_table_exists(self, table_name: str) -> bool:
if table_name in self.table_cache:
return True
cursor = self.conn.cursor()
cursor.execute("""
SELECT name
FROM sqlite_master
WHERE type = 'table'
AND name = ?
""", (table_name,))
exist = cursor.fetchone() is not None
if exist:
self.table_cache[table_name] = self.get_table_by_name(table_name)
return exist
def get_table_by_name(self, table_name: str) -> SqliteTable:
"""
从数据库中读取表结构,还原为 SqliteTable 对象
"""
if table_name in self.table_cache:
return self.table_cache[table_name]
cur = self.conn.execute(f"PRAGMA table_info({table_name});")
rows = cur.fetchall()
if not rows:
raise ValueError(f"Table '{table_name}' does not exist.")
table = SqliteTable(table_name)
for row in rows:
cid, name, type_affinity, notnull, dflt_value, pk = row
py_type = _sqlite_type_to_py_type(type_affinity)
default_val = _parse_default_value(dflt_value)
autoincrement = False
if pk and py_type is int:
seq_cur = self.conn.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name='sqlite_sequence';"
)
if seq_cur.fetchone():
seq_cur = self.conn.execute(
"SELECT 1 FROM sqlite_sequence WHERE name = ?;", (table_name,)
)
autoincrement = seq_cur.fetchone() is not None
column = SqliteColumn(
name=name,
data_type=py_type,
primary_key=bool(pk),
autoincrement=autoincrement,
not_null=bool(notnull),
default=default_val,
)
table.column_dict[column.name] = column
self.table_cache[table_name] = table
return table