refactor: improve type coercion functions and enhance type hints across multiple files
This commit is contained in:
@@ -3,13 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import SQLModel, col
|
||||
|
||||
from app.db.queryset import QuerySet, qs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
|
||||
@@ -49,7 +52,7 @@ class ModelManager(Generic[ModelT]):
|
||||
|
||||
def by_ids(
|
||||
self,
|
||||
obj_ids: list[object] | tuple[object, ...] | set[object],
|
||||
obj_ids: Iterable[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by a set/list/tuple of identifiers."""
|
||||
return self.by_field_in(self.id_field, obj_ids)
|
||||
@@ -61,7 +64,7 @@ class ModelManager(Generic[ModelT]):
|
||||
def by_field_in(
|
||||
self,
|
||||
field_name: str,
|
||||
values: list[object] | tuple[object, ...] | set[object],
|
||||
values: Iterable[object],
|
||||
) -> QuerySet[ModelT]:
|
||||
"""Return queryset filtered by `field IN values` semantics."""
|
||||
seq = tuple(values)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
@@ -22,7 +22,11 @@ class QuerySet(Generic[ModelT]):
|
||||
|
||||
def filter(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with additional SQL criteria."""
|
||||
return replace(self, statement=self.statement.where(*criteria))
|
||||
statement = cast(
|
||||
"SelectOfScalar[ModelT]",
|
||||
cast(Any, self.statement).where(*criteria),
|
||||
)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def where(self, *criteria: object) -> QuerySet[ModelT]:
|
||||
"""Alias for `filter` to mirror SQLAlchemy naming."""
|
||||
@@ -35,7 +39,11 @@ class QuerySet(Generic[ModelT]):
|
||||
|
||||
def order_by(self, *ordering: object) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with ordering clauses applied."""
|
||||
return replace(self, statement=self.statement.order_by(*ordering))
|
||||
statement = cast(
|
||||
"SelectOfScalar[ModelT]",
|
||||
cast(Any, self.statement).order_by(*ordering),
|
||||
)
|
||||
return replace(self, statement=statement)
|
||||
|
||||
def limit(self, value: int) -> QuerySet[ModelT]:
|
||||
"""Return a new queryset with a SQL row limit."""
|
||||
|
||||
Reference in New Issue
Block a user