57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Generic, TypeVar
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
from sqlmodel.sql.expression import SelectOfScalar
|
|
|
|
from app.db.queryset import QuerySet, qs
|
|
|
|
ModelT = TypeVar("ModelT")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class APIQuerySet(Generic[ModelT]):
|
|
queryset: QuerySet[ModelT]
|
|
|
|
@property
|
|
def statement(self) -> SelectOfScalar[ModelT]:
|
|
return self.queryset.statement
|
|
|
|
def filter(self, *criteria: Any) -> APIQuerySet[ModelT]:
|
|
return APIQuerySet(self.queryset.filter(*criteria))
|
|
|
|
def order_by(self, *ordering: Any) -> APIQuerySet[ModelT]:
|
|
return APIQuerySet(self.queryset.order_by(*ordering))
|
|
|
|
def limit(self, value: int) -> APIQuerySet[ModelT]:
|
|
return APIQuerySet(self.queryset.limit(value))
|
|
|
|
def offset(self, value: int) -> APIQuerySet[ModelT]:
|
|
return APIQuerySet(self.queryset.offset(value))
|
|
|
|
async def all(self, session: AsyncSession) -> list[ModelT]:
|
|
return await self.queryset.all(session)
|
|
|
|
async def first(self, session: AsyncSession) -> ModelT | None:
|
|
return await self.queryset.first(session)
|
|
|
|
async def first_or_404(
|
|
self,
|
|
session: AsyncSession,
|
|
*,
|
|
detail: str | None = None,
|
|
) -> ModelT:
|
|
obj = await self.first(session)
|
|
if obj is not None:
|
|
return obj
|
|
if detail is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail)
|
|
|
|
|
|
def api_qs(model: type[ModelT]) -> APIQuerySet[ModelT]:
|
|
return APIQuerySet(qs(model))
|