feat: add validation for minimum length on various fields and update type definitions
This commit is contained in:
125
backend/app/db/crud.py
Normal file
125
backend/app/db/crud.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=SQLModel)
|
||||
|
||||
|
||||
class DoesNotExist(LookupError):
|
||||
pass
|
||||
|
||||
|
||||
class MultipleObjectsReturned(LookupError):
|
||||
pass
|
||||
|
||||
|
||||
async def get_by_id(session: AsyncSession, model: type[ModelT], obj_id: Any) -> ModelT | None:
|
||||
return await session.get(model, obj_id)
|
||||
|
||||
|
||||
async def get(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
stmt = stmt.limit(2)
|
||||
items = (await session.exec(stmt)).all()
|
||||
if not items:
|
||||
raise DoesNotExist(f"{model.__name__} matching query does not exist.")
|
||||
if len(items) > 1:
|
||||
raise MultipleObjectsReturned(
|
||||
f"Multiple {model.__name__} objects returned for lookup {lookup!r}."
|
||||
)
|
||||
return items[0]
|
||||
|
||||
|
||||
async def get_one_by(session: AsyncSession, model: type[ModelT], **lookup: Any) -> ModelT | None:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
return (await session.exec(stmt)).first()
|
||||
|
||||
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*,
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
**data: Any,
|
||||
) -> ModelT:
|
||||
obj = model.model_validate(data)
|
||||
session.add(obj)
|
||||
await session.flush()
|
||||
if commit:
|
||||
await session.commit()
|
||||
if refresh:
|
||||
await session.refresh(obj)
|
||||
return obj
|
||||
|
||||
|
||||
async def save(
|
||||
session: AsyncSession,
|
||||
obj: ModelT,
|
||||
*,
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
) -> ModelT:
|
||||
session.add(obj)
|
||||
await session.flush()
|
||||
if commit:
|
||||
await session.commit()
|
||||
if refresh:
|
||||
await session.refresh(obj)
|
||||
return obj
|
||||
|
||||
|
||||
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
||||
await session.delete(obj)
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_or_create(
|
||||
session: AsyncSession,
|
||||
model: type[ModelT],
|
||||
*,
|
||||
defaults: Mapping[str, Any] | None = None,
|
||||
commit: bool = True,
|
||||
refresh: bool = True,
|
||||
**lookup: Any,
|
||||
) -> tuple[ModelT, bool]:
|
||||
stmt = select(model)
|
||||
for key, value in lookup.items():
|
||||
stmt = stmt.where(getattr(model, key) == value)
|
||||
|
||||
existing = (await session.exec(stmt)).first()
|
||||
if existing is not None:
|
||||
return existing, False
|
||||
|
||||
payload: dict[str, Any] = dict(lookup)
|
||||
if defaults:
|
||||
for key, value in defaults.items():
|
||||
payload.setdefault(key, value)
|
||||
|
||||
obj = model.model_validate(payload)
|
||||
session.add(obj)
|
||||
try:
|
||||
await session.flush()
|
||||
if commit:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
# If another concurrent request inserted the same unique row, surface that row.
|
||||
await session.rollback()
|
||||
existing = (await session.exec(stmt)).first()
|
||||
if existing is not None:
|
||||
return existing, False
|
||||
raise
|
||||
|
||||
if refresh:
|
||||
await session.refresh(obj)
|
||||
return obj, True
|
||||
Reference in New Issue
Block a user