import importlib
import logging
import os
import sys
from logging.config import fileConfig
from os.path import abspath, dirname
from alembic import context
from dotenv import load_dotenv
from sqlalchemy import engine_from_config, pool
PROJECT_ROOT = dirname(dirname(dirname(dirname(dirname(abspath(__file__))))))
BACKEND_ROOT = dirname(dirname(dirname(dirname(abspath(__file__)))))
sys.path.append(PROJECT_ROOT)
dotenv_path = os.path.join(PROJECT_ROOT, ".env")
if os.path.exists(dotenv_path):
load_dotenv(dotenv_path=dotenv_path)
from openjiuwen_studio.ops.modules.prompt.infra.repositories import orm_repo
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
logger = logging.getLogger(__name__)
def rename_sqlite_indexes():
"""重命名 SQLite 索引以避免冲突"""
db_type = os.getenv("DB_TYPE", "mysql").lower()
if db_type == "sqlite":
for table in orm_repo.Base.metadata.tables.values():
if hasattr(table, "indexes"):
for idx in list(table.indexes):
old_name = idx.name
new_name = f"{old_name}_{table.name}"
if old_name != new_name:
idx.name = new_name
logger.info(f"[Alembic] Renamed index: {table.name}.{old_name} -> {new_name}")
rename_sqlite_indexes()
target_metadata = orm_repo.Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
db_type = os.getenv("DB_TYPE", "mysql")
url = None
if db_type == "sqlite":
sqlite_db = os.getenv("OPS_SQLITE_DB", "ops.db")
sqlite_db_path = os.getenv("SQLITE_DB_PATH", "data/databases")
if not os.path.isabs(sqlite_db):
if not os.path.isabs(sqlite_db_path):
sqlite_db_path = os.path.join(BACKEND_ROOT, sqlite_db_path)
os.makedirs(sqlite_db_path, exist_ok=True)
sqlite_db = os.path.join(sqlite_db_path, sqlite_db)
url = f"sqlite:///{sqlite_db}"
else:
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")
db_host = os.getenv("DB_HOST")
db_port = os.getenv("DB_PORT")
db_name = os.getenv("OPS_DB_NAME")
if all([db_user, db_password, db_host, db_port, db_name]):
url = f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
if url is None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
db_type = os.getenv("DB_TYPE", "mysql")
render_as_batch = False
url = None
if db_type == "sqlite":
sqlite_db = os.getenv("OPS_SQLITE_DB", "ops.db")
sqlite_db_path = os.getenv("SQLITE_DB_PATH", "data/databases")
if not os.path.isabs(sqlite_db):
if not os.path.isabs(sqlite_db_path):
sqlite_db_path = os.path.join(BACKEND_ROOT, sqlite_db_path)
os.makedirs(sqlite_db_path, exist_ok=True)
sqlite_db = os.path.join(sqlite_db_path, sqlite_db)
url = f"sqlite:///{sqlite_db}"
render_as_batch = True
else:
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")
db_host = os.getenv("DB_HOST")
db_port = os.getenv("DB_PORT")
db_name = os.getenv("OPS_DB_NAME")
if all([db_user, db_password, db_host, db_port, db_name]):
url = f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}"
if url is None:
url = config.get_main_option("sqlalchemy.url")
from sqlalchemy import create_engine
connectable = create_engine(url)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=render_as_batch
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()