"""
通用数据库模型同步工具
自动检测模型定义与数据库表结构的差异,并同步新增字段
"""
import re
import logging
from datetime import datetime
from typing import Any, Dict, List
from sqlalchemy import inspect, text
from server.core.database import engine
logger = logging.getLogger(__name__)
class DatabaseSync:
"""数据库模型同步器"""
def __init__(self, db_engine):
self.engine = db_engine
self.inspector = inspect(db_engine)
@staticmethod
def get_model_columns(model_class) -> Dict[str, Any]:
"""获取模型定义的列信息"""
columns = {}
for column in model_class.__table__.columns:
columns[column.name] = {
'type': str(column.type),
'nullable': column.nullable,
'default': column.default,
'comment': getattr(column, 'comment', None)
}
return columns
@staticmethod
def _types_match(model_type: str, db_type: str) -> bool:
"""比较两种类型是否匹配
处理类型字符串的格式差异,如 "LONGTEXT" vs "longTEXT"
忽略 COLLATE、CHARACTER SET 等修饰符
"""
def normalize_type(type_str: str) -> str:
type_str = re.sub(r'collate\s+"?\w+"?', '', type_str, flags=re.IGNORECASE)
type_str = re.sub(r'character\s+set\s+\w+', '', type_str, flags=re.IGNORECASE)
type_str = type_str.lower().strip()
type_str = re.sub(r'\s+', ' ', type_str)
return type_str
model_normalized = normalize_type(model_type)
db_normalized = normalize_type(db_type)
if set(['boolean', 'tinyint']) & set([model_normalized, db_normalized]):
model_is_bool = 'boolean' in model_normalized
db_is_bool = 'boolean' in db_normalized
model_is_tinyint = 'tinyint' in model_normalized
db_is_tinyint = 'tinyint' in db_normalized
if model_is_bool and db_is_tinyint:
return True
if model_is_tinyint and db_is_bool:
return True
return model_normalized == db_normalized
def get_table_columns(self, table_name: str) -> Dict[str, Any]:
"""获取数据库表的实际列信息"""
columns = {}
try:
db_columns = self.inspector.get_columns(table_name)
for column in db_columns:
columns[column['name']] = {
'type': str(column['type']),
'nullable': column.get('nullable', True),
'default': column.get('default', None),
'comment': column.get('comment', None)
}
return columns
except Exception as e:
logger.warning(f"无法获取表 {table_name} 的列信息: {e}")
return {}
def get_missing_columns(self, model_class) -> List[str]:
"""获取模型中定义但数据库表中缺失的列"""
table_name = model_class.__tablename__
model_columns = self.get_model_columns(model_class)
table_columns = self.get_table_columns(table_name)
missing_columns = []
for column_name in model_columns:
if column_name not in table_columns:
missing_columns.append(column_name)
return missing_columns
def get_type_mismatched_columns(self, model_class) -> Dict[str, Dict[str, str]]:
"""获取模型定义与数据库表类型不匹配的列
返回: {列名: {'model_type': 模型中的类型, 'db_type': 数据库中的类型}}
"""
table_name = model_class.__tablename__
model_columns = self.get_model_columns(model_class)
table_columns = self.get_table_columns(table_name)
mismatched_columns = {}
for column_name in model_columns:
if column_name in table_columns:
model_type = model_columns[column_name]['type']
db_type = table_columns[column_name]['type']
if not self._types_match(model_type, db_type):
mismatched_columns[column_name] = {
'model_type': model_type,
'db_type': db_type
}
return mismatched_columns
def add_column_to_table(self, model_class, column_name: str):
"""向数据库表添加列"""
table_name = model_class.__tablename__
column = model_class.__table__.columns[column_name]
dialect_name = self.engine.dialect.name
if dialect_name == 'mysql':
alter_sql = f"ALTER TABLE `{table_name}` ADD COLUMN `{column_name}` {column.type}"
elif dialect_name == 'postgresql':
alter_sql = f'ALTER TABLE "{table_name}" ADD COLUMN "{column_name}" {column.type}'
else:
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column.type}"
if not column.nullable:
alter_sql += " NOT NULL"
else:
alter_sql += " NULL"
if column.default is not None:
arg = getattr(column.default, 'arg', None)
if arg is not None and not callable(arg):
if isinstance(arg, (str, datetime)):
alter_sql += f" DEFAULT '{arg}'"
elif isinstance(arg, bool):
alter_sql += f" DEFAULT {1 if arg else 0}"
else:
alter_sql += f" DEFAULT {arg}"
if dialect_name == 'mysql' and hasattr(column, 'comment') and column.comment:
alter_sql += f" COMMENT '{column.comment}'"
try:
with self.engine.connect() as conn:
conn.execute(text(alter_sql))
conn.commit()
logger.info(f"✅ 成功添加列 {column_name} 到表 {table_name}")
except Exception as e:
logger.error(f"❌ 添加列失败 {column_name} 到表 {table_name}: {e}")
raise
def modify_column_type(self, model_class, column_name: str, old_type: str, new_type: str):
"""修改数据库表中列的类型
注意:类型转换可能导致数据丢失,请谨慎操作
"""
table_name = model_class.__tablename__
column = model_class.__table__.columns[column_name]
dialect_name = self.engine.dialect.name
if dialect_name == 'mysql':
alter_sql = f"ALTER TABLE `{table_name}` MODIFY COLUMN `{column_name}` {column.type}"
elif dialect_name == 'sqlite':
logger.warning(
f"⚠️ SQLite 不支持直接修改列类型 {table_name}.{column_name} "
f"从 {old_type} 到 {new_type}。建议手动重建表或使用数据库迁移工具。"
)
return
elif dialect_name == 'postgresql':
alter_sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{column_name}" TYPE {column.type}'
else:
logger.warning(f"⚠️ 不支持的数据库方言: {dialect_name}")
return
try:
with self.engine.connect() as conn:
conn.execute(text(alter_sql))
conn.commit()
logger.info(f"✅ 成功修改列 {table_name}.{column_name} 类型: {old_type} -> {new_type}")
except Exception as e:
logger.error(f"❌ 修改列类型失败 {table_name}.{column_name}: {e}")
raise
def sync_model(self, model_class):
"""同步单个模型"""
table_name = model_class.__tablename__
try:
if not self.inspector.has_table(table_name):
logger.info(f"📋 表 {table_name} 不存在,跳过字段同步")
return
missing_columns = self.get_missing_columns(model_class)
if missing_columns:
logger.info(f"🔄 检测到表 {table_name} 缺少字段: {missing_columns}")
for column_name in missing_columns:
self.add_column_to_table(model_class, column_name)
logger.info(f"✅ 表 {table_name} 字段同步完成")
else:
logger.info(f"✅ 表 {table_name} 字段已同步")
mismatched_columns = self.get_type_mismatched_columns(model_class)
if mismatched_columns:
logger.info(f"🔄 检测到表 {table_name} 字段类型不匹配: {mismatched_columns}")
for column_name, type_info in mismatched_columns.items():
self.modify_column_type(
model_class,
column_name,
type_info['db_type'],
type_info['model_type']
)
logger.info(f"✅ 表 {table_name} 字段类型同步完成")
else:
logger.info(f"✅ 表 {table_name} 字段类型已匹配")
except Exception as e:
logger.error(f"❌ 同步表 {table_name} 失败: {e}")
raise
def sync_all_models(self, model_classes: List):
"""同步所有模型"""
logger.info("🔄 开始数据库模型同步...")
for model_class in model_classes:
try:
self.sync_model(model_class)
except Exception as e:
logger.error(f"❌ 同步模型 {model_class.__name__} 失败: {e}")
continue
logger.info("✅ 数据库模型同步完成")
def get_all_model_classes():
"""获取本项目的所有模型类"""
from server.deepsearch.core.models.report_template import ReportTemplateDB
from server.deepsearch.core.models.web_search_engine_model import WebSearchEngineModel
return [
ReportTemplateDB,
WebSearchEngineModel,
]
def run_database_sync():
"""运行数据库同步"""
try:
sync = DatabaseSync(engine)
model_classes = get_all_model_classes()
sync.sync_all_models(model_classes)
except Exception as e:
logger.error(f"❌ 数据库同步失败: {e}")
raise
if __name__ == "__main__":
run_database_sync()