diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index 4bd58938..6a81efb1 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -9,7 +9,7 @@ from typing import Any, cast from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException, Query, Request, status -from sqlalchemy import asc, or_, update +from sqlalchemy import asc, or_ from sqlalchemy.sql.elements import ColumnElement from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,9 +19,9 @@ from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.auth import AuthContext, get_auth_context from app.core.time import utcnow +from app.db import crud from app.db.pagination import paginate from app.db.session import async_session_maker, get_session -from app.db.sqlmodel_exec import exec_dml from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.models.activity_events import ActivityEvent @@ -101,7 +101,7 @@ async def _require_board( session: AsyncSession, board_id: UUID | str | None, *, - user: object | None = None, + user: User | None = None, write: bool = False, ) -> Board: if not board_id: @@ -113,7 +113,7 @@ async def _require_board( if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") if user is not None: - await require_board_access(session, user=user, board=board, write=write) # type: ignore[arg-type] + await require_board_access(session, user=user, board=board, write=write) return board @@ -972,31 +972,32 @@ async def delete_agent( agent_id=None, ) now = datetime.now() - await exec_dml( + await crud.update_where( session, - update(Task) - .where(col(Task.assigned_agent_id) == agent.id) - .where(col(Task.status) == "in_progress") - .values( - assigned_agent_id=None, - status="inbox", - in_progress_at=None, - updated_at=now, - ), + Task, + col(Task.assigned_agent_id) == agent.id, + col(Task.status) == "in_progress", + assigned_agent_id=None, + status="inbox", + in_progress_at=None, + updated_at=now, + commit=False, ) - await exec_dml( + await crud.update_where( session, - update(Task) - .where(col(Task.assigned_agent_id) == agent.id) - .where(col(Task.status) != "in_progress") - .values( - assigned_agent_id=None, - updated_at=now, - ), + Task, + col(Task.assigned_agent_id) == agent.id, + col(Task.status) != "in_progress", + assigned_agent_id=None, + updated_at=now, + commit=False, ) - await exec_dml( + await crud.update_where( session, - update(ActivityEvent).where(col(ActivityEvent.agent_id) == agent.id).values(agent_id=None), + ActivityEvent, + col(ActivityEvent.agent_id) == agent.id, + agent_id=None, + commit=False, ) await session.delete(agent) await session.commit() diff --git a/backend/app/api/board_groups.py b/backend/app/api/board_groups.py index a64e1515..2cbef618 100644 --- a/backend/app/api/board_groups.py +++ b/backend/app/api/board_groups.py @@ -5,7 +5,7 @@ from typing import Any, cast from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy import delete, func, update +from sqlalchemy import func from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -14,7 +14,6 @@ from app.core.time import utcnow from app.db import crud from app.db.pagination import paginate from app.db.session import get_session -from app.db.sqlmodel_exec import exec_dml from app.models.agents import Agent from app.models.board_group_memory import BoardGroupMemory from app.models.board_groups import BoardGroup @@ -276,14 +275,16 @@ async def delete_board_group( await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) # Boards reference groups, so clear the FK first to keep deletes simple. - await exec_dml( + await crud.update_where( session, - update(Board).where(col(Board.board_group_id) == group_id).values(board_group_id=None), + Board, + col(Board.board_group_id) == group_id, + board_group_id=None, + commit=False, ) - await exec_dml( - session, - delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id), + await crud.delete_where( + session, BoardGroupMemory, col(BoardGroupMemory.board_group_id) == group_id, commit=False ) - await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.id) == group_id)) + await crud.delete_where(session, BoardGroup, col(BoardGroup.id) == group_id, commit=False) await session.commit() return OkResponse() diff --git a/backend/app/api/boards.py b/backend/app/api/boards.py index f3415722..b0f72198 100644 --- a/backend/app/api/boards.py +++ b/backend/app/api/boards.py @@ -4,7 +4,7 @@ import re from uuid import UUID, uuid4 from fastapi import APIRouter, Depends, HTTPException, Query, status -from sqlalchemy import delete, func +from sqlalchemy import func from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,7 +19,6 @@ from app.core.time import utcnow from app.db import crud from app.db.pagination import paginate from app.db.session import get_session -from app.db.sqlmodel_exec import exec_dml from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import ( OpenClawGatewayError, @@ -307,43 +306,38 @@ async def delete_board( ) from exc if task_ids: - await exec_dml( - session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)) + await crud.delete_where( + session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False ) - await exec_dml(session, delete(TaskDependency).where(col(TaskDependency.board_id) == board.id)) - await exec_dml( - session, delete(TaskFingerprint).where(col(TaskFingerprint.board_id) == board.id) - ) + await crud.delete_where(session, TaskDependency, col(TaskDependency.board_id) == board.id) + await crud.delete_where(session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id) # Approvals can reference tasks and agents, so delete before both. - await exec_dml(session, delete(Approval).where(col(Approval.board_id) == board.id)) + await crud.delete_where(session, Approval, col(Approval.board_id) == board.id) - await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id) == board.id)) - await exec_dml( - session, - delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id) == board.id), + await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id) + await crud.delete_where( + session, BoardOnboardingSession, col(BoardOnboardingSession.board_id) == board.id ) - await exec_dml( - session, - delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id) == board.id), + await crud.delete_where( + session, OrganizationBoardAccess, col(OrganizationBoardAccess.board_id) == board.id ) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationInviteBoardAccess).where( - col(OrganizationInviteBoardAccess.board_id) == board.id - ), + OrganizationInviteBoardAccess, + col(OrganizationInviteBoardAccess.board_id) == board.id, ) # Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so # delete tasks before agents. - await exec_dml(session, delete(Task).where(col(Task.board_id) == board.id)) + await crud.delete_where(session, Task, col(Task.board_id) == board.id) if agents: agent_ids = [agent.id for agent in agents] - await exec_dml( - session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids)) + await crud.delete_where( + session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False ) - await exec_dml(session, delete(Agent).where(col(Agent.id).in_(agent_ids))) + await crud.delete_where(session, Agent, col(Agent.id).in_(agent_ids)) await session.delete(board) await session.commit() return OkResponse() diff --git a/backend/app/api/gateway.py b/backend/app/api/gateway.py index 06ccf543..f22a5b4e 100644 --- a/backend/app/api/gateway.py +++ b/backend/app/api/gateway.py @@ -21,6 +21,7 @@ from app.integrations.openclaw_gateway_protocol import ( ) from app.models.boards import Board from app.models.gateways import Gateway +from app.models.users import User from app.schemas.common import OkResponse from app.schemas.gateway_api import ( GatewayCommandsResponse, @@ -43,7 +44,7 @@ async def _resolve_gateway( gateway_token: str | None, gateway_main_session_key: str | None, *, - user: object | None = None, + user: User | None = None, ) -> tuple[Board | None, GatewayClientConfig, str | None]: if gateway_url: return ( @@ -59,8 +60,8 @@ async def _resolve_gateway( board = await Board.objects.by_id(board_id).first(session) if board is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") - if isinstance(user, object) and user is not None: - await require_board_access(session, user=user, board=board, write=False) # type: ignore[arg-type] + if user is not None: + await require_board_access(session, user=user, board=board, write=False) if not board.gateway_id: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -85,7 +86,7 @@ async def _resolve_gateway( async def _require_gateway( - session: AsyncSession, board_id: str | None, *, user: object | None = None + session: AsyncSession, board_id: str | None, *, user: User | None = None ) -> tuple[Board, GatewayClientConfig, str | None]: board, config, main_session = await _resolve_gateway( session, diff --git a/backend/app/api/organizations.py b/backend/app/api/organizations.py index 6561dcd8..2a6db257 100644 --- a/backend/app/api/organizations.py +++ b/backend/app/api/organizations.py @@ -5,7 +5,7 @@ from typing import Any, Sequence from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy import delete, func, update +from sqlalchemy import func from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -15,7 +15,6 @@ from app.core.time import utcnow from app.db import crud from app.db.pagination import paginate from app.db.session import get_session -from app.db.sqlmodel_exec import exec_dml from app.models.activity_events import ActivityEvent from app.models.agents import Agent from app.models.approvals import Approval @@ -214,67 +213,85 @@ async def delete_my_org( ) group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id) - await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids))) - await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids))) - await exec_dml( - session, delete(TaskDependency).where(col(TaskDependency.board_id).in_(board_ids)) + await crud.delete_where( + session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False ) - await exec_dml( + await crud.delete_where( + session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False + ) + await crud.delete_where( + session, TaskDependency, col(TaskDependency.board_id).in_(board_ids), commit=False + ) + await crud.delete_where( + session, TaskFingerprint, col(TaskFingerprint.board_id).in_(board_ids), commit=False + ) + await crud.delete_where(session, Approval, col(Approval.board_id).in_(board_ids), commit=False) + await crud.delete_where( + session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False + ) + await crud.delete_where( session, - delete(TaskFingerprint).where(col(TaskFingerprint.board_id).in_(board_ids)), + BoardOnboardingSession, + col(BoardOnboardingSession.board_id).in_(board_ids), + commit=False, ) - await exec_dml(session, delete(Approval).where(col(Approval.board_id).in_(board_ids))) - await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id).in_(board_ids))) - await exec_dml( + await crud.delete_where( session, - delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id).in_(board_ids)), + OrganizationBoardAccess, + col(OrganizationBoardAccess.board_id).in_(board_ids), + commit=False, ) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id).in_(board_ids)), + OrganizationInviteBoardAccess, + col(OrganizationInviteBoardAccess.board_id).in_(board_ids), + commit=False, ) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationInviteBoardAccess).where( - col(OrganizationInviteBoardAccess.board_id).in_(board_ids) - ), + OrganizationBoardAccess, + col(OrganizationBoardAccess.organization_member_id).in_(member_ids), + commit=False, ) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationBoardAccess).where( - col(OrganizationBoardAccess.organization_member_id).in_(member_ids) - ), + OrganizationInviteBoardAccess, + col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids), + commit=False, ) - await exec_dml( + await crud.delete_where(session, Task, col(Task.board_id).in_(board_ids), commit=False) + await crud.delete_where(session, Agent, col(Agent.board_id).in_(board_ids), commit=False) + await crud.delete_where(session, Board, col(Board.organization_id) == org_id, commit=False) + await crud.delete_where( session, - delete(OrganizationInviteBoardAccess).where( - col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids) - ), + BoardGroupMemory, + col(BoardGroupMemory.board_group_id).in_(group_ids), + commit=False, ) - await exec_dml(session, delete(Task).where(col(Task.board_id).in_(board_ids))) - await exec_dml(session, delete(Agent).where(col(Agent.board_id).in_(board_ids))) - await exec_dml(session, delete(Board).where(col(Board.organization_id) == org_id)) - await exec_dml( + await crud.delete_where( + session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False + ) + await crud.delete_where(session, Gateway, col(Gateway.organization_id) == org_id, commit=False) + await crud.delete_where( session, - delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id).in_(group_ids)), + OrganizationInvite, + col(OrganizationInvite.organization_id) == org_id, + commit=False, ) - await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.organization_id) == org_id)) - await exec_dml(session, delete(Gateway).where(col(Gateway.organization_id) == org_id)) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationInvite).where(col(OrganizationInvite.organization_id) == org_id), + OrganizationMember, + col(OrganizationMember.organization_id) == org_id, + commit=False, ) - await exec_dml( + await crud.update_where( session, - delete(OrganizationMember).where(col(OrganizationMember.organization_id) == org_id), + User, + col(User.active_organization_id) == org_id, + active_organization_id=None, + commit=False, ) - await exec_dml( - session, - update(User) - .where(col(User.active_organization_id) == org_id) - .values(active_organization_id=None), - ) - await exec_dml(session, delete(Organization).where(col(Organization.id) == org_id)) + await crud.delete_where(session, Organization, col(Organization.id) == org_id, commit=False) await session.commit() return OkResponse() @@ -425,11 +442,11 @@ async def remove_org_member( detail="Organization must have at least one owner", ) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationBoardAccess).where( - col(OrganizationBoardAccess.organization_member_id) == member.id - ), + OrganizationBoardAccess, + col(OrganizationBoardAccess.organization_member_id) == member.id, + commit=False, ) user = await User.objects.by_id(member.user_id).first(session) @@ -532,11 +549,11 @@ async def revoke_org_invite( organization_id=ctx.organization.id, invite_id=invite_id, ) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationInviteBoardAccess).where( - col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id - ), + OrganizationInviteBoardAccess, + col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id, + commit=False, ) await crud.delete(session, invite) return OrganizationInviteRead.model_validate(invite, from_attributes=True) diff --git a/backend/app/api/tasks.py b/backend/app/api/tasks.py index 579cc52c..1e514a9c 100644 --- a/backend/app/api/tasks.py +++ b/backend/app/api/tasks.py @@ -9,7 +9,7 @@ from typing import cast from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Request, status -from sqlalchemy import asc, delete, desc, or_ +from sqlalchemy import asc, desc, or_ from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql.expression import Select @@ -25,9 +25,9 @@ from app.api.deps import ( ) from app.core.auth import AuthContext from app.core.time import utcnow +from app.db import crud from app.db.pagination import paginate from app.db.session import async_session_maker, get_session -from app.db.sqlmodel_exec import exec_dml from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.models.activity_events import ActivityEvent @@ -997,17 +997,21 @@ async def delete_task( if auth.user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) await require_board_access(session, user=auth.user, board=board, write=True) - await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id) == task.id)) - await exec_dml(session, delete(TaskFingerprint).where(col(TaskFingerprint.task_id) == task.id)) - await exec_dml(session, delete(Approval).where(col(Approval.task_id) == task.id)) - await exec_dml( + await crud.delete_where( + session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False + ) + await crud.delete_where( + session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False + ) + await crud.delete_where(session, Approval, col(Approval.task_id) == task.id, commit=False) + await crud.delete_where( session, - delete(TaskDependency).where( - or_( - col(TaskDependency.task_id) == task.id, - col(TaskDependency.depends_on_task_id) == task.id, - ) + TaskDependency, + or_( + col(TaskDependency.task_id) == task.id, + col(TaskDependency.depends_on_task_id) == task.id, ), + commit=False, ) await session.delete(task) await session.commit() diff --git a/backend/app/db/crud.py b/backend/app/db/crud.py index e3221e28..bd865b22 100644 --- a/backend/app/db/crud.py +++ b/backend/app/db/crud.py @@ -1,8 +1,10 @@ from __future__ import annotations from collections.abc import Iterable, Mapping -from typing import Any, TypeVar +from typing import Any, TypeVar, cast +from sqlalchemy import delete as sql_delete +from sqlalchemy import update as sql_update from sqlalchemy.exc import IntegrityError from sqlmodel import SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,6 +21,22 @@ class MultipleObjectsReturned(LookupError): pass +async def _flush_or_rollback(session: AsyncSession) -> None: + try: + await session.flush() + except Exception: + await session.rollback() + raise + + +async def _commit_or_rollback(session: AsyncSession) -> None: + try: + await session.commit() + except Exception: + await session.rollback() + raise + + def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]: stmt = select(model) for key, value in lookup.items(): @@ -58,9 +76,9 @@ async def create( ) -> ModelT: obj = model.model_validate(data) session.add(obj) - await session.flush() + await _flush_or_rollback(session) if commit: - await session.commit() + await _commit_or_rollback(session) if refresh: await session.refresh(obj) return obj @@ -74,9 +92,9 @@ async def save( refresh: bool = True, ) -> ModelT: session.add(obj) - await session.flush() + await _flush_or_rollback(session) if commit: - await session.commit() + await _commit_or_rollback(session) if refresh: await session.refresh(obj) return obj @@ -85,7 +103,7 @@ async def save( async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None: await session.delete(obj) if commit: - await session.commit() + await _commit_or_rollback(session) async def list_by( @@ -111,6 +129,77 @@ async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> b return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None +def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]: + stmt = select(model) + if criteria: + stmt = stmt.where(*criteria) + return stmt + + +async def list_where( + session: AsyncSession, + model: type[ModelT], + *criteria: Any, + order_by: Iterable[Any] = (), +) -> list[ModelT]: + stmt = _criteria_statement(model, criteria) + for ordering in order_by: + stmt = stmt.order_by(ordering) + return list(await session.exec(stmt)) + + +async def delete_where( + session: AsyncSession, + model: type[ModelT], + *criteria: Any, + commit: bool = False, +) -> int: + stmt = sql_delete(model) + if criteria: + stmt = stmt.where(*criteria) + result = await session.exec(cast(Any, stmt)) + if commit: + await _commit_or_rollback(session) + rowcount = getattr(result, "rowcount", None) + return int(rowcount) if isinstance(rowcount, int) else 0 + + +async def update_where( + session: AsyncSession, + model: type[ModelT], + *criteria: Any, + updates: Mapping[str, Any] | None = None, + commit: bool = False, + exclude_none: bool = False, + allowed_fields: set[str] | None = None, + **update_fields: Any, +) -> int: + source_updates: dict[str, Any] = {} + if updates: + source_updates.update(dict(updates)) + if update_fields: + source_updates.update(update_fields) + + values: dict[str, Any] = {} + for key, value in source_updates.items(): + if allowed_fields is not None and key not in allowed_fields: + continue + if exclude_none and value is None: + continue + values[key] = value + if not values: + return 0 + + stmt = sql_update(model).values(**values) + if criteria: + stmt = stmt.where(*criteria) + result = await session.exec(cast(Any, stmt)) + if commit: + await _commit_or_rollback(session) + rowcount = getattr(result, "rowcount", None) + return int(rowcount) if isinstance(rowcount, int) else 0 + + def apply_updates( obj: ModelT, updates: Mapping[str, Any], @@ -179,6 +268,9 @@ async def get_or_create( if existing is not None: return existing, False raise + except Exception: + await session.rollback() + raise if refresh: await session.refresh(obj) diff --git a/backend/app/db/session.py b/backend/app/db/session.py index 0f80aed6..8a71ba1b 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -11,9 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession -from app import models # noqa: F401 +from app import models as _models from app.core.config import settings +# Import model modules so SQLModel metadata is fully registered at startup. +_MODEL_REGISTRY = _models + def _normalize_database_url(database_url: str) -> str: if "://" not in database_url: @@ -64,4 +67,11 @@ async def init_db() -> None: async def get_session() -> AsyncGenerator[AsyncSession, None]: async with async_session_maker() as session: - yield session + try: + yield session + except Exception: + try: + await session.rollback() + except Exception: + logger.exception("Failed to rollback session after request error.") + raise diff --git a/backend/app/db/sqlmodel_exec.py b/backend/app/db/sqlmodel_exec.py deleted file mode 100644 index 1ac22403..00000000 --- a/backend/app/db/sqlmodel_exec.py +++ /dev/null @@ -1,9 +0,0 @@ -from __future__ import annotations - -from sqlalchemy.sql.base import Executable -from sqlmodel.ext.asyncio.session import AsyncSession - - -async def exec_dml(session: AsyncSession, statement: Executable) -> None: - # SQLModel's AsyncSession typing only overloads exec() for SELECT statements. - await session.exec(statement) # type: ignore[call-overload] diff --git a/backend/app/services/organizations.py b/backend/app/services/organizations.py index b5efdd32..5cfe72a3 100644 --- a/backend/app/services/organizations.py +++ b/backend/app/services/organizations.py @@ -5,13 +5,13 @@ from typing import Iterable from uuid import UUID from fastapi import HTTPException, status -from sqlalchemy import delete, func, or_ +from sqlalchemy import func, or_ from sqlalchemy.sql.elements import ColumnElement from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.time import utcnow -from app.db.sqlmodel_exec import exec_dml +from app.db import crud from app.models.boards import Board from app.models.organization_board_access import OrganizationBoardAccess from app.models.organization_invite_board_access import OrganizationInviteBoardAccess @@ -328,11 +328,11 @@ async def apply_member_access_update( member.updated_at = now session.add(member) - await exec_dml( + await crud.delete_where( session, - delete(OrganizationBoardAccess).where( - col(OrganizationBoardAccess.organization_member_id) == member.id - ), + OrganizationBoardAccess, + col(OrganizationBoardAccess.organization_member_id) == member.id, + commit=False, ) if update.all_boards_read or update.all_boards_write: @@ -359,11 +359,11 @@ async def apply_invite_board_access( invite: OrganizationInvite, entries: Iterable[OrganizationBoardAccessSpec], ) -> None: - await exec_dml( + await crud.delete_where( session, - delete(OrganizationInviteBoardAccess).where( - col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id - ), + OrganizationInviteBoardAccess, + col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id, + commit=False, ) if invite.all_boards_read or invite.all_boards_write: return diff --git a/backend/app/services/task_dependencies.py b/backend/app/services/task_dependencies.py index 166c60ed..aaa37aaa 100644 --- a/backend/app/services/task_dependencies.py +++ b/backend/app/services/task_dependencies.py @@ -6,11 +6,10 @@ from typing import Final from uuid import UUID from fastapi import HTTPException, status -from sqlalchemy import delete from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession -from app.db.sqlmodel_exec import exec_dml +from app.db import crud from app.models.task_dependencies import TaskDependency from app.models.tasks import Task @@ -195,11 +194,12 @@ async def replace_task_dependencies( task_id=task_id, depends_on_task_ids=depends_on_task_ids, ) - await exec_dml( + await crud.delete_where( session, - delete(TaskDependency) - .where(col(TaskDependency.board_id) == board_id) - .where(col(TaskDependency.task_id) == task_id), + TaskDependency, + col(TaskDependency.board_id) == board_id, + col(TaskDependency.task_id) == task_id, + commit=False, ) for dep_id in normalized: session.add( diff --git a/backend/tests/test_db_transaction_safety.py b/backend/tests/test_db_transaction_safety.py new file mode 100644 index 00000000..348a6bb7 --- /dev/null +++ b/backend/tests/test_db_transaction_safety.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import pytest +from sqlmodel import Field, SQLModel + +from app.db import crud +from app.db import session as db_session + + +class _Thing(SQLModel, table=True): + __tablename__ = "_test_thing" + + id: int | None = Field(default=None, primary_key=True) + name: str + + +@dataclass +class _SessionCtx: + session: Any + entered: int = 0 + exited: int = 0 + + async def __aenter__(self) -> Any: + self.entered += 1 + return self.session + + async def __aexit__(self, _exc_type: Any, _exc: Any, _tb: Any) -> bool: + self.exited += 1 + return False + + +@dataclass +class _Maker: + ctx: _SessionCtx + + def __call__(self) -> _SessionCtx: + return self.ctx + + +@pytest.mark.asyncio +async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None: + fake_session = SimpleNamespace(rollbacks=0) + + async def _rollback() -> None: + fake_session.rollbacks += 1 + + fake_session.rollback = _rollback + ctx = _SessionCtx(fake_session) + monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx)) + + generator = db_session.get_session() + yielded = await anext(generator) + assert yielded is fake_session + + with pytest.raises(RuntimeError, match="boom"): + await generator.athrow(RuntimeError("boom")) + + assert fake_session.rollbacks == 1 + assert ctx.entered == 1 + assert ctx.exited == 1 + + +@pytest.mark.asyncio +async def test_create_rolls_back_when_commit_fails() -> None: + @dataclass + class _FailCommitSession: + rollback_calls: int = 0 + added: list[Any] = None + + def __post_init__(self) -> None: + if self.added is None: + self.added = [] + + def add(self, value: Any) -> None: + self.added.append(value) + + async def flush(self) -> None: + return None + + async def commit(self) -> None: + raise RuntimeError("commit failed") + + async def rollback(self) -> None: + self.rollback_calls += 1 + + async def refresh(self, _value: Any) -> None: + return None + + session = _FailCommitSession() + + with pytest.raises(RuntimeError, match="commit failed"): + await crud.create(session, _Thing, name="demo") + + assert session.rollback_calls == 1 + assert len(session.added) == 1 + + +@pytest.mark.asyncio +async def test_delete_where_rolls_back_when_commit_fails() -> None: + @dataclass + class _FailCommitDmlSession: + rollback_calls: int = 0 + exec_calls: int = 0 + + async def exec(self, _statement: Any) -> Any: + self.exec_calls += 1 + return SimpleNamespace(rowcount=3) + + async def commit(self) -> None: + raise RuntimeError("commit failed") + + async def rollback(self) -> None: + self.rollback_calls += 1 + + session = _FailCommitDmlSession() + + with pytest.raises(RuntimeError, match="commit failed"): + await crud.delete_where(session, _Thing, commit=True) + + assert session.exec_calls == 1 + assert session.rollback_calls == 1