refactor: replace exec_dml with CRUD operations in various files and improve session handling
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy import update as sql_update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -19,6 +21,22 @@ class MultipleObjectsReturned(LookupError):
|
||||
pass
|
||||
|
||||
|
||||
async def _flush_or_rollback(session: AsyncSession) -> None:
|
||||
try:
|
||||
await session.flush()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def _commit_or_rollback(session: AsyncSession) -> None:
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
@@ -58,9 +76,9 @@ async def create(
|
||||
) -> ModelT:
|
||||
obj = model.model_validate(data)
|
||||
session.add(obj)
|
||||
await session.flush()
|
||||
await _flush_or_rollback(session)
|
||||
if commit:
|
||||
await session.commit()
|
||||
await _commit_or_rollback(session)
|
||||
if refresh:
|
||||
await session.refresh(obj)
|
||||
return obj
|
||||
@@ -74,9 +92,9 @@ async def save(
|
||||
refresh: bool = True,
|
||||
) -> ModelT:
|
||||
session.add(obj)
|
||||
await session.flush()
|
||||
await _flush_or_rollback(session)
|
||||
if commit:
|
||||
await session.commit()
|
||||
await _commit_or_rollback(session)
|
||||
if refresh:
|
||||
await session.refresh(obj)
|
||||
return obj
|
||||
@@ -85,7 +103,7 @@ async def save(
|
||||
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
||||
await session.delete(obj)
|
||||
if commit:
|
||||
await session.commit()
|
||||
await _commit_or_rollback(session)
|
||||
|
||||
|
||||
async def list_by(
|
||||
@@ -111,6 +129,77 @@ async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> b
|
||||
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
|
||||
|
||||
|
||||
def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]:
|
||||
stmt = select(model)
|
||||
if criteria:
|
||||
stmt = stmt.where(*criteria)
|
||||
return stmt
|
||||
|
||||
|
||||
async def list_where(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*criteria: Any,
|
||||
order_by: Iterable[Any] = (),
|
||||
) -> list[ModelT]:
|
||||
stmt = _criteria_statement(model, criteria)
|
||||
for ordering in order_by:
|
||||
stmt = stmt.order_by(ordering)
|
||||
return list(await session.exec(stmt))
|
||||
|
||||
|
||||
async def delete_where(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*criteria: Any,
|
||||
commit: bool = False,
|
||||
) -> int:
|
||||
stmt = sql_delete(model)
|
||||
if criteria:
|
||||
stmt = stmt.where(*criteria)
|
||||
result = await session.exec(cast(Any, stmt))
|
||||
if commit:
|
||||
await _commit_or_rollback(session)
|
||||
rowcount = getattr(result, "rowcount", None)
|
||||
return int(rowcount) if isinstance(rowcount, int) else 0
|
||||
|
||||
|
||||
async def update_where(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*criteria: Any,
|
||||
updates: Mapping[str, Any] | None = None,
|
||||
commit: bool = False,
|
||||
exclude_none: bool = False,
|
||||
allowed_fields: set[str] | None = None,
|
||||
**update_fields: Any,
|
||||
) -> int:
|
||||
source_updates: dict[str, Any] = {}
|
||||
if updates:
|
||||
source_updates.update(dict(updates))
|
||||
if update_fields:
|
||||
source_updates.update(update_fields)
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
for key, value in source_updates.items():
|
||||
if allowed_fields is not None and key not in allowed_fields:
|
||||
continue
|
||||
if exclude_none and value is None:
|
||||
continue
|
||||
values[key] = value
|
||||
if not values:
|
||||
return 0
|
||||
|
||||
stmt = sql_update(model).values(**values)
|
||||
if criteria:
|
||||
stmt = stmt.where(*criteria)
|
||||
result = await session.exec(cast(Any, stmt))
|
||||
if commit:
|
||||
await _commit_or_rollback(session)
|
||||
rowcount = getattr(result, "rowcount", None)
|
||||
return int(rowcount) if isinstance(rowcount, int) else 0
|
||||
|
||||
|
||||
def apply_updates(
|
||||
obj: ModelT,
|
||||
updates: Mapping[str, Any],
|
||||
@@ -179,6 +268,9 @@ async def get_or_create(
|
||||
if existing is not None:
|
||||
return existing, False
|
||||
raise
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
if refresh:
|
||||
await session.refresh(obj)
|
||||
|
||||
@@ -11,9 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app import models # noqa: F401
|
||||
from app import models as _models
|
||||
from app.core.config import settings
|
||||
|
||||
# Import model modules so SQLModel metadata is fully registered at startup.
|
||||
_MODEL_REGISTRY = _models
|
||||
|
||||
|
||||
def _normalize_database_url(database_url: str) -> str:
|
||||
if "://" not in database_url:
|
||||
@@ -64,4 +67,11 @@ async def init_db() -> None:
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
logger.exception("Failed to rollback session after request error.")
|
||||
raise
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.sql.base import Executable
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def exec_dml(session: AsyncSession, statement: Executable) -> None:
|
||||
# SQLModel's AsyncSession typing only overloads exec() for SELECT statements.
|
||||
await session.exec(statement) # type: ignore[call-overload]
|
||||
Reference in New Issue
Block a user