diff --git a/backend/app/api/approvals.py b/backend/app/api/approvals.py index edcada9c..5619e84e 100644 --- a/backend/app/api/approvals.py +++ b/backend/app/api/approvals.py @@ -30,8 +30,10 @@ from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, from app.schemas.pagination import DefaultLimitOffsetPage from app.services.activity_log import record_activity from app.services.approval_task_links import ( + lock_tasks_for_approval, load_task_ids_by_approval, normalize_task_ids, + pending_approval_conflicts_by_task, replace_approval_task_links, task_counts_for_board, ) @@ -114,6 +116,44 @@ def _serialize_approval(approval: ApprovalRead) -> dict[str, object]: return approval.model_dump(mode="json") +def _pending_conflict_detail(conflicts: dict[UUID, UUID]) -> dict[str, object]: + ordered = sorted(conflicts.items(), key=lambda item: str(item[0])) + return { + "message": "Each task can have only one pending approval.", + "conflicts": [ + { + "task_id": str(task_id), + "approval_id": str(approval_id), + } + for task_id, approval_id in ordered + ], + } + + +async def _ensure_no_pending_approval_conflicts( + session: AsyncSession, + *, + board_id: UUID, + task_ids: Sequence[UUID], + exclude_approval_id: UUID | None = None, +) -> None: + normalized_task_ids = list({*task_ids}) + if not normalized_task_ids: + return + await lock_tasks_for_approval(session, task_ids=normalized_task_ids) + conflicts = await pending_approval_conflicts_by_task( + session, + board_id=board_id, + task_ids=normalized_task_ids, + exclude_approval_id=exclude_approval_id, + ) + if conflicts: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=_pending_conflict_detail(conflicts), + ) + + def _approval_resolution_message( *, board: Board, @@ -324,6 +364,12 @@ async def create_approval( payload=payload.payload, ) task_id = task_ids[0] if task_ids else None + if payload.status == "pending": + await _ensure_no_pending_approval_conflicts( + session, + board_id=board.id, + task_ids=task_ids, + ) approval = Approval( board_id=board.id, task_id=task_id, @@ -360,7 +406,19 @@ async def update_approval( updates = payload.model_dump(exclude_unset=True) prior_status = approval.status if "status" in updates: - approval.status = updates["status"] + target_status = updates["status"] + if target_status == "pending" and prior_status != "pending": + task_ids_by_approval = await load_task_ids_by_approval(session, approval_ids=[approval.id]) + approval_task_ids = task_ids_by_approval.get(approval.id) + if not approval_task_ids and approval.task_id is not None: + approval_task_ids = [approval.task_id] + await _ensure_no_pending_approval_conflicts( + session, + board_id=board.id, + task_ids=approval_task_ids or [], + exclude_approval_id=approval.id, + ) + approval.status = target_status if approval.status != "pending": approval.resolved_at = utcnow() session.add(approval) diff --git a/backend/app/services/approval_task_links.py b/backend/app/services/approval_task_links.py index 5ee684f6..a1dc7a6d 100644 --- a/backend/app/services/approval_task_links.py +++ b/backend/app/services/approval_task_links.py @@ -11,6 +11,7 @@ from sqlmodel import col, select from app.models.approval_task_links import ApprovalTaskLink from app.models.approvals import Approval +from app.models.tasks import Task if TYPE_CHECKING: from sqlmodel.ext.asyncio.session import AsyncSession @@ -121,6 +122,88 @@ async def replace_approval_task_links( session.add(ApprovalTaskLink(approval_id=approval_id, task_id=task_id)) +async def lock_tasks_for_approval( + session: AsyncSession, + *, + task_ids: Sequence[UUID], +) -> None: + """Acquire row locks for task ids in deterministic order within a transaction.""" + normalized_task_ids = sorted({*task_ids}, key=str) + if not normalized_task_ids: + return + statement = ( + select(col(Task.id)) + .where(col(Task.id).in_(normalized_task_ids)) + .order_by(col(Task.id).asc()) + .with_for_update() + ) + # Materialize results so the lock query fully executes before proceeding. + _ = list(await session.exec(statement)) + + +async def pending_approval_conflicts_by_task( + session: AsyncSession, + *, + board_id: UUID, + task_ids: Sequence[UUID], + exclude_approval_id: UUID | None = None, +) -> dict[UUID, UUID]: + """Return the first conflicting pending approval id for each requested task id.""" + normalized_task_ids = list({*task_ids}) + if not normalized_task_ids: + return {} + + linked_statement = ( + select( + col(ApprovalTaskLink.task_id), + col(Approval.id), + col(Approval.created_at), + ) + .join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id)) + .where(col(Approval.board_id) == board_id) + .where(col(Approval.status) == "pending") + .where(col(ApprovalTaskLink.task_id).in_(normalized_task_ids)) + .order_by(col(Approval.created_at).asc(), col(Approval.id).asc()) + ) + if exclude_approval_id is not None: + linked_statement = linked_statement.where(col(Approval.id) != exclude_approval_id) + linked_rows = list(await session.exec(linked_statement)) + + conflicts: dict[UUID, UUID] = {} + for task_id, approval_id, _created_at in linked_rows: + conflicts.setdefault(task_id, approval_id) + + legacy_statement = ( + select( + col(Approval.task_id), + col(Approval.id), + col(Approval.created_at), + ) + .where(col(Approval.board_id) == board_id) + .where(col(Approval.status) == "pending") + .where(col(Approval.task_id).is_not(None)) + .where(col(Approval.task_id).in_(normalized_task_ids)) + .where( + ~exists( + select(1) + .where(col(ApprovalTaskLink.approval_id) == col(Approval.id)) + .correlate(Approval), + ), + ) + .order_by(col(Approval.created_at).asc(), col(Approval.id).asc()) + ) + if exclude_approval_id is not None: + legacy_statement = legacy_statement.where(col(Approval.id) != exclude_approval_id) + legacy_rows = list(await session.exec(legacy_statement)) + + for task_id, approval_id, _created_at in legacy_rows: + if task_id is None: + continue + conflicts.setdefault(task_id, approval_id) + + return conflicts + + async def task_counts_for_board( session: AsyncSession, *, diff --git a/backend/tests/test_approvals_pending_conflicts.py b/backend/tests/test_approvals_pending_conflicts.py new file mode 100644 index 00000000..ad5661fd --- /dev/null +++ b/backend/tests/test_approvals_pending_conflicts.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from uuid import UUID, uuid4 + +import pytest +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.api import approvals as approvals_api +from app.models.boards import Board +from app.models.organizations import Organization +from app.models.tasks import Task +from app.schemas.approvals import ApprovalCreate, ApprovalUpdate + + +async def _make_engine() -> AsyncEngine: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.connect() as conn, conn.begin(): + await conn.run_sync(SQLModel.metadata.create_all) + return engine + + +async def _make_session(engine: AsyncEngine) -> AsyncSession: + return AsyncSession(engine, expire_on_commit=False) + + +async def _seed_board_with_tasks( + session: AsyncSession, + *, + task_count: int = 2, +) -> tuple[Board, list[UUID]]: + org_id = uuid4() + board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b") + task_ids = [uuid4() for _ in range(task_count)] + + session.add(Organization(id=org_id, name=f"org-{org_id}")) + session.add(board) + for task_id in task_ids: + session.add(Task(id=task_id, board_id=board.id, title=f"task-{task_id}")) + await session.commit() + + return board, task_ids + + +@pytest.mark.asyncio +async def test_create_approval_rejects_duplicate_pending_for_same_task() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + board, task_ids = await _seed_board_with_tasks(session, task_count=1) + task_id = task_ids[0] + await approvals_api.create_approval( + payload=ApprovalCreate( + action_type="task.execute", + task_id=task_id, + confidence=80, + status="pending", + ), + board=board, + session=session, + ) + + with pytest.raises(HTTPException) as exc: + await approvals_api.create_approval( + payload=ApprovalCreate( + action_type="task.retry", + task_id=task_id, + confidence=77, + status="pending", + ), + board=board, + session=session, + ) + + assert exc.value.status_code == 409 + detail = exc.value.detail + assert isinstance(detail, dict) + assert detail["message"] == "Each task can have only one pending approval." + assert len(detail["conflicts"]) == 1 + assert detail["conflicts"][0]["task_id"] == str(task_id) + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_create_approval_rejects_pending_conflict_from_linked_task_ids() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + board, task_ids = await _seed_board_with_tasks(session, task_count=2) + task_a, task_b = task_ids + await approvals_api.create_approval( + payload=ApprovalCreate( + action_type="task.batch_execute", + task_ids=[task_a, task_b], + confidence=85, + status="pending", + ), + board=board, + session=session, + ) + + with pytest.raises(HTTPException) as exc: + await approvals_api.create_approval( + payload=ApprovalCreate( + action_type="task.execute", + task_id=task_b, + confidence=70, + status="pending", + ), + board=board, + session=session, + ) + + assert exc.value.status_code == 409 + detail = exc.value.detail + assert isinstance(detail, dict) + assert detail["message"] == "Each task can have only one pending approval." + assert len(detail["conflicts"]) == 1 + assert detail["conflicts"][0]["task_id"] == str(task_b) + finally: + await engine.dispose() + + +@pytest.mark.asyncio +async def test_update_approval_rejects_reopening_to_pending_with_existing_pending() -> None: + engine = await _make_engine() + try: + async with await _make_session(engine) as session: + board, task_ids = await _seed_board_with_tasks(session, task_count=1) + task_id = task_ids[0] + pending = await approvals_api.create_approval( + payload=ApprovalCreate( + action_type="task.execute", + task_id=task_id, + confidence=83, + status="pending", + ), + board=board, + session=session, + ) + resolved = await approvals_api.create_approval( + payload=ApprovalCreate( + action_type="task.review", + task_id=task_id, + confidence=90, + status="approved", + ), + board=board, + session=session, + ) + + with pytest.raises(HTTPException) as exc: + await approvals_api.update_approval( + approval_id=resolved.id, # type: ignore[arg-type] + payload=ApprovalUpdate(status="pending"), + board=board, + session=session, + ) + + assert exc.value.status_code == 409 + detail = exc.value.detail + assert isinstance(detail, dict) + assert detail["message"] == "Each task can have only one pending approval." + assert detail["conflicts"] == [ + { + "task_id": str(task_id), + "approval_id": str(pending.id), + }, + ] + finally: + await engine.dispose()