refactor: replace exec_dml with CRUD operations in various files and improve session handling

This commit is contained in:
Abhimanyu Saharan
2026-02-09 02:17:34 +05:30
parent 228b99bc9b
commit fafcac1e16
12 changed files with 392 additions and 156 deletions

View File

@@ -9,7 +9,7 @@ from typing import Any, cast
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, or_, update from sqlalchemy import asc, or_
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,9 +19,9 @@ from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin
from app.core.agent_tokens import generate_agent_token, hash_agent_token from app.core.agent_tokens import generate_agent_token, hash_agent_token
from app.core.auth import AuthContext, get_auth_context from app.core.auth import AuthContext, get_auth_context
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.db.sqlmodel_exec import exec_dml
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
@@ -101,7 +101,7 @@ async def _require_board(
session: AsyncSession, session: AsyncSession,
board_id: UUID | str | None, board_id: UUID | str | None,
*, *,
user: object | None = None, user: User | None = None,
write: bool = False, write: bool = False,
) -> Board: ) -> Board:
if not board_id: if not board_id:
@@ -113,7 +113,7 @@ async def _require_board(
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if user is not None: if user is not None:
await require_board_access(session, user=user, board=board, write=write) # type: ignore[arg-type] await require_board_access(session, user=user, board=board, write=write)
return board return board
@@ -972,31 +972,32 @@ async def delete_agent(
agent_id=None, agent_id=None,
) )
now = datetime.now() now = datetime.now()
await exec_dml( await crud.update_where(
session, session,
update(Task) Task,
.where(col(Task.assigned_agent_id) == agent.id) col(Task.assigned_agent_id) == agent.id,
.where(col(Task.status) == "in_progress") col(Task.status) == "in_progress",
.values( assigned_agent_id=None,
assigned_agent_id=None, status="inbox",
status="inbox", in_progress_at=None,
in_progress_at=None, updated_at=now,
updated_at=now, commit=False,
),
) )
await exec_dml( await crud.update_where(
session, session,
update(Task) Task,
.where(col(Task.assigned_agent_id) == agent.id) col(Task.assigned_agent_id) == agent.id,
.where(col(Task.status) != "in_progress") col(Task.status) != "in_progress",
.values( assigned_agent_id=None,
assigned_agent_id=None, updated_at=now,
updated_at=now, commit=False,
),
) )
await exec_dml( await crud.update_where(
session, session,
update(ActivityEvent).where(col(ActivityEvent.agent_id) == agent.id).values(agent_id=None), ActivityEvent,
col(ActivityEvent.agent_id) == agent.id,
agent_id=None,
commit=False,
) )
await session.delete(agent) await session.delete(agent)
await session.commit() await session.commit()

View File

@@ -5,7 +5,7 @@ from typing import Any, cast
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete, func, update from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -14,7 +14,6 @@ from app.core.time import utcnow
from app.db import crud from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import get_session from app.db.session import get_session
from app.db.sqlmodel_exec import exec_dml
from app.models.agents import Agent from app.models.agents import Agent
from app.models.board_group_memory import BoardGroupMemory from app.models.board_group_memory import BoardGroupMemory
from app.models.board_groups import BoardGroup from app.models.board_groups import BoardGroup
@@ -276,14 +275,16 @@ async def delete_board_group(
await _require_group_access(session, group_id=group_id, member=ctx.member, write=True) await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)
# Boards reference groups, so clear the FK first to keep deletes simple. # Boards reference groups, so clear the FK first to keep deletes simple.
await exec_dml( await crud.update_where(
session, session,
update(Board).where(col(Board.board_group_id) == group_id).values(board_group_id=None), Board,
col(Board.board_group_id) == group_id,
board_group_id=None,
commit=False,
) )
await exec_dml( await crud.delete_where(
session, session, BoardGroupMemory, col(BoardGroupMemory.board_group_id) == group_id, commit=False
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id),
) )
await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.id) == group_id)) await crud.delete_where(session, BoardGroup, col(BoardGroup.id) == group_id, commit=False)
await session.commit() await session.commit()
return OkResponse() return OkResponse()

View File

