refactor: enhance docstrings for clarity and consistency across multiple files
This commit is contained in:
@@ -1,27 +1,37 @@
|
||||
"""Generic asynchronous CRUD helpers for SQLModel entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy import update as sql_update
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
|
||||
class DoesNotExist(LookupError):
|
||||
pass
|
||||
class DoesNotExistError(LookupError):
|
||||
"""Raised when a query expected one row but found none."""
|
||||
|
||||
|
||||
class MultipleObjectsReturned(LookupError):
|
||||
pass
|
||||
class MultipleObjectsReturnedError(LookupError):
|
||||
"""Raised when a query expected one row but found many."""
|
||||
|
||||
|
||||
DoesNotExist = DoesNotExistError
|
||||
MultipleObjectsReturned = MultipleObjectsReturnedError
|
||||
|
||||
|
||||
async def _flush_or_rollback(session: AsyncSession) -> None:
|
||||
"""Flush changes and rollback on SQLAlchemy errors."""
|
||||
try:
|
||||
await session.flush()
|
||||
except SQLAlchemyError:
|
||||
@@ -30,6 +40,7 @@ async def _flush_or_rollback(session: AsyncSession) -> None:
|
||||
|
||||
|
||||
async def _commit_or_rollback(session: AsyncSession) -> None:
|
||||
"""Commit transaction and rollback on SQLAlchemy errors."""
|
||||
try:
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
@@ -37,31 +48,50 @@ async def _commit_or_rollback(session: AsyncSession) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
|
||||
def _lookup_statement(
|
||||
model: type[ModelT],
|
||||
lookup: Mapping[str, Any],
|
||||
) -> SelectOfScalar[ModelT]:
|
||||
"""Build a select statement with equality filters from lookup values."""
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
return stmt
|
||||
|
||||
|
||||
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
|
||||
async def get_by_id(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
obj_id: object,
|
||||
) -> ModelT | None:
|
||||
"""Fetch one model instance by id or return None."""
|
||||
stmt = _lookup_statement(model, {"id": obj_id}).limit(1)
|
||||
return (await session.exec(stmt)).first()
|
||||
|
||||
|
||||
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:
|
||||
async def get(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
**lookup: object,
|
||||
) -> ModelT:
|
||||
"""Fetch exactly one model instance by lookup values."""
|
||||
stmt = _lookup_statement(model, lookup).limit(2)
|
||||
items = (await session.exec(stmt)).all()
|
||||
if not items:
|
||||
raise DoesNotExist(f"{model.__name__} matching query does not exist.")
|
||||
message = f"{model.__name__} matching query does not exist."
|
||||
raise DoesNotExist(message)
|
||||
if len(items) > 1:
|
||||
raise MultipleObjectsReturned(
|
||||
f"Multiple {model.__name__} objects returned for lookup {lookup!r}."
|
||||
)
|
||||
message = f"Multiple {model.__name__} objects returned for lookup {lookup!r}."
|
||||
raise MultipleObjectsReturned(message)
|
||||
return items[0]
|
||||
|
||||
|
||||
async def get_one_by(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT | None:
|
||||
async def get_one_by(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
**lookup: object,
|
||||
) -> ModelT | None:
|
||||
"""Fetch the first model instance matching lookup values."""
|
||||
stmt = _lookup_statement(model, lookup)
|
||||
return (await session.exec(stmt)).first()
|
||||
|
||||
@@ -72,8 +102,9 @@ async def create(
|
||||
*,
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
**data: Any,
|
||||
**data: object,
|
||||
) -> ModelT:
|
||||
"""Create, flush, optionally commit, and optionally refresh an object."""
|
||||
obj = model.model_validate(data)
|
||||
session.add(obj)
|
||||
await _flush_or_rollback(session)
|
||||
@@ -91,6 +122,7 @@ async def save(
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
) -> ModelT:
|
||||
"""Persist an existing object with optional commit and refresh."""
|
||||
session.add(obj)
|
||||
await _flush_or_rollback(session)
|
||||
if commit:
|
||||
@@ -101,6 +133,7 @@ async def save(
|
||||
|
||||
|
||||
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
||||
"""Delete an object with optional commit."""
|
||||
await session.delete(obj)
|
||||
if commit:
|
||||
await _commit_or_rollback(session)
|
||||
@@ -113,8 +146,9 @@ async def list_by(
|
||||
order_by: Iterable[Any] = (),
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
**lookup: Any,
|
||||
**lookup: object,
|
||||
) -> list[ModelT]:
|
||||
"""List objects by lookup values with optional ordering and pagination."""
|
||||
stmt = _lookup_statement(model, lookup)
|
||||
for ordering in order_by:
|
||||
stmt = stmt.order_by(ordering)
|
||||
@@ -125,11 +159,19 @@ async def list_by(
|
||||
return list(await session.exec(stmt))
|
||||
|
||||
|
||||
async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> bool:
|
||||
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
|
||||
async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool:
|
||||
"""Return whether any object exists for lookup values."""
|
||||
return (
|
||||
(await session.exec(_lookup_statement(model, lookup).limit(1))).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]:
|
||||
def _criteria_statement(
|
||||
model: type[ModelT],
|
||||
criteria: tuple[Any, ...],
|
||||
) -> SelectOfScalar[ModelT]:
|
||||
"""Build a select statement from variadic where criteria."""
|
||||
stmt = select(model)
|
||||
if criteria:
|
||||
stmt = stmt.where(*criteria)
|
||||
@@ -139,9 +181,10 @@ def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> Selec
|
||||
async def list_where(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*criteria: Any,
|
||||
*criteria: object,
|
||||
order_by: Iterable[Any] = (),
|
||||
) -> list[ModelT]:
|
||||
"""List objects filtered by explicit SQL criteria."""
|
||||
stmt = _criteria_statement(model, criteria)
|
||||
for ordering in order_by:
|
||||
stmt = stmt.order_by(ordering)
|
||||
@@ -151,9 +194,10 @@ async def list_where(
|
||||
async def delete_where(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*criteria: Any,
|
||||
*criteria: object,
|
||||
commit: bool = False,
|
||||
) -> int:
|
||||
"""Delete rows matching criteria and return affected row count."""
|
||||
stmt: Any = sql_delete(model)
|
||||
if criteria:
|
||||
stmt = stmt.where(*criteria)
|
||||
@@ -167,18 +211,24 @@ async def delete_where(
|
||||
async def update_where(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*criteria: Any,
|
||||
*criteria: object,
|
||||
updates: Mapping[str, Any] | None = None,
|
||||
commit: bool = False,
|
||||
exclude_none: bool = False,
|
||||
allowed_fields: set[str] | None = None,
|
||||
**update_fields: Any,
|
||||
**options: object,
|
||||
) -> int:
|
||||
"""Apply bulk updates by criteria and return affected row count."""
|
||||
commit = bool(options.pop("commit", False))
|
||||
exclude_none = bool(options.pop("exclude_none", False))
|
||||
allowed_fields_raw = options.pop("allowed_fields", None)
|
||||
allowed_fields = (
|
||||
allowed_fields_raw
|
||||
if isinstance(allowed_fields_raw, set)
|
||||
else None
|
||||
)
|
||||
source_updates: dict[str, Any] = {}
|
||||
if updates:
|
||||
source_updates.update(dict(updates))
|
||||
if update_fields:
|
||||
source_updates.update(update_fields)
|
||||
if options:
|
||||
source_updates.update(options)
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
for key, value in source_updates.items():
|
||||
@@ -207,6 +257,7 @@ def apply_updates(
|
||||
exclude_none: bool = False,
|
||||
allowed_fields: set[str] | None = None,
|
||||
) -> ModelT:
|
||||
"""Apply a mapping of field updates onto an object."""
|
||||
for key, value in updates.items():
|
||||
if allowed_fields is not None and key not in allowed_fields:
|
||||
continue
|
||||
@@ -220,12 +271,18 @@ async def patch(
|
||||
session: AsyncSession,
|
||||
obj: ModelT,
|
||||
updates: Mapping[str, Any],
|
||||
*,
|
||||
exclude_none: bool = False,
|
||||
allowed_fields: set[str] | None = None,
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
**options: object,
|
||||
) -> ModelT:
|
||||
"""Apply partial updates and persist object."""
|
||||
exclude_none = bool(options.pop("exclude_none", False))
|
||||
allowed_fields_raw = options.pop("allowed_fields", None)
|
||||
allowed_fields = (
|
||||
allowed_fields_raw
|
||||
if isinstance(allowed_fields_raw, set)
|
||||
else None
|
||||
)
|
||||
commit = bool(options.pop("commit", True))
|
||||
refresh = bool(options.pop("refresh", True))
|
||||
apply_updates(
|
||||
obj,
|
||||
updates,
|
||||
@@ -242,8 +299,9 @@ async def get_or_create(
|
||||
defaults: Mapping[str, Any] | None = None,
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
**lookup: Any,
|
||||
**lookup: object,
|
||||
) -> tuple[ModelT, bool]:
|
||||
"""Get one object by lookup, or create it with defaults."""
|
||||
stmt = _lookup_statement(model, lookup)
|
||||
|
||||
existing = (await session.exec(stmt)).first()
|
||||
|
||||
Reference in New Issue
Block a user