Files
openclaw-mission-control/backend/app/db/session.py

84 lines
2.6 KiB
Python

from __future__ import annotations
import logging
from collections.abc import AsyncGenerator
from pathlib import Path
import anyio
from alembic import command
from alembic.config import Config
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from app import models as _models
from app.core.config import settings
# Import model modules so SQLModel metadata is fully registered at startup.
_MODEL_REGISTRY = _models
def _normalize_database_url(database_url: str) -> str:
if "://" not in database_url:
return database_url
scheme, rest = database_url.split("://", 1)
if scheme == "postgresql":
return f"postgresql+psycopg://{rest}"
return database_url
async_engine: AsyncEngine = create_async_engine(
_normalize_database_url(settings.database_url),
pool_pre_ping=True,
)
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
logger = logging.getLogger(__name__)
def _alembic_config() -> Config:
alembic_ini = Path(__file__).resolve().parents[2] / "alembic.ini"
alembic_cfg = Config(str(alembic_ini))
alembic_cfg.attributes["configure_logger"] = False
return alembic_cfg
def run_migrations() -> None:
logger.info("Running database migrations.")
command.upgrade(_alembic_config(), "head")
logger.info("Database migrations complete.")
async def init_db() -> None:
if settings.db_auto_migrate:
versions_dir = Path(__file__).resolve().parents[2] / "migrations" / "versions"
if any(versions_dir.glob("*.py")):
logger.info("Running migrations on startup")
await anyio.to_thread.run_sync(run_migrations)
return
logger.warning("No migration revisions found; falling back to create_all")
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
try:
yield session
finally:
in_txn = False
try:
in_txn = bool(session.in_transaction())
except SQLAlchemyError:
logger.exception("Failed to inspect session transaction state.")
if in_txn:
try:
await session.rollback()
except SQLAlchemyError:
logger.exception("Failed to rollback session after request error.")