refactor: replace exec_dml with CRUD operations in various files and improve session handling

This commit is contained in:
Abhimanyu Saharan
2026-02-09 02:17:34 +05:30
parent 228b99bc9b
commit fafcac1e16
12 changed files with 392 additions and 156 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
from collections.abc import Iterable, Mapping
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
from sqlalchemy import delete as sql_delete
from sqlalchemy import update as sql_update
from sqlalchemy.exc import IntegrityError
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,6 +21,22 @@ class MultipleObjectsReturned(LookupError):
pass
async def _flush_or_rollback(session: AsyncSession) -> None:
try:
await session.flush()
except Exception:
await session.rollback()
raise
async def _commit_or_rollback(session: AsyncSession) -> None:
try:
await session.commit()
except Exception:
await session.rollback()
raise
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
stmt = select(model)
for key, value in lookup.items():
@@ -58,9 +76,9 @@ async def create(
) -> ModelT:
obj = model.model_validate(data)
session.add(obj)
await session.flush()
await _flush_or_rollback(session)
if commit:
await session.commit()
await _commit_or_rollback(session)
if refresh:
await session.refresh(obj)
return obj
@@ -74,9 +92,9 @@ async def save(
refresh: bool = True,
) -> ModelT:
session.add(obj)
await session.flush()
await _flush_or_rollback(session)
if commit:
await session.commit()
await _commit_or_rollback(session)
if refresh:
await session.refresh(obj)
return obj
@@ -85,7 +103,7 @@ async def save(
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
await session.delete(obj)
if commit:
await session.commit()
await _commit_or_rollback(session)
async def list_by(
@@ -111,6 +129,77 @@ async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> b
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]:
stmt = select(model)
if criteria:
stmt = stmt.where(*criteria)
return stmt
async def list_where(
session: AsyncSession,
model: type[ModelT],
*criteria: Any,
order_by: Iterable[Any] = (),
) -> list[ModelT]:
stmt = _criteria_statement(model, criteria)
for ordering in order_by:
stmt = stmt.order_by(ordering)
return list(await session.exec(stmt))
async def delete_where(
session: AsyncSession,
model: type[ModelT],
*criteria: Any,
commit: bool = False,
) -> int:
stmt = sql_delete(model)
if criteria:
stmt = stmt.where(*criteria)
result = await session.exec(cast(Any, stmt))
if commit:
await _commit_or_rollback(session)
rowcount = getattr(result, "rowcount", None)
return int(rowcount) if isinstance(rowcount, int) else 0
async def update_where(
session: AsyncSession,
model: type[ModelT],
*criteria: Any,
updates: Mapping[str, Any] | None = None,
commit: bool = False,
exclude_none: bool = False,
allowed_fields: set[str] | None = None,
**update_fields: Any,
) -> int:
source_updates: dict[str, Any] = {}
if updates:
source_updates.update(dict(updates))
if update_fields:
source_updates.update(update_fields)
values: dict[str, Any] = {}
for key, value in source_updates.items():
if allowed_fields is not None and key not in allowed_fields:
continue
if exclude_none and value is None:
continue
values[key] = value
if not values:
return 0
stmt = sql_update(model).values(**values)
if criteria:
stmt = stmt.where(*criteria)
result = await session.exec(cast(Any, stmt))
if commit:
await _commit_or_rollback(session)
rowcount = getattr(result, "rowcount", None)
return int(rowcount) if isinstance(rowcount, int) else 0
def apply_updates(
obj: ModelT,
updates: Mapping[str, Any],
@@ -179,6 +268,9 @@ async def get_or_create(
if existing is not None:
return existing, False
raise
except Exception:
await session.rollback()
raise
if refresh:
await session.refresh(obj)

View File

@@ -11,9 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from app import models # noqa: F401
from app import models as _models
from app.core.config import settings
# Import model modules so SQLModel metadata is fully registered at startup.
_MODEL_REGISTRY = _models
def _normalize_database_url(database_url: str) -> str:
if "://" not in database_url:
@@ -64,4 +67,11 @@ async def init_db() -> None:
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session
try:
yield session
except Exception:
try:
await session.rollback()
except Exception:
logger.exception("Failed to rollback session after request error.")
raise

View File

@@ -1,9 +0,0 @@
from __future__ import annotations
from sqlalchemy.sql.base import Executable
from sqlmodel.ext.asyncio.session import AsyncSession
async def exec_dml(session: AsyncSession, statement: Executable) -> None:
# SQLModel's AsyncSession typing only overloads exec() for SELECT statements.
await session.exec(statement) # type: ignore[call-overload]