278 lines
7.3 KiB
Python
278 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable, Mapping
|
|
from typing import Any, TypeVar, cast
|
|
|
|
from sqlalchemy import delete as sql_delete
|
|
from sqlalchemy import update as sql_update
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlmodel import SQLModel, select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
from sqlmodel.sql.expression import SelectOfScalar
|
|
|
|
ModelT = TypeVar("ModelT", bound=SQLModel)
|
|
|
|
|
|
class DoesNotExist(LookupError):
|
|
pass
|
|
|
|
|
|
class MultipleObjectsReturned(LookupError):
|
|
pass
|
|
|
|
|
|
async def _flush_or_rollback(session: AsyncSession) -> None:
|
|
try:
|
|
await session.flush()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
|
|
|
|
async def _commit_or_rollback(session: AsyncSession) -> None:
|
|
try:
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
|
|
|
|
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
|
|
stmt = select(model)
|
|
for key, value in lookup.items():
|
|
stmt = stmt.where(getattr(model, key) == value)
|
|
return stmt
|
|
|
|
|
|
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
|
|
stmt = _lookup_statement(model, {"id": obj_id}).limit(1)
|
|
return (await session.exec(stmt)).first()
|
|
|
|
|
|
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:
|
|
stmt = _lookup_statement(model, lookup).limit(2)
|
|
items = (await session.exec(stmt)).all()
|
|
if not items:
|
|
raise DoesNotExist(f"{model.__name__} matching query does not exist.")
|
|
if len(items) > 1:
|
|
raise MultipleObjectsReturned(
|
|
f"Multiple {model.__name__} objects returned for lookup {lookup!r}."
|
|
)
|
|
return items[0]
|
|
|
|
|
|
async def get_one_by(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT | None:
|
|
stmt = _lookup_statement(model, lookup)
|
|
return (await session.exec(stmt)).first()
|
|
|
|
|
|
async def create(
|
|
session: AsyncSession,
|
|
model: type[ModelT],
|
|
*,
|
|
commit: bool = True,
|
|
refresh: bool = True,
|
|
**data: Any,
|
|
) -> ModelT:
|
|
obj = model.model_validate(data)
|
|
session.add(obj)
|
|
await _flush_or_rollback(session)
|
|
if commit:
|
|
await _commit_or_rollback(session)
|
|
if refresh:
|
|
await session.refresh(obj)
|
|
return obj
|
|
|
|
|
|
async def save(
|
|
session: AsyncSession,
|
|
obj: ModelT,
|
|
*,
|
|
commit: bool = True,
|
|
refresh: bool = True,
|
|
) -> ModelT:
|
|
session.add(obj)
|
|
await _flush_or_rollback(session)
|
|
if commit:
|
|
await _commit_or_rollback(session)
|
|
if refresh:
|
|
await session.refresh(obj)
|
|
return obj
|
|
|
|
|
|
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
|
await session.delete(obj)
|
|
if commit:
|
|
await _commit_or_rollback(session)
|
|
|
|
|
|
async def list_by(
|
|
session: AsyncSession,
|
|
model: type[ModelT],
|
|
*,
|
|
order_by: Iterable[Any] = (),
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
**lookup: Any,
|
|
) -> list[ModelT]:
|
|
stmt = _lookup_statement(model, lookup)
|
|
for ordering in order_by:
|
|
stmt = stmt.order_by(ordering)
|
|
if offset is not None:
|
|
stmt = stmt.offset(offset)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
return list(await session.exec(stmt))
|
|
|
|
|
|
async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> bool:
|
|
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
|
|
|
|
|
|
def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]:
|
|
stmt = select(model)
|
|
if criteria:
|
|
stmt = stmt.where(*criteria)
|
|
return stmt
|
|
|
|
|
|
async def list_where(
|
|
session: AsyncSession,
|
|
model: type[ModelT],
|
|
*criteria: Any,
|
|
order_by: Iterable[Any] = (),
|
|
) -> list[ModelT]:
|
|
stmt = _criteria_statement(model, criteria)
|
|
for ordering in order_by:
|
|
stmt = stmt.order_by(ordering)
|
|
return list(await session.exec(stmt))
|
|
|
|
|
|
async def delete_where(
|
|
session: AsyncSession,
|
|
model: type[ModelT],
|
|
*criteria: Any,
|
|
commit: bool = False,
|
|
) -> int:
|
|
stmt = sql_delete(model)
|
|
if criteria:
|
|
stmt = stmt.where(*criteria)
|
|
result = await session.exec(cast(Any, stmt))
|
|
if commit:
|
|
await _commit_or_rollback(session)
|
|
rowcount = getattr(result, "rowcount", None)
|
|
return int(rowcount) if isinstance(rowcount, int) else 0
|
|
|
|
|
|
async def update_where(
|
|
session: AsyncSession,
|
|
model: type[ModelT],
|
|
*criteria: Any,
|
|
updates: Mapping[str, Any] | None = None,
|
|
commit: bool = False,
|
|
exclude_none: bool = False,
|
|
allowed_fields: set[str] | None = None,
|
|
**update_fields: Any,
|
|
) -> int:
|
|
source_updates: dict[str, Any] = {}
|
|
if updates:
|
|
source_updates.update(dict(updates))
|
|
if update_fields:
|
|
source_updates.update(update_fields)
|
|
|
|
values: dict[str, Any] = {}
|
|
for key, value in source_updates.items():
|
|
if allowed_fields is not None and key not in allowed_fields:
|
|
continue
|
|
if exclude_none and value is None:
|
|
continue
|
|
values[key] = value
|
|
if not values:
|
|
return 0
|
|
|
|
stmt = sql_update(model).values(**values)
|
|
if criteria:
|
|
stmt = stmt.where(*criteria)
|
|
result = await session.exec(cast(Any, stmt))
|
|
if commit:
|
|
await _commit_or_rollback(session)
|
|
rowcount = getattr(result, "rowcount", None)
|
|
return int(rowcount) if isinstance(rowcount, int) else 0
|
|
|
|
|
|
def apply_updates(
|
|
obj: ModelT,
|
|
updates: Mapping[str, Any],
|
|
*,
|
|
exclude_none: bool = False,
|
|
allowed_fields: set[str] | None = None,
|
|
) -> ModelT:
|
|
for key, value in updates.items():
|
|
if allowed_fields is not None and key not in allowed_fields:
|
|
continue
|
|
if exclude_none and value is None:
|
|
continue
|
|
setattr(obj, key, value)
|
|
return obj
|
|
|
|
|
|
async def patch(
|
|
session: AsyncSession,
|
|
obj: ModelT,
|
|
updates: Mapping[str, Any],
|
|
*,
|
|
exclude_none: bool = False,
|
|
allowed_fields: set[str] | None = None,
|
|
commit: bool = True,
|
|
refresh: bool = True,
|
|
) -> ModelT:
|
|
apply_updates(
|
|
obj,
|
|
updates,
|
|
exclude_none=exclude_none,
|
|
allowed_fields=allowed_fields,
|
|
)
|
|
return await save(session, obj, commit=commit, refresh=refresh)
|
|
|
|
|
|
async def get_or_create(
|
|
session: AsyncSession,
|
|
model: type[ModelT],
|
|
*,
|
|
defaults: Mapping[str, Any] | None = None,
|
|
commit: bool = True,
|
|
refresh: bool = True,
|
|
**lookup: Any,
|
|
) -> tuple[ModelT, bool]:
|
|
stmt = _lookup_statement(model, lookup)
|
|
|
|
existing = (await session.exec(stmt)).first()
|
|
if existing is not None:
|
|
return existing, False
|
|
|
|
payload: dict[str, Any] = dict(lookup)
|
|
if defaults:
|
|
for key, value in defaults.items():
|
|
payload.setdefault(key, value)
|
|
|
|
obj = model.model_validate(payload)
|
|
session.add(obj)
|
|
try:
|
|
await session.flush()
|
|
if commit:
|
|
await session.commit()
|
|
except IntegrityError:
|
|
# If another concurrent request inserted the same unique row, surface that row.
|
|
await session.rollback()
|
|
existing = (await session.exec(stmt)).first()
|
|
if existing is not None:
|
|
return existing, False
|
|
raise
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
|
|
if refresh:
|
|
await session.refresh(obj)
|
|
return obj, True
|