refactor: replace SQLModel with QueryModel in various models and update query methods

This commit is contained in:
Abhimanyu Saharan
2026-02-09 02:04:14 +05:30
parent e19e47106b
commit 228b99bc9b
40 changed files with 413 additions and 419 deletions

View File

@@ -27,7 +27,8 @@ def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectO
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
return await session.get(model, obj_id)
stmt = _lookup_statement(model, {"id": obj_id}).limit(1)
return (await session.exec(stmt)).first()
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from sqlalchemy import false
from sqlmodel import SQLModel, col
from app.db.queryset import QuerySet, qs
ModelT = TypeVar("ModelT", bound=SQLModel)
@dataclass(frozen=True)
class ModelManager(Generic[ModelT]):
model: type[ModelT]
id_field: str = "id"
def all(self) -> QuerySet[ModelT]:
return qs(self.model)
def none(self) -> QuerySet[ModelT]:
return qs(self.model).filter(false())
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
return self.all().filter(*criteria)
def where(self, *criteria: Any) -> QuerySet[ModelT]:
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
queryset = self.all()
for field_name, value in kwargs.items():
queryset = queryset.filter(col(getattr(self.model, field_name)) == value)
return queryset
def by_id(self, obj_id: Any) -> QuerySet[ModelT]:
return self.by_field(self.id_field, obj_id)
def by_ids(self, obj_ids: list[Any] | tuple[Any, ...] | set[Any]) -> QuerySet[ModelT]:
return self.by_field_in(self.id_field, obj_ids)
def by_field(self, field_name: str, value: Any) -> QuerySet[ModelT]:
return self.filter(col(getattr(self.model, field_name)) == value)
def by_field_in(
self,
field_name: str,
values: list[Any] | tuple[Any, ...] | set[Any],
) -> QuerySet[ModelT]:
seq = tuple(values)
if not seq:
return self.none()
return self.filter(col(getattr(self.model, field_name)).in_(seq))
class ManagerDescriptor(Generic[ModelT]):
def __get__(self, instance: object, owner: type[ModelT]) -> ModelManager[ModelT]:
return ModelManager(owner)

View File

@@ -17,6 +17,13 @@ class QuerySet(Generic[ModelT]):
def filter(self, *criteria: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.where(*criteria))
def where(self, *criteria: Any) -> QuerySet[ModelT]:
return self.filter(*criteria)
def filter_by(self, **kwargs: Any) -> QuerySet[ModelT]:
statement = self.statement.filter_by(**kwargs)
return replace(self, statement=statement)
def order_by(self, *ordering: Any) -> QuerySet[ModelT]:
return replace(self, statement=self.statement.order_by(*ordering))