82 lines
2.7 KiB
Python
82 lines
2.7 KiB
Python
"""Model manager descriptor utilities for query-set style access."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Generic, TypeVar
|
|
|
|
from sqlalchemy import false
|
|
from sqlmodel import SQLModel, col
|
|
|
|
from app.db.queryset import QuerySet, qs
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable
|
|
|
|
ModelT = TypeVar("ModelT", bound=SQLModel)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ModelManager(Generic[ModelT]):
|
|
"""Convenience query manager bound to a SQLModel class."""
|
|
|
|
model: type[ModelT]
|
|
id_field: str = "id"
|
|
|
|
def all(self) -> QuerySet[ModelT]:
|
|
"""Return an unfiltered queryset for the bound model."""
|
|
return qs(self.model)
|
|
|
|
def none(self) -> QuerySet[ModelT]:
|
|
"""Return a queryset that yields no rows."""
|
|
return qs(self.model).filter(false())
|
|
|
|
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
|
"""Return queryset filtered by SQL criteria expressions."""
|
|
return self.all().filter(*criteria)
|
|
|
|
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
|
"""Alias for `filter`."""
|
|
return self.filter(*criteria)
|
|
|
|
def filter_by(self, **kwargs: object) -> QuerySet[ModelT]:
|
|
"""Return queryset filtered by model field equality values."""
|
|
queryset = self.all()
|
|
for field_name, value in kwargs.items():
|
|
queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
|
|
return queryset
|
|
|
|
def by_id(self, obj_id: object) -> QuerySet[ModelT]:
|
|
"""Return queryset filtered by primary identifier field."""
|
|
return self.by_field(self.id_field, obj_id)
|
|
|
|
def by_ids(
|
|
self,
|
|
obj_ids: Iterable[object],
|
|
) -> QuerySet[ModelT]:
|
|
"""Return queryset filtered by a set/list/tuple of identifiers."""
|
|
return self.by_field_in(self.id_field, obj_ids)
|
|
|
|
def by_field(self, field_name: str, value: object) -> QuerySet[ModelT]:
|
|
"""Return queryset filtered by a single field equality check."""
|
|
return self.filter(col(getattr(self.model, field_name)) == value)
|
|
|
|
def by_field_in(
|
|
self,
|
|
field_name: str,
|
|
values: Iterable[object],
|
|
) -> QuerySet[ModelT]:
|
|
"""Return queryset filtered by `field IN values` semantics."""
|
|
seq = tuple(values)
|
|
if not seq:
|
|
return self.none()
|
|
return self.filter(col(getattr(self.model, field_name)).in_(seq))
|
|
|
|
|
|
class ManagerDescriptor(Generic[ModelT]):
|
|
"""Descriptor that exposes a model-bound `ModelManager` as `.objects`."""
|
|
|
|
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
|
|
"""Return a fresh manager bound to the owning model class."""
|
|
return ModelManager(owner)
|