Files
openclaw-mission-control/backend/app/db/crud.py

325 lines
9.2 KiB
Python
Raw Normal View History

"""Generic asynchronous CRUD helpers for SQLModel entities."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, TypeVar
from sqlalchemy import delete as sql_delete
from sqlalchemy import update as sql_update
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlmodel import SQLModel, select
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
ModelT = TypeVar("ModelT", bound=SQLModel)
class DoesNotExistError(LookupError):
"""Raised when a query expected one row but found none."""
class MultipleObjectsReturnedError(LookupError):
"""Raised when a query expected one row but found many."""
DoesNotExist = DoesNotExistError
MultipleObjectsReturned = MultipleObjectsReturnedError
async def _flush_or_rollback(session: AsyncSession) -> None:
"""Flush changes and rollback on SQLAlchemy errors."""
try:
await session.flush()
except SQLAlchemyError:
await session.rollback()
raise
async def _commit_or_rollback(session: AsyncSession) -> None:
"""Commit transaction and rollback on SQLAlchemy errors."""
try:
await session.commit()
except SQLAlchemyError:
await session.rollback()
raise
def _lookup_statement(
model: type[ModelT],
lookup: Mapping[str, Any],
) -> SelectOfScalar[ModelT]:
"""Build a select statement with equality filters from lookup values."""
stmt = select(model)
for key, value in lookup.items():
stmt = stmt.where(getattr(model, key) == value)
return stmt
async def get_by_id(
session: AsyncSession,
model: type[ModelT],
obj_id: object,
) -> ModelT | None:
"""Fetch one model instance by id or return None."""
stmt = _lookup_statement(model, {"id": obj_id}).limit(1)
return (await session.exec(stmt)).first()
async def get(
session: AsyncSession,
model: type[ModelT],
**lookup: object,
) -> ModelT:
"""Fetch exactly one model instance by lookup values."""
stmt = _lookup_statement(model, lookup).limit(2)
items = (await session.exec(stmt)).all()
if not items:
message = f"{model.__name__} matching query does not exist."
raise DoesNotExist(message)
if len(items) > 1:
message = f"Multiple {model.__name__} objects returned for lookup {lookup!r}."
raise MultipleObjectsReturned(message)
return items[0]
async def get_one_by(
session: AsyncSession,
model: type[ModelT],
**lookup: object,
) -> ModelT | None:
"""Fetch the first model instance matching lookup values."""
stmt = _lookup_statement(model, lookup)
return (await session.exec(stmt)).first()
async def create(
session: AsyncSession,
model: type[ModelT],
*,
commit: bool = True,
refresh: bool = True,
**data: object,
) -> ModelT:
"""Create, flush, optionally commit, and optionally refresh an object."""
obj = model.model_validate(data)
session.add(obj)
await _flush_or_rollback(session)
if commit:
await _commit_or_rollback(session)
if refresh:
await session.refresh(obj)
return obj
async def save(
session: AsyncSession,
obj: ModelT,
*,
commit: bool = True,
refresh: bool = True,
) -> ModelT:
"""Persist an existing object with optional commit and refresh."""
session.add(obj)
await _flush_or_rollback(session)
if commit:
await _commit_or_rollback(session)
if refresh:
await session.refresh(obj)
return obj
async def delete(session: AsyncSession, obj: SQLModel, *, commit: bool = True) -> None:
"""Delete an object with optional commit."""
await session.delete(obj)
if commit:
await _commit_or_rollback(session)
async def list_by(
session: AsyncSession,
model: type[ModelT],
*,
order_by: Iterable[Any] = (),
limit: int | None = None,
offset: int | None = None,
**lookup: object,
) -> list[ModelT]:
"""List objects by lookup values with optional ordering and pagination."""
stmt = _lookup_statement(model, lookup)
for ordering in order_by:
stmt = stmt.order_by(ordering)
if offset is not None:
stmt = stmt.offset(offset)
if limit is not None:
stmt = stmt.limit(limit)
return list(await session.exec(stmt))
async def exists(session: AsyncSession, model: type[ModelT], **lookup: object) -> bool:
"""Return whether any object exists for lookup values."""
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
def _criteria_statement(
model: type[ModelT],
criteria: tuple[Any, ...],
) -> SelectOfScalar[ModelT]:
"""Build a select statement from variadic where criteria."""
stmt = select(model)
if criteria:
stmt = stmt.where(*criteria)
return stmt
async def list_where(
session: AsyncSession,
model: type[ModelT],
*criteria: object,
order_by: Iterable[Any] = (),
) -> list[ModelT]:
"""List objects filtered by explicit SQL criteria."""
stmt = _criteria_statement(model, criteria)
for ordering in order_by:
stmt = stmt.order_by(ordering)
return list(await session.exec(stmt))
async def delete_where(
session: AsyncSession,
model: type[ModelT],
*criteria: object,
commit: bool = False,
) -> int:
"""Delete rows matching criteria and return affected row count."""
stmt: Any = sql_delete(model)
if criteria:
stmt = stmt.where(*criteria)
result = await session.exec(stmt)
if commit:
await _commit_or_rollback(session)
rowcount = getattr(result, "rowcount", None)
return int(rowcount) if isinstance(rowcount, int) else 0
async def update_where(
session: AsyncSession,
model: type[ModelT],
*criteria: object,
updates: Mapping[str, Any] | None = None,
**options: object,
) -> int:
"""Apply bulk updates by criteria and return affected row count."""
commit = bool(options.pop("commit", False))
exclude_none = bool(options.pop("exclude_none", False))
allowed_fields_raw = options.pop("allowed_fields", None)
allowed_fields = allowed_fields_raw if isinstance(allowed_fields_raw, set) else None
source_updates: dict[str, Any] = {}
if updates:
source_updates.update(dict(updates))
if options:
source_updates.update(options)
values: dict[str, Any] = {}
for key, value in source_updates.items():
if allowed_fields is not None and key not in allowed_fields:
continue
if exclude_none and value is None:
continue
values[key] = value
if not values:
return 0
stmt: Any = sql_update(model).values(**values)
if criteria:
stmt = stmt.where(*criteria)
result = await session.exec(stmt)
if commit:
await _commit_or_rollback(session)
rowcount = getattr(result, "rowcount", None)
return int(rowcount) if isinstance(rowcount, int) else 0
def apply_updates(
obj: ModelT,
updates: Mapping[str, Any],
*,
exclude_none: bool = False,
allowed_fields: set[str] | None = None,
) -> ModelT:
"""Apply a mapping of field updates onto an object."""
for key, value in updates.items():
if allowed_fields is not None and key not in allowed_fields:
continue
if exclude_none and value is None:
continue
setattr(obj, key, value)
return obj
async def patch(
session: AsyncSession,
obj: ModelT,
updates: Mapping[str, Any],
**options: object,
) -> ModelT:
"""Apply partial updates and persist object."""
exclude_none = bool(options.pop("exclude_none", False))
allowed_fields_raw = options.pop("allowed_fields", None)
allowed_fields = allowed_fields_raw if isinstance(allowed_fields_raw, set) else None
commit = bool(options.pop("commit", True))
refresh = bool(options.pop("refresh", True))
apply_updates(
obj,
updates,
exclude_none=exclude_none,
allowed_fields=allowed_fields,
)
return await save(session, obj, commit=commit, refresh=refresh)
async def get_or_create(
session: AsyncSession,
model: type[ModelT],
*,
defaults: Mapping[str, Any] | None = None,
commit: bool = True,
refresh: bool = True,
**lookup: object,
) -> tuple[ModelT, bool]:
"""Get one object by lookup, or create it with defaults."""
stmt = _lookup_statement(model, lookup)
existing = (await session.exec(stmt)).first()
if existing is not None:
return existing, False
payload: dict[str, Any] = dict(lookup)
if defaults:
for key, value in defaults.items():
payload.setdefault(key, value)
obj = model.model_validate(payload)
session.add(obj)
try:
await session.flush()
if commit:
await session.commit()
except IntegrityError:
# If another concurrent request inserted the same unique row, surface that row.
await session.rollback()
existing = (await session.exec(stmt)).first()
if existing is not None:
return existing, False
raise
except SQLAlchemyError:
await session.rollback()
raise
if refresh:
await session.refresh(obj)
return obj, True