feat(approvals): implement task locking and conflict resolution for pending approvals

This commit is contained in:
Abhimanyu Saharan
2026-02-12 14:46:31 +05:30
parent c427a8240f
commit e93b1864e5
3 changed files with 316 additions and 1 deletions

View File

@@ -30,8 +30,10 @@ from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus,
from app.schemas.pagination import DefaultLimitOffsetPage
from app.services.activity_log import record_activity
from app.services.approval_task_links import (
lock_tasks_for_approval,
load_task_ids_by_approval,
normalize_task_ids,
pending_approval_conflicts_by_task,
replace_approval_task_links,
task_counts_for_board,
)
@@ -114,6 +116,44 @@ def _serialize_approval(approval: ApprovalRead) -> dict[str, object]:
return approval.model_dump(mode="json")
def _pending_conflict_detail(conflicts: dict[UUID, UUID]) -> dict[str, object]:
ordered = sorted(conflicts.items(), key=lambda item: str(item[0]))
return {
"message": "Each task can have only one pending approval.",
"conflicts": [
{
"task_id": str(task_id),
"approval_id": str(approval_id),
}
for task_id, approval_id in ordered
],
}
async def _ensure_no_pending_approval_conflicts(
session: AsyncSession,
*,
board_id: UUID,
task_ids: Sequence[UUID],
exclude_approval_id: UUID | None = None,
) -> None:
normalized_task_ids = list({*task_ids})
if not normalized_task_ids:
return
await lock_tasks_for_approval(session, task_ids=normalized_task_ids)
conflicts = await pending_approval_conflicts_by_task(
session,
board_id=board_id,
task_ids=normalized_task_ids,
exclude_approval_id=exclude_approval_id,
)
if conflicts:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=_pending_conflict_detail(conflicts),
)
def _approval_resolution_message(
*,
board: Board,
@@ -324,6 +364,12 @@ async def create_approval(
payload=payload.payload,
)
task_id = task_ids[0] if task_ids else None
if payload.status == "pending":
await _ensure_no_pending_approval_conflicts(
session,
board_id=board.id,
task_ids=task_ids,
)
approval = Approval(
board_id=board.id,
task_id=task_id,
@@ -360,7 +406,19 @@ async def update_approval(
updates = payload.model_dump(exclude_unset=True)
prior_status = approval.status
if "status" in updates:
approval.status = updates["status"]
target_status = updates["status"]
if target_status == "pending" and prior_status != "pending":
task_ids_by_approval = await load_task_ids_by_approval(session, approval_ids=[approval.id])
approval_task_ids = task_ids_by_approval.get(approval.id)
if not approval_task_ids and approval.task_id is not None:
approval_task_ids = [approval.task_id]
await _ensure_no_pending_approval_conflicts(
session,
board_id=board.id,
task_ids=approval_task_ids or [],
exclude_approval_id=approval.id,
)
approval.status = target_status
if approval.status != "pending":
approval.resolved_at = utcnow()
session.add(approval)

View File

@@ -11,6 +11,7 @@ from sqlmodel import col, select
from app.models.approval_task_links import ApprovalTaskLink
from app.models.approvals import Approval
from app.models.tasks import Task
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -121,6 +122,88 @@ async def replace_approval_task_links(
session.add(ApprovalTaskLink(approval_id=approval_id, task_id=task_id))
async def lock_tasks_for_approval(
session: AsyncSession,
*,
task_ids: Sequence[UUID],
) -> None:
"""Acquire row locks for task ids in deterministic order within a transaction."""
normalized_task_ids = sorted({*task_ids}, key=str)
if not normalized_task_ids:
return
statement = (
select(col(Task.id))
.where(col(Task.id).in_(normalized_task_ids))
.order_by(col(Task.id).asc())
.with_for_update()
)
# Materialize results so the lock query fully executes before proceeding.
_ = list(await session.exec(statement))
async def pending_approval_conflicts_by_task(
session: AsyncSession,
*,
board_id: UUID,
task_ids: Sequence[UUID],
exclude_approval_id: UUID | None = None,
) -> dict[UUID, UUID]:
"""Return the first conflicting pending approval id for each requested task id."""
normalized_task_ids = list({*task_ids})
if not normalized_task_ids:
return {}
linked_statement = (
select(
col(ApprovalTaskLink.task_id),
col(Approval.id),
col(Approval.created_at),
)
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
.where(col(Approval.board_id) == board_id)
.where(col(Approval.status) == "pending")
.where(col(ApprovalTaskLink.task_id).in_(normalized_task_ids))
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
)
if exclude_approval_id is not None:
linked_statement = linked_statement.where(col(Approval.id) != exclude_approval_id)
linked_rows = list(await session.exec(linked_statement))
conflicts: dict[UUID, UUID] = {}
for task_id, approval_id, _created_at in linked_rows:
conflicts.setdefault(task_id, approval_id)
legacy_statement = (
select(
col(Approval.task_id),
col(Approval.id),
col(Approval.created_at),
)
.where(col(Approval.board_id) == board_id)
.where(col(Approval.status) == "pending")
.where(col(Approval.task_id).is_not(None))
.where(col(Approval.task_id).in_(normalized_task_ids))
.where(
~exists(
select(1)
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
.correlate(Approval),
),
)
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
)
if exclude_approval_id is not None:
legacy_statement = legacy_statement.where(col(Approval.id) != exclude_approval_id)
legacy_rows = list(await session.exec(legacy_statement))
for task_id, approval_id, _created_at in legacy_rows:
if task_id is None:
continue
conflicts.setdefault(task_id, approval_id)
return conflicts
async def task_counts_for_board(
session: AsyncSession,
*,

View File

@@ -0,0 +1,174 @@
from __future__ import annotations
from uuid import UUID, uuid4
import pytest
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from app.api import approvals as approvals_api
from app.models.boards import Board
from app.models.organizations import Organization
from app.models.tasks import Task
from app.schemas.approvals import ApprovalCreate, ApprovalUpdate
async def _make_engine() -> AsyncEngine:
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.connect() as conn, conn.begin():
await conn.run_sync(SQLModel.metadata.create_all)
return engine
async def _make_session(engine: AsyncEngine) -> AsyncSession:
return AsyncSession(engine, expire_on_commit=False)
async def _seed_board_with_tasks(
session: AsyncSession,
*,
task_count: int = 2,
) -> tuple[Board, list[UUID]]:
org_id = uuid4()
board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b")
task_ids = [uuid4() for _ in range(task_count)]
session.add(Organization(id=org_id, name=f"org-{org_id}"))
session.add(board)
for task_id in task_ids:
session.add(Task(id=task_id, board_id=board.id, title=f"task-{task_id}"))
await session.commit()
return board, task_ids
@pytest.mark.asyncio
async def test_create_approval_rejects_duplicate_pending_for_same_task() -> None:
engine = await _make_engine()
try:
async with await _make_session(engine) as session:
board, task_ids = await _seed_board_with_tasks(session, task_count=1)
task_id = task_ids[0]
await approvals_api.create_approval(
payload=ApprovalCreate(
action_type="task.execute",
task_id=task_id,
confidence=80,
status="pending",
),
board=board,
session=session,
)
with pytest.raises(HTTPException) as exc:
await approvals_api.create_approval(
payload=ApprovalCreate(
action_type="task.retry",
task_id=task_id,
confidence=77,
status="pending",
),
board=board,
session=session,
)
assert exc.value.status_code == 409
detail = exc.value.detail
assert isinstance(detail, dict)
assert detail["message"] == "Each task can have only one pending approval."
assert len(detail["conflicts"]) == 1
assert detail["conflicts"][0]["task_id"] == str(task_id)
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_create_approval_rejects_pending_conflict_from_linked_task_ids() -> None:
engine = await _make_engine()
try:
async with await _make_session(engine) as session:
board, task_ids = await _seed_board_with_tasks(session, task_count=2)
task_a, task_b = task_ids
await approvals_api.create_approval(
payload=ApprovalCreate(
action_type="task.batch_execute",
task_ids=[task_a, task_b],
confidence=85,
status="pending",
),
board=board,
session=session,
)
with pytest.raises(HTTPException) as exc:
await approvals_api.create_approval(
payload=ApprovalCreate(
action_type="task.execute",
task_id=task_b,
confidence=70,
status="pending",
),
board=board,
session=session,
)
assert exc.value.status_code == 409
detail = exc.value.detail
assert isinstance(detail, dict)
assert detail["message"] == "Each task can have only one pending approval."
assert len(detail["conflicts"]) == 1
assert detail["conflicts"][0]["task_id"] == str(task_b)
finally:
await engine.dispose()
@pytest.mark.asyncio
async def test_update_approval_rejects_reopening_to_pending_with_existing_pending() -> None:
engine = await _make_engine()
try:
async with await _make_session(engine) as session:
board, task_ids = await _seed_board_with_tasks(session, task_count=1)
task_id = task_ids[0]
pending = await approvals_api.create_approval(
payload=ApprovalCreate(
action_type="task.execute",
task_id=task_id,
confidence=83,
status="pending",
),
board=board,
session=session,
)
resolved = await approvals_api.create_approval(
payload=ApprovalCreate(
action_type="task.review",
task_id=task_id,
confidence=90,
status="approved",
),
board=board,
session=session,
)
with pytest.raises(HTTPException) as exc:
await approvals_api.update_approval(
approval_id=resolved.id, # type: ignore[arg-type]
payload=ApprovalUpdate(status="pending"),
board=board,
session=session,
)
assert exc.value.status_code == 409
detail = exc.value.detail
assert isinstance(detail, dict)
assert detail["message"] == "Each task can have only one pending approval."
assert detail["conflicts"] == [
{
"task_id": str(task_id),
"approval_id": str(pending.id),
},
]
finally:
await engine.dispose()