refactor: replace exec_dml with CRUD operations in various files and improve session handling
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Any, cast
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
from sqlalchemy import asc, or_, update
|
from sqlalchemy import asc, or_
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -19,9 +19,9 @@ from app.api.deps import ActorContext, require_admin_or_agent, require_org_admin
|
|||||||
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
from app.core.agent_tokens import generate_agent_token, hash_agent_token
|
||||||
from app.core.auth import AuthContext, get_auth_context
|
from app.core.auth import AuthContext, get_auth_context
|
||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
|
from app.db import crud
|
||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import async_session_maker, get_session
|
from app.db.session import async_session_maker, get_session
|
||||||
from app.db.sqlmodel_exec import exec_dml
|
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
from app.models.activity_events import ActivityEvent
|
from app.models.activity_events import ActivityEvent
|
||||||
@@ -101,7 +101,7 @@ async def _require_board(
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
board_id: UUID | str | None,
|
board_id: UUID | str | None,
|
||||||
*,
|
*,
|
||||||
user: object | None = None,
|
user: User | None = None,
|
||||||
write: bool = False,
|
write: bool = False,
|
||||||
) -> Board:
|
) -> Board:
|
||||||
if not board_id:
|
if not board_id:
|
||||||
@@ -113,7 +113,7 @@ async def _require_board(
|
|||||||
if board is None:
|
if board is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
|
||||||
if user is not None:
|
if user is not None:
|
||||||
await require_board_access(session, user=user, board=board, write=write) # type: ignore[arg-type]
|
await require_board_access(session, user=user, board=board, write=write)
|
||||||
return board
|
return board
|
||||||
|
|
||||||
|
|
||||||
@@ -972,31 +972,32 @@ async def delete_agent(
|
|||||||
agent_id=None,
|
agent_id=None,
|
||||||
)
|
)
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
await exec_dml(
|
await crud.update_where(
|
||||||
session,
|
session,
|
||||||
update(Task)
|
Task,
|
||||||
.where(col(Task.assigned_agent_id) == agent.id)
|
col(Task.assigned_agent_id) == agent.id,
|
||||||
.where(col(Task.status) == "in_progress")
|
col(Task.status) == "in_progress",
|
||||||
.values(
|
assigned_agent_id=None,
|
||||||
assigned_agent_id=None,
|
status="inbox",
|
||||||
status="inbox",
|
in_progress_at=None,
|
||||||
in_progress_at=None,
|
updated_at=now,
|
||||||
updated_at=now,
|
commit=False,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.update_where(
|
||||||
session,
|
session,
|
||||||
update(Task)
|
Task,
|
||||||
.where(col(Task.assigned_agent_id) == agent.id)
|
col(Task.assigned_agent_id) == agent.id,
|
||||||
.where(col(Task.status) != "in_progress")
|
col(Task.status) != "in_progress",
|
||||||
.values(
|
assigned_agent_id=None,
|
||||||
assigned_agent_id=None,
|
updated_at=now,
|
||||||
updated_at=now,
|
commit=False,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.update_where(
|
||||||
session,
|
session,
|
||||||
update(ActivityEvent).where(col(ActivityEvent.agent_id) == agent.id).values(agent_id=None),
|
ActivityEvent,
|
||||||
|
col(ActivityEvent.agent_id) == agent.id,
|
||||||
|
agent_id=None,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await session.delete(agent)
|
await session.delete(agent)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any, cast
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy import delete, func, update
|
from sqlalchemy import func
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -14,7 +14,6 @@ from app.core.time import utcnow
|
|||||||
from app.db import crud
|
from app.db import crud
|
||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import get_session
|
from app.db.session import get_session
|
||||||
from app.db.sqlmodel_exec import exec_dml
|
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.board_group_memory import BoardGroupMemory
|
from app.models.board_group_memory import BoardGroupMemory
|
||||||
from app.models.board_groups import BoardGroup
|
from app.models.board_groups import BoardGroup
|
||||||
@@ -276,14 +275,16 @@ async def delete_board_group(
|
|||||||
await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)
|
await _require_group_access(session, group_id=group_id, member=ctx.member, write=True)
|
||||||
|
|
||||||
# Boards reference groups, so clear the FK first to keep deletes simple.
|
# Boards reference groups, so clear the FK first to keep deletes simple.
|
||||||
await exec_dml(
|
await crud.update_where(
|
||||||
session,
|
session,
|
||||||
update(Board).where(col(Board.board_group_id) == group_id).values(board_group_id=None),
|
Board,
|
||||||
|
col(Board.board_group_id) == group_id,
|
||||||
|
board_group_id=None,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session, BoardGroupMemory, col(BoardGroupMemory.board_group_id) == group_id, commit=False
|
||||||
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id) == group_id),
|
|
||||||
)
|
)
|
||||||
await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.id) == group_id))
|
await crud.delete_where(session, BoardGroup, col(BoardGroup.id) == group_id, commit=False)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return OkResponse()
|
return OkResponse()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import re
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlalchemy import delete, func
|
from sqlalchemy import func
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -19,7 +19,6 @@ from app.core.time import utcnow
|
|||||||
from app.db import crud
|
from app.db import crud
|
||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import get_session
|
from app.db.session import get_session
|
||||||
from app.db.sqlmodel_exec import exec_dml
|
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import (
|
from app.integrations.openclaw_gateway import (
|
||||||
OpenClawGatewayError,
|
OpenClawGatewayError,
|
||||||
@@ -307,43 +306,38 @@ async def delete_board(
|
|||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
if task_ids:
|
if task_ids:
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids))
|
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
|
||||||
)
|
)
|
||||||
await exec_dml(session, delete(TaskDependency).where(col(TaskDependency.board_id) == board.id))
|
await crud.delete_where(session, TaskDependency, col(TaskDependency.board_id) == board.id)
|
||||||
await exec_dml(
|
await crud.delete_where(session, TaskFingerprint, col(TaskFingerprint.board_id) == board.id)
|
||||||
session, delete(TaskFingerprint).where(col(TaskFingerprint.board_id) == board.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Approvals can reference tasks and agents, so delete before both.
|
# Approvals can reference tasks and agents, so delete before both.
|
||||||
await exec_dml(session, delete(Approval).where(col(Approval.board_id) == board.id))
|
await crud.delete_where(session, Approval, col(Approval.board_id) == board.id)
|
||||||
|
|
||||||
await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id) == board.id))
|
await crud.delete_where(session, BoardMemory, col(BoardMemory.board_id) == board.id)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session, BoardOnboardingSession, col(BoardOnboardingSession.board_id) == board.id
|
||||||
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id) == board.id),
|
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session, OrganizationBoardAccess, col(OrganizationBoardAccess.board_id) == board.id
|
||||||
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id) == board.id),
|
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationInviteBoardAccess).where(
|
OrganizationInviteBoardAccess,
|
||||||
col(OrganizationInviteBoardAccess.board_id) == board.id
|
col(OrganizationInviteBoardAccess.board_id) == board.id,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so
|
# Tasks reference agents (assigned_agent_id) and have dependents (fingerprints/dependencies), so
|
||||||
# delete tasks before agents.
|
# delete tasks before agents.
|
||||||
await exec_dml(session, delete(Task).where(col(Task.board_id) == board.id))
|
await crud.delete_where(session, Task, col(Task.board_id) == board.id)
|
||||||
|
|
||||||
if agents:
|
if agents:
|
||||||
agent_ids = [agent.id for agent in agents]
|
agent_ids = [agent.id for agent in agents]
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids))
|
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
|
||||||
)
|
)
|
||||||
await exec_dml(session, delete(Agent).where(col(Agent.id).in_(agent_ids)))
|
await crud.delete_where(session, Agent, col(Agent.id).in_(agent_ids))
|
||||||
await session.delete(board)
|
await session.delete(board)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return OkResponse()
|
return OkResponse()
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.integrations.openclaw_gateway_protocol import (
|
|||||||
)
|
)
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
from app.models.gateways import Gateway
|
from app.models.gateways import Gateway
|
||||||
|
from app.models.users import User
|
||||||
from app.schemas.common import OkResponse
|
from app.schemas.common import OkResponse
|
||||||
from app.schemas.gateway_api import (
|
from app.schemas.gateway_api import (
|
||||||
GatewayCommandsResponse,
|
GatewayCommandsResponse,
|
||||||
@@ -43,7 +44,7 @@ async def _resolve_gateway(
|
|||||||
gateway_token: str | None,
|
gateway_token: str | None,
|
||||||
gateway_main_session_key: str | None,
|
gateway_main_session_key: str | None,
|
||||||
*,
|
*,
|
||||||
user: object | None = None,
|
user: User | None = None,
|
||||||
) -> tuple[Board | None, GatewayClientConfig, str | None]:
|
) -> tuple[Board | None, GatewayClientConfig, str | None]:
|
||||||
if gateway_url:
|
if gateway_url:
|
||||||
return (
|
return (
|
||||||
@@ -59,8 +60,8 @@ async def _resolve_gateway(
|
|||||||
board = await Board.objects.by_id(board_id).first(session)
|
board = await Board.objects.by_id(board_id).first(session)
|
||||||
if board is None:
|
if board is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Board not found")
|
||||||
if isinstance(user, object) and user is not None:
|
if user is not None:
|
||||||
await require_board_access(session, user=user, board=board, write=False) # type: ignore[arg-type]
|
await require_board_access(session, user=user, board=board, write=False)
|
||||||
if not board.gateway_id:
|
if not board.gateway_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
@@ -85,7 +86,7 @@ async def _resolve_gateway(
|
|||||||
|
|
||||||
|
|
||||||
async def _require_gateway(
|
async def _require_gateway(
|
||||||
session: AsyncSession, board_id: str | None, *, user: object | None = None
|
session: AsyncSession, board_id: str | None, *, user: User | None = None
|
||||||
) -> tuple[Board, GatewayClientConfig, str | None]:
|
) -> tuple[Board, GatewayClientConfig, str | None]:
|
||||||
board, config, main_session = await _resolve_gateway(
|
board, config, main_session = await _resolve_gateway(
|
||||||
session,
|
session,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any, Sequence
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy import delete, func, update
|
from sqlalchemy import func
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -15,7 +15,6 @@ from app.core.time import utcnow
|
|||||||
from app.db import crud
|
from app.db import crud
|
||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import get_session
|
from app.db.session import get_session
|
||||||
from app.db.sqlmodel_exec import exec_dml
|
|
||||||
from app.models.activity_events import ActivityEvent
|
from app.models.activity_events import ActivityEvent
|
||||||
from app.models.agents import Agent
|
from app.models.agents import Agent
|
||||||
from app.models.approvals import Approval
|
from app.models.approvals import Approval
|
||||||
@@ -214,67 +213,85 @@ async def delete_my_org(
|
|||||||
)
|
)
|
||||||
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
|
group_ids = select(BoardGroup.id).where(col(BoardGroup.organization_id) == org_id)
|
||||||
|
|
||||||
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id).in_(task_ids)))
|
await crud.delete_where(
|
||||||
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.agent_id).in_(agent_ids)))
|
session, ActivityEvent, col(ActivityEvent.task_id).in_(task_ids), commit=False
|
||||||
await exec_dml(
|
|
||||||
session, delete(TaskDependency).where(col(TaskDependency.board_id).in_(board_ids))
|
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
|
session, ActivityEvent, col(ActivityEvent.agent_id).in_(agent_ids), commit=False
|
||||||
|
)
|
||||||
|
await crud.delete_where(
|
||||||
|
session, TaskDependency, col(TaskDependency.board_id).in_(board_ids), commit=False
|
||||||
|
)
|
||||||
|
await crud.delete_where(
|
||||||
|
session, TaskFingerprint, col(TaskFingerprint.board_id).in_(board_ids), commit=False
|
||||||
|
)
|
||||||
|
await crud.delete_where(session, Approval, col(Approval.board_id).in_(board_ids), commit=False)
|
||||||
|
await crud.delete_where(
|
||||||
|
session, BoardMemory, col(BoardMemory.board_id).in_(board_ids), commit=False
|
||||||
|
)
|
||||||
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(TaskFingerprint).where(col(TaskFingerprint.board_id).in_(board_ids)),
|
BoardOnboardingSession,
|
||||||
|
col(BoardOnboardingSession.board_id).in_(board_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(session, delete(Approval).where(col(Approval.board_id).in_(board_ids)))
|
await crud.delete_where(
|
||||||
await exec_dml(session, delete(BoardMemory).where(col(BoardMemory.board_id).in_(board_ids)))
|
|
||||||
await exec_dml(
|
|
||||||
session,
|
session,
|
||||||
delete(BoardOnboardingSession).where(col(BoardOnboardingSession.board_id).in_(board_ids)),
|
OrganizationBoardAccess,
|
||||||
|
col(OrganizationBoardAccess.board_id).in_(board_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationBoardAccess).where(col(OrganizationBoardAccess.board_id).in_(board_ids)),
|
OrganizationInviteBoardAccess,
|
||||||
|
col(OrganizationInviteBoardAccess.board_id).in_(board_ids),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationInviteBoardAccess).where(
|
OrganizationBoardAccess,
|
||||||
col(OrganizationInviteBoardAccess.board_id).in_(board_ids)
|
col(OrganizationBoardAccess.organization_member_id).in_(member_ids),
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationBoardAccess).where(
|
OrganizationInviteBoardAccess,
|
||||||
col(OrganizationBoardAccess.organization_member_id).in_(member_ids)
|
col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids),
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(session, Task, col(Task.board_id).in_(board_ids), commit=False)
|
||||||
|
await crud.delete_where(session, Agent, col(Agent.board_id).in_(board_ids), commit=False)
|
||||||
|
await crud.delete_where(session, Board, col(Board.organization_id) == org_id, commit=False)
|
||||||
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationInviteBoardAccess).where(
|
BoardGroupMemory,
|
||||||
col(OrganizationInviteBoardAccess.organization_invite_id).in_(invite_ids)
|
col(BoardGroupMemory.board_group_id).in_(group_ids),
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(session, delete(Task).where(col(Task.board_id).in_(board_ids)))
|
await crud.delete_where(
|
||||||
await exec_dml(session, delete(Agent).where(col(Agent.board_id).in_(board_ids)))
|
session, BoardGroup, col(BoardGroup.organization_id) == org_id, commit=False
|
||||||
await exec_dml(session, delete(Board).where(col(Board.organization_id) == org_id))
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(session, Gateway, col(Gateway.organization_id) == org_id, commit=False)
|
||||||
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(BoardGroupMemory).where(col(BoardGroupMemory.board_group_id).in_(group_ids)),
|
OrganizationInvite,
|
||||||
|
col(OrganizationInvite.organization_id) == org_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(session, delete(BoardGroup).where(col(BoardGroup.organization_id) == org_id))
|
await crud.delete_where(
|
||||||
await exec_dml(session, delete(Gateway).where(col(Gateway.organization_id) == org_id))
|
|
||||||
await exec_dml(
|
|
||||||
session,
|
session,
|
||||||
delete(OrganizationInvite).where(col(OrganizationInvite.organization_id) == org_id),
|
OrganizationMember,
|
||||||
|
col(OrganizationMember.organization_id) == org_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.update_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationMember).where(col(OrganizationMember.organization_id) == org_id),
|
User,
|
||||||
|
col(User.active_organization_id) == org_id,
|
||||||
|
active_organization_id=None,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(session, Organization, col(Organization.id) == org_id, commit=False)
|
||||||
session,
|
|
||||||
update(User)
|
|
||||||
.where(col(User.active_organization_id) == org_id)
|
|
||||||
.values(active_organization_id=None),
|
|
||||||
)
|
|
||||||
await exec_dml(session, delete(Organization).where(col(Organization.id) == org_id))
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return OkResponse()
|
return OkResponse()
|
||||||
|
|
||||||
@@ -425,11 +442,11 @@ async def remove_org_member(
|
|||||||
detail="Organization must have at least one owner",
|
detail="Organization must have at least one owner",
|
||||||
)
|
)
|
||||||
|
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationBoardAccess).where(
|
OrganizationBoardAccess,
|
||||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await User.objects.by_id(member.user_id).first(session)
|
user = await User.objects.by_id(member.user_id).first(session)
|
||||||
@@ -532,11 +549,11 @@ async def revoke_org_invite(
|
|||||||
organization_id=ctx.organization.id,
|
organization_id=ctx.organization.id,
|
||||||
invite_id=invite_id,
|
invite_id=invite_id,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationInviteBoardAccess).where(
|
OrganizationInviteBoardAccess,
|
||||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
await crud.delete(session, invite)
|
await crud.delete(session, invite)
|
||||||
return OrganizationInviteRead.model_validate(invite, from_attributes=True)
|
return OrganizationInviteRead.model_validate(invite, from_attributes=True)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import cast
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
from sqlalchemy import asc, delete, desc, or_
|
from sqlalchemy import asc, desc, or_
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlmodel.sql.expression import Select
|
from sqlmodel.sql.expression import Select
|
||||||
@@ -25,9 +25,9 @@ from app.api.deps import (
|
|||||||
)
|
)
|
||||||
from app.core.auth import AuthContext
|
from app.core.auth import AuthContext
|
||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
|
from app.db import crud
|
||||||
from app.db.pagination import paginate
|
from app.db.pagination import paginate
|
||||||
from app.db.session import async_session_maker, get_session
|
from app.db.session import async_session_maker, get_session
|
||||||
from app.db.sqlmodel_exec import exec_dml
|
|
||||||
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
from app.integrations.openclaw_gateway import GatewayConfig as GatewayClientConfig
|
||||||
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
from app.integrations.openclaw_gateway import OpenClawGatewayError, ensure_session, send_message
|
||||||
from app.models.activity_events import ActivityEvent
|
from app.models.activity_events import ActivityEvent
|
||||||
@@ -997,17 +997,21 @@ async def delete_task(
|
|||||||
if auth.user is None:
|
if auth.user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
await require_board_access(session, user=auth.user, board=board, write=True)
|
await require_board_access(session, user=auth.user, board=board, write=True)
|
||||||
await exec_dml(session, delete(ActivityEvent).where(col(ActivityEvent.task_id) == task.id))
|
await crud.delete_where(
|
||||||
await exec_dml(session, delete(TaskFingerprint).where(col(TaskFingerprint.task_id) == task.id))
|
session, ActivityEvent, col(ActivityEvent.task_id) == task.id, commit=False
|
||||||
await exec_dml(session, delete(Approval).where(col(Approval.task_id) == task.id))
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
|
session, TaskFingerprint, col(TaskFingerprint.task_id) == task.id, commit=False
|
||||||
|
)
|
||||||
|
await crud.delete_where(session, Approval, col(Approval.task_id) == task.id, commit=False)
|
||||||
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(TaskDependency).where(
|
TaskDependency,
|
||||||
or_(
|
or_(
|
||||||
col(TaskDependency.task_id) == task.id,
|
col(TaskDependency.task_id) == task.id,
|
||||||
col(TaskDependency.depends_on_task_id) == task.id,
|
col(TaskDependency.depends_on_task_id) == task.id,
|
||||||
)
|
|
||||||
),
|
),
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
await session.delete(task)
|
await session.delete(task)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
|
from sqlalchemy import delete as sql_delete
|
||||||
|
from sqlalchemy import update as sql_update
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlmodel import SQLModel, select
|
from sqlmodel import SQLModel, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -19,6 +21,22 @@ class MultipleObjectsReturned(LookupError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _flush_or_rollback(session: AsyncSession) -> None:
|
||||||
|
try:
|
||||||
|
await session.flush()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _commit_or_rollback(session: AsyncSession) -> None:
|
||||||
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
|
def _lookup_statement(model: type[ModelT], lookup: Mapping[str, Any]) -> SelectOfScalar[ModelT]:
|
||||||
stmt = select(model)
|
stmt = select(model)
|
||||||
for key, value in lookup.items():
|
for key, value in lookup.items():
|
||||||
@@ -58,9 +76,9 @@ async def create(
|
|||||||
) -> ModelT:
|
) -> ModelT:
|
||||||
obj = model.model_validate(data)
|
obj = model.model_validate(data)
|
||||||
session.add(obj)
|
session.add(obj)
|
||||||
await session.flush()
|
await _flush_or_rollback(session)
|
||||||
if commit:
|
if commit:
|
||||||
await session.commit()
|
await _commit_or_rollback(session)
|
||||||
if refresh:
|
if refresh:
|
||||||
await session.refresh(obj)
|
await session.refresh(obj)
|
||||||
return obj
|
return obj
|
||||||
@@ -74,9 +92,9 @@ async def save(
|
|||||||
refresh: bool = True,
|
refresh: bool = True,
|
||||||
) -> ModelT:
|
) -> ModelT:
|
||||||
session.add(obj)
|
session.add(obj)
|
||||||
await session.flush()
|
await _flush_or_rollback(session)
|
||||||
if commit:
|
if commit:
|
||||||
await session.commit()
|
await _commit_or_rollback(session)
|
||||||
if refresh:
|
if refresh:
|
||||||
await session.refresh(obj)
|
await session.refresh(obj)
|
||||||
return obj
|
return obj
|
||||||
@@ -85,7 +103,7 @@ async def save(
|
|||||||
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
async def delete(session: AsyncSession, obj: ModelT, *, commit: bool = True) -> None:
|
||||||
await session.delete(obj)
|
await session.delete(obj)
|
||||||
if commit:
|
if commit:
|
||||||
await session.commit()
|
await _commit_or_rollback(session)
|
||||||
|
|
||||||
|
|
||||||
async def list_by(
|
async def list_by(
|
||||||
@@ -111,6 +129,77 @@ async def exists(session: AsyncSession, model: type[ModelT], **lookup: Any) -> b
|
|||||||
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
|
return (await session.exec(_lookup_statement(model, lookup).limit(1))).first() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _criteria_statement(model: type[ModelT], criteria: tuple[Any, ...]) -> SelectOfScalar[ModelT]:
|
||||||
|
stmt = select(model)
|
||||||
|
if criteria:
|
||||||
|
stmt = stmt.where(*criteria)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
async def list_where(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[ModelT],
|
||||||
|
*criteria: Any,
|
||||||
|
order_by: Iterable[Any] = (),
|
||||||
|
) -> list[ModelT]:
|
||||||
|
stmt = _criteria_statement(model, criteria)
|
||||||
|
for ordering in order_by:
|
||||||
|
stmt = stmt.order_by(ordering)
|
||||||
|
return list(await session.exec(stmt))
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_where(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[ModelT],
|
||||||
|
*criteria: Any,
|
||||||
|
commit: bool = False,
|
||||||
|
) -> int:
|
||||||
|
stmt = sql_delete(model)
|
||||||
|
if criteria:
|
||||||
|
stmt = stmt.where(*criteria)
|
||||||
|
result = await session.exec(cast(Any, stmt))
|
||||||
|
if commit:
|
||||||
|
await _commit_or_rollback(session)
|
||||||
|
rowcount = getattr(result, "rowcount", None)
|
||||||
|
return int(rowcount) if isinstance(rowcount, int) else 0
|
||||||
|
|
||||||
|
|
||||||
|
async def update_where(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[ModelT],
|
||||||
|
*criteria: Any,
|
||||||
|
updates: Mapping[str, Any] | None = None,
|
||||||
|
commit: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
allowed_fields: set[str] | None = None,
|
||||||
|
**update_fields: Any,
|
||||||
|
) -> int:
|
||||||
|
source_updates: dict[str, Any] = {}
|
||||||
|
if updates:
|
||||||
|
source_updates.update(dict(updates))
|
||||||
|
if update_fields:
|
||||||
|
source_updates.update(update_fields)
|
||||||
|
|
||||||
|
values: dict[str, Any] = {}
|
||||||
|
for key, value in source_updates.items():
|
||||||
|
if allowed_fields is not None and key not in allowed_fields:
|
||||||
|
continue
|
||||||
|
if exclude_none and value is None:
|
||||||
|
continue
|
||||||
|
values[key] = value
|
||||||
|
if not values:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
stmt = sql_update(model).values(**values)
|
||||||
|
if criteria:
|
||||||
|
stmt = stmt.where(*criteria)
|
||||||
|
result = await session.exec(cast(Any, stmt))
|
||||||
|
if commit:
|
||||||
|
await _commit_or_rollback(session)
|
||||||
|
rowcount = getattr(result, "rowcount", None)
|
||||||
|
return int(rowcount) if isinstance(rowcount, int) else 0
|
||||||
|
|
||||||
|
|
||||||
def apply_updates(
|
def apply_updates(
|
||||||
obj: ModelT,
|
obj: ModelT,
|
||||||
updates: Mapping[str, Any],
|
updates: Mapping[str, Any],
|
||||||
@@ -179,6 +268,9 @@ async def get_or_create(
|
|||||||
if existing is not None:
|
if existing is not None:
|
||||||
return existing, False
|
return existing, False
|
||||||
raise
|
raise
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
if refresh:
|
if refresh:
|
||||||
await session.refresh(obj)
|
await session.refresh(obj)
|
||||||
|
|||||||
@@ -11,9 +11,12 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async
|
|||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app import models # noqa: F401
|
from app import models as _models
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# Import model modules so SQLModel metadata is fully registered at startup.
|
||||||
|
_MODEL_REGISTRY = _models
|
||||||
|
|
||||||
|
|
||||||
def _normalize_database_url(database_url: str) -> str:
|
def _normalize_database_url(database_url: str) -> str:
|
||||||
if "://" not in database_url:
|
if "://" not in database_url:
|
||||||
@@ -64,4 +67,11 @@ async def init_db() -> None:
|
|||||||
|
|
||||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
async with async_session_maker() as session:
|
async with async_session_maker() as session:
|
||||||
yield session
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
await session.rollback()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to rollback session after request error.")
|
||||||
|
raise
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from sqlalchemy.sql.base import Executable
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
async def exec_dml(session: AsyncSession, statement: Executable) -> None:
|
|
||||||
# SQLModel's AsyncSession typing only overloads exec() for SELECT statements.
|
|
||||||
await session.exec(statement) # type: ignore[call-overload]
|
|
||||||
@@ -5,13 +5,13 @@ from typing import Iterable
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from sqlalchemy import delete, func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.time import utcnow
|
from app.core.time import utcnow
|
||||||
from app.db.sqlmodel_exec import exec_dml
|
from app.db import crud
|
||||||
from app.models.boards import Board
|
from app.models.boards import Board
|
||||||
from app.models.organization_board_access import OrganizationBoardAccess
|
from app.models.organization_board_access import OrganizationBoardAccess
|
||||||
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
|
from app.models.organization_invite_board_access import OrganizationInviteBoardAccess
|
||||||
@@ -328,11 +328,11 @@ async def apply_member_access_update(
|
|||||||
member.updated_at = now
|
member.updated_at = now
|
||||||
session.add(member)
|
session.add(member)
|
||||||
|
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationBoardAccess).where(
|
OrganizationBoardAccess,
|
||||||
col(OrganizationBoardAccess.organization_member_id) == member.id
|
col(OrganizationBoardAccess.organization_member_id) == member.id,
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if update.all_boards_read or update.all_boards_write:
|
if update.all_boards_read or update.all_boards_write:
|
||||||
@@ -359,11 +359,11 @@ async def apply_invite_board_access(
|
|||||||
invite: OrganizationInvite,
|
invite: OrganizationInvite,
|
||||||
entries: Iterable[OrganizationBoardAccessSpec],
|
entries: Iterable[OrganizationBoardAccessSpec],
|
||||||
) -> None:
|
) -> None:
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(OrganizationInviteBoardAccess).where(
|
OrganizationInviteBoardAccess,
|
||||||
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id
|
col(OrganizationInviteBoardAccess.organization_invite_id) == invite.id,
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
if invite.all_boards_read or invite.all_boards_write:
|
if invite.all_boards_read or invite.all_boards_write:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,11 +6,10 @@ from typing import Final
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from sqlalchemy import delete
|
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.db.sqlmodel_exec import exec_dml
|
from app.db import crud
|
||||||
from app.models.task_dependencies import TaskDependency
|
from app.models.task_dependencies import TaskDependency
|
||||||
from app.models.tasks import Task
|
from app.models.tasks import Task
|
||||||
|
|
||||||
@@ -195,11 +194,12 @@ async def replace_task_dependencies(
|
|||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
depends_on_task_ids=depends_on_task_ids,
|
depends_on_task_ids=depends_on_task_ids,
|
||||||
)
|
)
|
||||||
await exec_dml(
|
await crud.delete_where(
|
||||||
session,
|
session,
|
||||||
delete(TaskDependency)
|
TaskDependency,
|
||||||
.where(col(TaskDependency.board_id) == board_id)
|
col(TaskDependency.board_id) == board_id,
|
||||||
.where(col(TaskDependency.task_id) == task_id),
|
col(TaskDependency.task_id) == task_id,
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
for dep_id in normalized:
|
for dep_id in normalized:
|
||||||
session.add(
|
session.add(
|
||||||
|
|||||||
125
backend/tests/test_db_transaction_safety.py
Normal file
125
backend/tests/test_db_transaction_safety.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
from app.db import crud
|
||||||
|
from app.db import session as db_session
|
||||||
|
|
||||||
|
|
||||||
|
class _Thing(SQLModel, table=True):
|
||||||
|
__tablename__ = "_test_thing"
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _SessionCtx:
|
||||||
|
session: Any
|
||||||
|
entered: int = 0
|
||||||
|
exited: int = 0
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Any:
|
||||||
|
self.entered += 1
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
async def __aexit__(self, _exc_type: Any, _exc: Any, _tb: Any) -> bool:
|
||||||
|
self.exited += 1
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Maker:
|
||||||
|
ctx: _SessionCtx
|
||||||
|
|
||||||
|
def __call__(self) -> _SessionCtx:
|
||||||
|
return self.ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session_rolls_back_on_dependency_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
fake_session = SimpleNamespace(rollbacks=0)
|
||||||
|
|
||||||
|
async def _rollback() -> None:
|
||||||
|
fake_session.rollbacks += 1
|
||||||
|
|
||||||
|
fake_session.rollback = _rollback
|
||||||
|
ctx = _SessionCtx(fake_session)
|
||||||
|
monkeypatch.setattr(db_session, "async_session_maker", _Maker(ctx))
|
||||||
|
|
||||||
|
generator = db_session.get_session()
|
||||||
|
yielded = await anext(generator)
|
||||||
|
assert yielded is fake_session
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
|
await generator.athrow(RuntimeError("boom"))
|
||||||
|
|
||||||
|
assert fake_session.rollbacks == 1
|
||||||
|
assert ctx.entered == 1
|
||||||
|
assert ctx.exited == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_rolls_back_when_commit_fails() -> None:
|
||||||
|
@dataclass
|
||||||
|
class _FailCommitSession:
|
||||||
|
rollback_calls: int = 0
|
||||||
|
added: list[Any] = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.added is None:
|
||||||
|
self.added = []
|
||||||
|
|
||||||
|
def add(self, value: Any) -> None:
|
||||||
|
self.added.append(value)
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
raise RuntimeError("commit failed")
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rollback_calls += 1
|
||||||
|
|
||||||
|
async def refresh(self, _value: Any) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
session = _FailCommitSession()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="commit failed"):
|
||||||
|
await crud.create(session, _Thing, name="demo")
|
||||||
|
|
||||||
|
assert session.rollback_calls == 1
|
||||||
|
assert len(session.added) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_where_rolls_back_when_commit_fails() -> None:
|
||||||
|
@dataclass
|
||||||
|
class _FailCommitDmlSession:
|
||||||
|
rollback_calls: int = 0
|
||||||
|
exec_calls: int = 0
|
||||||
|
|
||||||
|
async def exec(self, _statement: Any) -> Any:
|
||||||
|
self.exec_calls += 1
|
||||||
|
return SimpleNamespace(rowcount=3)
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
raise RuntimeError("commit failed")
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rollback_calls += 1
|
||||||
|
|
||||||
|
session = _FailCommitDmlSession()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="commit failed"):
|
||||||
|
await crud.delete_where(session, _Thing, commit=True)
|
||||||
|
|
||||||
|
assert session.exec_calls == 1
|
||||||
|
assert session.rollback_calls == 1
|
||||||
Reference in New Issue
Block a user