@@ -4,7 +4,7 @@ import re
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import delete, func from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,7 +19,6 @@ from app.core.time import utcnow
from app.db import crud from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import get_session from app.db.session import get_session
from app.db.sqlmodel_exec import exec_dml
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import ( from app.integrations.openclaw_gateway import (
OpenClawGatewayError, OpenClawGatewayError,
@@ -307,43 +306,38 @@ async def delete_board(
) from exc ) from exc
if task_ids: if task_ids:
await exec_dml( await crud.delete_where(
session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)) session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
) )
await exec_dml(session, delete(TaskDependency).where(col(TaskDependency.board_id) == board.id)) await crud.delete_where(session, TaskDependency, col(TaskDependency.board_id) == board.id)
await exec_dml( await crud.delete_where(session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id)
session, delete(TaskFingerprint).where(col(TaskFingerprint.board_id) == board.id)
)
# Approvals can reference tasks and agents, so delete before both. # Approvals can reference tasks and agents, so delete before both.
await exec_dml(session, delete(Approval).where(col(Approval.board_id) == board.id)) await crud.delete_where(session, Approval, col(Approval.board_id) == board.id)
await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id) == board.id)) await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id)
await exec_dml( await crud.delete_where(
session, session, BoardOnboardingSession, col(BoardOnboardingSession.board_id) == board.id
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id) == board.id),
) )
await exec_dml( await crud.delete_where(
session, session, OrganizationBoardAccess, col(OrganizationBoardAccess.board_id) == board.id
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id) == board.id),
) )
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationInviteBoardAccess).where( OrganizationInviteBoardAccess,
col(OrganizationInviteBoardAccess.board_id) == board.id col(OrganizationInviteBoardAccess.board_id) == board.id,
),
) )
# Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so # Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so
# delete tasks before agents. # delete tasks before agents.
await exec_dml(session, delete(Task).where(col(Task.board_id) == board.id)) await crud.delete_where(session, Task, col(Task.board_id) == board.id)
if agents: if agents:
agent_ids = [agent.id for agent in agents] agent_ids = [agent.id for agent in agents]
await exec_dml( await crud.delete_where(
session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids)) session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
) )
await exec_dml(session, delete(Agent).where(col(Agent.id).in_(agent_ids))) await crud.delete_where(session, Agent, col(Agent.id).in_(agent_ids))
await session.delete(board) await session.delete(board)
await session.commit() await session.commit()
return OkResponse() return OkResponse()

View File

@@ -21,6 +21,7 @@ from app.integrations.openclaw_gateway_protocol import (
) )
from app.models.boards import Board from app.models.boards import Board
from app.models.gateways import Gateway from app.models.gateways import Gateway
from app.models.users import User
from app.schemas.common import OkResponse from app.schemas.common import OkResponse
from app.schemas.gateway_api import ( from app.schemas.gateway_api import (
GatewayCommandsResponse, GatewayCommandsResponse,
@@ -43,7 +44,7 @@ async def _resolve_gateway(
gateway_token: str | None, gateway_token: str | None,
gateway_main_session_key: str | None, gateway_main_session_key: str | None,
*, *,
user: object | None = None, user: User | None = None,
) -> tuple[Board | None, GatewayClientConfig, str | None]: ) -> tuple[Board | None, GatewayClientConfig, str | None]:
if gateway_url: if gateway_url:
return ( return (
@@ -59,8 +60,8 @@ async def _resolve_gateway(
board = await Board.objects.by_id(board_id).first(session) board = await Board.objects.by_id(board_id).first(session)
if board is None: if board is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
if isinstance(user, object) and user is not None: if user is not None:
await require_board_access(session, user=user, board=board, write=False) # type: ignore[arg-type] await require_board_access(session, user=user, board=board, write=False)
if not board.gateway_id: if not board.gateway_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
@@ -85,7 +86,7 @@ async def _resolve_gateway(
async def _require_gateway( async def _require_gateway(
session: AsyncSession, board_id: str | None, *, user: object | None = None session: AsyncSession, board_id: str | None, *, user: User | None = None
) -> tuple[Board, GatewayClientConfig, str | None]: ) -> tuple[Board, GatewayClientConfig, str | None]:
board, config, main_session = await _resolve_gateway( board, config, main_session = await _resolve_gateway(
session, session,

View File

@@ -5,7 +5,7 @@ from typing import Any, Sequence
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import delete, func, update from sqlalchemy import func
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -15,7 +15,6 @@ from app.core.time import utcnow
from app.db import crud from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import get_session from app.db.session import get_session
from app.db.sqlmodel_exec import exec_dml
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
from app.models.agents import Agent from app.models.agents import Agent
from app.models.approvals import Approval from app.models.approvals import Approval
@@ -214,67 +213,85 @@ async def delete_my_org(
) )
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id) group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids))) await crud.delete_where(
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids))) session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
await exec_dml(
session, delete(TaskDependency).where(col(TaskDependency.board_id).in_(board_ids))
) )
await exec_dml( await crud.delete_where(
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
)
await crud.delete_where(
session, TaskDependency, col(TaskDependency.board_id).in_(board_ids), commit=False
)
await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.board_id).in_(board_ids), commit=False
)
await crud.delete_where(session, Approval, col(Approval.board_id).in_(board_ids), commit=False)
await crud.delete_where(
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False
)
await crud.delete_where(
session, session,
delete(TaskFingerprint).where(col(TaskFingerprint.board_id).in_(board_ids)), BoardOnboardingSession,
col(BoardOnboardingSession.board_id).in_(board_ids),
commit=False,
) )
await exec_dml(session, delete(Approval).where(col(Approval.board_id).in_(board_ids))) await crud.delete_where(
await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id).in_(board_ids)))
await exec_dml(
session, session,
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id).in_(board_ids)), OrganizationBoardAccess,
col(OrganizationBoardAccess.board_id).in_(board_ids),
commit=False,
) )
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id).in_(board_ids)), OrganizationInviteBoardAccess,
col(OrganizationInviteBoardAccess.board_id).in_(board_ids),
commit=False,
) )
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationInviteBoardAccess).where( OrganizationBoardAccess,
col(OrganizationInviteBoardAccess.board_id).in_(board_ids) col(OrganizationBoardAccess.organization_member_id).in_(member_ids),
), commit=False,
) )
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationBoardAccess).where( OrganizationInviteBoardAccess,
col(OrganizationBoardAccess.organization_member_id).in_(member_ids) col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids),
), commit=False,
) )
await exec_dml( await crud.delete_where(session, Task, col(Task.board_id).in_(board_ids), commit=False)
await crud.delete_where(session, Agent, col(Agent.board_id).in_(board_ids), commit=False)
await crud.delete_where(session, Board, col(Board.organization_id) == org_id, commit=False)
await crud.delete_where(
session, session,
delete(OrganizationInviteBoardAccess).where( BoardGroupMemory,
col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids) col(BoardGroupMemory.board_group_id).in_(group_ids),
), commit=False,
) )
await exec_dml(session, delete(Task).where(col(Task.board_id).in_(board_ids))) await crud.delete_where(
await exec_dml(session, delete(Agent).where(col(Agent.board_id).in_(board_ids))) session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False
await exec_dml(session, delete(Board).where(col(Board.organization_id) == org_id)) )
await exec_dml( await crud.delete_where(session, Gateway, col(Gateway.organization_id) == org_id, commit=False)
await crud.delete_where(
session, session,
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id).in_(group_ids)), OrganizationInvite,
col(OrganizationInvite.organization_id) == org_id,
commit=False,
) )
await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.organization_id) == org_id)) await crud.delete_where(
await exec_dml(session, delete(Gateway).where(col(Gateway.organization_id) == org_id))
await exec_dml(
session, session,
delete(OrganizationInvite).where(col(OrganizationInvite.organization_id) == org_id), OrganizationMember,
col(OrganizationMember.organization_id) == org_id,
commit=False,
) )
await exec_dml( await crud.update_where(
session, session,
delete(OrganizationMember).where(col(OrganizationMember.organization_id) == org_id), User,
col(User.active_organization_id) == org_id,
active_organization_id=None,
commit=False,
) )
await exec_dml( await crud.delete_where(session, Organization, col(Organization.id) == org_id, commit=False)
session,
update(User)
.where(col(User.active_organization_id) == org_id)
.values(active_organization_id=None),
)
await exec_dml(session, delete(Organization).where(col(Organization.id) == org_id))
await session.commit() await session.commit()
return OkResponse() return OkResponse()
@@ -425,11 +442,11 @@ async def remove_org_member(
detail="Organization must have at least one owner", detail="Organization must have at least one owner",
) )
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationBoardAccess).where( OrganizationBoardAccess,
col(OrganizationBoardAccess.organization_member_id) == member.id col(OrganizationBoardAccess.organization_member_id) == member.id,
), commit=False,
) )
user = await User.objects.by_id(member.user_id).first(session) user = await User.objects.by_id(member.user_id).first(session)
@@ -532,11 +549,11 @@ async def revoke_org_invite(
organization_id=ctx.organization.id, organization_id=ctx.organization.id,
invite_id=invite_id, invite_id=invite_id,
) )
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationInviteBoardAccess).where( OrganizationInviteBoardAccess,
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
), commit=False,
) )
await crud.delete(session, invite) await crud.delete(session, invite)
return OrganizationInviteRead.model_validate(invite, from_attributes=True) return OrganizationInviteRead.model_validate(invite, from_attributes=True)

View File

@@ -9,7 +9,7 @@ from typing import cast
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, delete, desc, or_ from sqlalchemy import asc, desc, or_
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import Select from sqlmodel.sql.expression import Select
@@ -25,9 +25,9 @@ from app.api.deps import (
) )
from app.core.auth import AuthContext from app.core.auth import AuthContext
from app.core.time import utcnow from app.core.time import utcnow
from app.db import crud
from app.db.pagination import paginate from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session from app.db.session import async_session_maker, get_session
from app.db.sqlmodel_exec import exec_dml
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
from app.models.activity_events import ActivityEvent from app.models.activity_events import ActivityEvent
@@ -997,17 +997,21 @@ async def delete_task(
if auth.user is None: if auth.user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
await require_board_access(session, user=auth.user, board=board, write=True) await require_board_access(session, user=auth.user, board=board, write=True)
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id) == task.id)) await crud.delete_where(
await exec_dml(session, delete(TaskFingerprint).where(col(TaskFingerprint.task_id) == task.id)) session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False
await exec_dml(session, delete(Approval).where(col(Approval.task_id) == task.id)) )
await exec_dml( await crud.delete_where(
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False
)
await crud.delete_where(session, Approval, col(Approval.task_id) == task.id, commit=False)
await crud.delete_where(
session, session,
delete(TaskDependency).where( TaskDependency,
or_( or_(
col(TaskDependency.task_id) == task.id, col(TaskDependency.task_id) == task.id,
col(TaskDependency.depends_on_task_id) == task.id, col(TaskDependency.depends_on_task_id) == task.id,
)
), ),
commit=False,
) )
await session.delete(task) await session.delete(task)
await session.commit() await session.commit()

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Any, TypeVar from typing import Any, TypeVar, cast
from sqlalchemy import delete as sql_delete
from sqlalchemy import update as sql_update
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel import SQLModel, select from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,6 +21,22 @@ class MultipleObjectsReturned(LookupError):
pass pass
async def _flush_or_rollback(session: AsyncSession) -> None:
try:
await session.flush()
except Exception:
await session.rollback()
raise
async def _commit_or_rollback(session: AsyncSession) -> None:
try:
await session.commit()
except Exception:
await session.rollback()
raise
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]: def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
stmt = select(model) stmt = select(model)
for key, value in lookup.items(): for key, value in lookup.items():
@@ -58,9 +76,9 @@ async def create(
) -> ModelT: ) -> ModelT:
obj = model.model_validate(data) obj = model.model_validate(data)
session.add(obj) session.add(obj)
await session.flush() await _flush_or_rollback(session)
if commit: if commit:
await session.commit() await _commit_or_rollback(session)
if refresh: if refresh:
await session.refresh(obj) await session.refresh(obj)
return obj return obj
@@ -74,9 +92,9 @@ async def save(
refresh: bool = True, refresh: bool = True,
) -> ModelT: ) -> ModelT:
session.add(obj) session.add(obj)
await session.flush() await _flush_or_rollback(session)
if commit: if commit:
await session.commit() await _commit_or_rollback(session)
if refresh: if refresh:
await session.refresh(obj) await session.refresh(obj)
return obj return obj
@@ -85,7 +103,7 @@ async def save(
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None: async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
await session.delete(obj) await session.delete(obj)
if commit: if commit:
await session.commit() await _commit_or_rollback(session)
async def list_by( async def list_by(
@@ -111,6 +129,77 @@ async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> b
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]:
stmt = select(model)
if criteria:
stmt = stmt.where(*criteria)
return stmt
async def list_where(
session: AsyncSession,
model: type[ModelT],
*criteria: Any,
order_by: Iterable[Any] = (),
) -> list[ModelT]:
stmt = _criteria_statement(model, criteria)
for ordering in order_by:
stmt = stmt.order_by(ordering)
return list(await session.exec(stmt))
async def delete_where(
session: AsyncSession,
model: type[ModelT],
*criteria: Any,
commit: bool = False,
) -> int:
stmt = sql_delete(model)
if criteria:
stmt = stmt.where(*criteria)
result = await session.exec(cast(Any, stmt))
if commit:
await _commit_or_rollback(session)
rowcount = getattr(result, "rowcount", None)
return int(rowcount) if isinstance(rowcount, int) else 0
async def update_where(
session: AsyncSession,
model: type[ModelT],
*criteria: Any,
updates: Mapping[str, Any] | None = None,
commit: bool = False,
exclude_none: bool = False,
allowed_fields: set[str] | None = None,
**update_fields: Any,
) -> int:
source_updates: dict[str, Any] = {}
if updates:
source_updates.update(dict(updates))
if update_fields:
source_updates.update(update_fields)
values: dict[str, Any] = {}
for key, value in source_updates.items():
if allowed_fields is not None and key not in allowed_fields:
continue
if exclude_none and value is None:
continue
values[key] = value
if not values:
return 0
stmt = sql_update(model).values(**values)
if criteria:
stmt = stmt.where(*criteria)
result = await session.exec(cast(Any, stmt))
if commit:
await _commit_or_rollback(session)
rowcount = getattr(result, "rowcount", None)
return int(rowcount) if isinstance(rowcount, int) else 0
def apply_updates( def apply_updates(
obj: ModelT, obj: ModelT,
updates: Mapping[str, Any], updates: Mapping[str, Any],
@@ -179,6 +268,9 @@ async def get_or_create(
if existing is not None: if existing is not None:
return existing, False return existing, False
raise raise
except Exception:
await session.rollback()
raise
if refresh: if refresh:
await session.refresh(obj) await session.refresh(obj)

View File

@@ -11,9 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app import models # noqa: F401 from app import models as _models
from app.core.config import settings from app.core.config import settings
# Import model modules so SQLModel metadata is fully registered at startup.
_MODEL_REGISTRY = _models
def _normalize_database_url(database_url: str) -> str: def _normalize_database_url(database_url: str) -> str:
if "://" not in database_url: if "://" not in database_url:
@@ -64,4 +67,11 @@ async def init_db() -> None:
async def get_session() -> AsyncGenerator[AsyncSession, None]: async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session: async with async_session_maker() as session:
yield session try:
yield session
except Exception:
try:
await session.rollback()
except Exception:
logger.exception("Failed to rollback session after request error.")
raise

View File

@@ -1,9 +0,0 @@
from __future__ import annotations
from sqlalchemy.sql.base import Executable
from sqlmodel.ext.asyncio.session import AsyncSession
async def exec_dml(session: AsyncSession, statement: Executable) -> None:
# SQLModel's AsyncSession typing only overloads exec() for SELECT statements.
await session.exec(statement) # type: ignore[call-overload]

View File

@@ -5,13 +5,13 @@ from typing import Iterable
from uuid import UUID from uuid import UUID
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlalchemy import delete, func, or_ from sqlalchemy import func, or_
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.time import utcnow from app.core.time import utcnow
from app.db.sqlmodel_exec import exec_dml from app.db import crud
from app.models.boards import Board from app.models.boards import Board
from app.models.organization_board_access import OrganizationBoardAccess from app.models.organization_board_access import OrganizationBoardAccess
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
@@ -328,11 +328,11 @@ async def apply_member_access_update(
member.updated_at = now member.updated_at = now
session.add(member) session.add(member)
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationBoardAccess).where( OrganizationBoardAccess,
col(OrganizationBoardAccess.organization_member_id) == member.id col(OrganizationBoardAccess.organization_member_id) == member.id,
), commit=False,
) )
if update.all_boards_read or update.all_boards_write: if update.all_boards_read or update.all_boards_write:
@@ -359,11 +359,11 @@ async def apply_invite_board_access(
invite: OrganizationInvite, invite: OrganizationInvite,
entries: Iterable[OrganizationBoardAccessSpec], entries: Iterable[OrganizationBoardAccessSpec],
) -> None: ) -> None:
await exec_dml( await crud.delete_where(
session, session,
delete(OrganizationInviteBoardAccess).where( OrganizationInviteBoardAccess,
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
), commit=False,
) )
if invite.all_boards_read or invite.all_boards_write: if invite.all_boards_read or invite.all_boards_write:
return return

View File

@@ -6,11 +6,10 @@ from typing import Final
from uuid import UUID from uuid import UUID
from fastapi import HTTPException, status from fastapi import HTTPException, status
from sqlalchemy import delete
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.db.sqlmodel_exec import exec_dml from app.db import crud
from app.models.task_dependencies import TaskDependency from app.models.task_dependencies import TaskDependency
from app.models.tasks import Task from app.models.tasks import Task
@@ -195,11 +194,12 @@ async def replace_task_dependencies(
task_id=task_id, task_id=task_id,
depends_on_task_ids=depends_on_task_ids, depends_on_task_ids=depends_on_task_ids,
) )
await exec_dml( await crud.delete_where(
session, session,
delete(TaskDependency) TaskDependency,
.where(col(TaskDependency.board_id) == board_id) col(TaskDependency.board_id) == board_id,
.where(col(TaskDependency.task_id) == task_id), col(TaskDependency.task_id) == task_id,
commit=False,
) )
for dep_id in normalized: for dep_id in normalized:
session.add( session.add(

View File

@@ -0,0 +1,125 @@
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any
import pytest
from sqlmodel import Field, SQLModel
from app.db import crud
from app.db import session as db_session
class _Thing(SQLModel, table=True):
__tablename__ = "_test_thing"
id: int | None = Field(default=None, primary_key=True)
name: str
@dataclass
class _SessionCtx:
session: Any
entered: int = 0
exited: int = 0
async def __aenter__(self) -> Any:
self.entered += 1
return self.session
async def __aexit__(self, _exc_type: Any, _exc: Any, _tb: Any) -> bool:
self.exited += 1
return False
@dataclass
class _Maker:
ctx: _SessionCtx
def __call__(self) -> _SessionCtx:
return self.ctx
@pytest.mark.asyncio
async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None:
fake_session = SimpleNamespace(rollbacks=0)
async def _rollback() -> None:
fake_session.rollbacks += 1
fake_session.rollback = _rollback
ctx = _SessionCtx(fake_session)
monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx))
generator = db_session.get_session()
yielded = await anext(generator)
assert yielded is fake_session
with pytest.raises(RuntimeError, match="boom"):
await generator.athrow(RuntimeError("boom"))
assert fake_session.rollbacks == 1
assert ctx.entered == 1
assert ctx.exited == 1
@pytest.mark.asyncio
async def test_create_rolls_back_when_commit_fails() -> None:
@dataclass
class _FailCommitSession:
rollback_calls: int = 0
added: list[Any] = None
def __post_init__(self) -> None:
if self.added is None:
self.added = []
def add(self, value: Any) -> None:
self.added.append(value)
async def flush(self) -> None:
return None
async def commit(self) -> None:
raise RuntimeError("commit failed")
async def rollback(self) -> None:
self.rollback_calls += 1
async def refresh(self, _value: Any) -> None:
return None
session = _FailCommitSession()
with pytest.raises(RuntimeError, match="commit failed"):
await crud.create(session, _Thing, name="demo")
assert session.rollback_calls == 1
assert len(session.added) == 1
@pytest.mark.asyncio
async def test_delete_where_rolls_back_when_commit_fails() -> None:
@dataclass
class _FailCommitDmlSession:
rollback_calls: int = 0
exec_calls: int = 0
async def exec(self, _statement: Any) -> Any:
self.exec_calls += 1
return SimpleNamespace(rowcount=3)
async def commit(self) -> None:
raise RuntimeError("commit failed")
async def rollback(self) -> None:
self.rollback_calls += 1
session = _FailCommitDmlSession()
with pytest.raises(RuntimeError, match="commit failed"):
await crud.delete_where(session, _Thing, commit=True)
assert session.exec_calls == 1
assert session.rollback_calls == 1