feat(approvals): implement task locking and conflict resolution for pending approvals
This commit is contained in:
@@ -30,8 +30,10 @@ from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus,
|
||||
from app.schemas.pagination import DefaultLimitOffsetPage
|
||||
from app.services.activity_log import record_activity
|
||||
from app.services.approval_task_links import (
|
||||
lock_tasks_for_approval,
|
||||
load_task_ids_by_approval,
|
||||
normalize_task_ids,
|
||||
pending_approval_conflicts_by_task,
|
||||
replace_approval_task_links,
|
||||
task_counts_for_board,
|
||||
)
|
||||
@@ -114,6 +116,44 @@ def _serialize_approval(approval: ApprovalRead) -> dict[str, object]:
|
||||
return approval.model_dump(mode="json")
|
||||
|
||||
|
||||
def _pending_conflict_detail(conflicts: dict[UUID, UUID]) -> dict[str, object]:
|
||||
ordered = sorted(conflicts.items(), key=lambda item: str(item[0]))
|
||||
return {
|
||||
"message": "Each task can have only one pending approval.",
|
||||
"conflicts": [
|
||||
{
|
||||
"task_id": str(task_id),
|
||||
"approval_id": str(approval_id),
|
||||
}
|
||||
for task_id, approval_id in ordered
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def _ensure_no_pending_approval_conflicts(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
board_id: UUID,
|
||||
task_ids: Sequence[UUID],
|
||||
exclude_approval_id: UUID | None = None,
|
||||
) -> None:
|
||||
normalized_task_ids = list({*task_ids})
|
||||
if not normalized_task_ids:
|
||||
return
|
||||
await lock_tasks_for_approval(session, task_ids=normalized_task_ids)
|
||||
conflicts = await pending_approval_conflicts_by_task(
|
||||
session,
|
||||
board_id=board_id,
|
||||
task_ids=normalized_task_ids,
|
||||
exclude_approval_id=exclude_approval_id,
|
||||
)
|
||||
if conflicts:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=_pending_conflict_detail(conflicts),
|
||||
)
|
||||
|
||||
|
||||
def _approval_resolution_message(
|
||||
*,
|
||||
board: Board,
|
||||
@@ -324,6 +364,12 @@ async def create_approval(
|
||||
payload=payload.payload,
|
||||
)
|
||||
task_id = task_ids[0] if task_ids else None
|
||||
if payload.status == "pending":
|
||||
await _ensure_no_pending_approval_conflicts(
|
||||
session,
|
||||
board_id=board.id,
|
||||
task_ids=task_ids,
|
||||
)
|
||||
approval = Approval(
|
||||
board_id=board.id,
|
||||
task_id=task_id,
|
||||
@@ -360,7 +406,19 @@ async def update_approval(
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
prior_status = approval.status
|
||||
if "status" in updates:
|
||||
approval.status = updates["status"]
|
||||
target_status = updates["status"]
|
||||
if target_status == "pending" and prior_status != "pending":
|
||||
task_ids_by_approval = await load_task_ids_by_approval(session, approval_ids=[approval.id])
|
||||
approval_task_ids = task_ids_by_approval.get(approval.id)
|
||||
if not approval_task_ids and approval.task_id is not None:
|
||||
approval_task_ids = [approval.task_id]
|
||||
await _ensure_no_pending_approval_conflicts(
|
||||
session,
|
||||
board_id=board.id,
|
||||
task_ids=approval_task_ids or [],
|
||||
exclude_approval_id=approval.id,
|
||||
)
|
||||
approval.status = target_status
|
||||
if approval.status != "pending":
|
||||
approval.resolved_at = utcnow()
|
||||
session.add(approval)
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlmodel import col, select
|
||||
|
||||
from app.models.approval_task_links import ApprovalTaskLink
|
||||
from app.models.approvals import Approval
|
||||
from app.models.tasks import Task
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -121,6 +122,88 @@ async def replace_approval_task_links(
|
||||
session.add(ApprovalTaskLink(approval_id=approval_id, task_id=task_id))
|
||||
|
||||
|
||||
async def lock_tasks_for_approval(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
task_ids: Sequence[UUID],
|
||||
) -> None:
|
||||
"""Acquire row locks for task ids in deterministic order within a transaction."""
|
||||
normalized_task_ids = sorted({*task_ids}, key=str)
|
||||
if not normalized_task_ids:
|
||||
return
|
||||
statement = (
|
||||
select(col(Task.id))
|
||||
.where(col(Task.id).in_(normalized_task_ids))
|
||||
.order_by(col(Task.id).asc())
|
||||
.with_for_update()
|
||||
)
|
||||
# Materialize results so the lock query fully executes before proceeding.
|
||||
_ = list(await session.exec(statement))
|
||||
|
||||
|
||||
async def pending_approval_conflicts_by_task(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
board_id: UUID,
|
||||
task_ids: Sequence[UUID],
|
||||
exclude_approval_id: UUID | None = None,
|
||||
) -> dict[UUID, UUID]:
|
||||
"""Return the first conflicting pending approval id for each requested task id."""
|
||||
normalized_task_ids = list({*task_ids})
|
||||
if not normalized_task_ids:
|
||||
return {}
|
||||
|
||||
linked_statement = (
|
||||
select(
|
||||
col(ApprovalTaskLink.task_id),
|
||||
col(Approval.id),
|
||||
col(Approval.created_at),
|
||||
)
|
||||
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
|
||||
.where(col(Approval.board_id) == board_id)
|
||||
.where(col(Approval.status) == "pending")
|
||||
.where(col(ApprovalTaskLink.task_id).in_(normalized_task_ids))
|
||||
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
|
||||
)
|
||||
if exclude_approval_id is not None:
|
||||
linked_statement = linked_statement.where(col(Approval.id) != exclude_approval_id)
|
||||
linked_rows = list(await session.exec(linked_statement))
|
||||
|
||||
conflicts: dict[UUID, UUID] = {}
|
||||
for task_id, approval_id, _created_at in linked_rows:
|
||||
conflicts.setdefault(task_id, approval_id)
|
||||
|
||||
legacy_statement = (
|
||||
select(
|
||||
col(Approval.task_id),
|
||||
col(Approval.id),
|
||||
col(Approval.created_at),
|
||||
)
|
||||
.where(col(Approval.board_id) == board_id)
|
||||
.where(col(Approval.status) == "pending")
|
||||
.where(col(Approval.task_id).is_not(None))
|
||||
.where(col(Approval.task_id).in_(normalized_task_ids))
|
||||
.where(
|
||||
~exists(
|
||||
select(1)
|
||||
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
|
||||
.correlate(Approval),
|
||||
),
|
||||
)
|
||||
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
|
||||
)
|
||||
if exclude_approval_id is not None:
|
||||
legacy_statement = legacy_statement.where(col(Approval.id) != exclude_approval_id)
|
||||
legacy_rows = list(await session.exec(legacy_statement))
|
||||
|
||||
for task_id, approval_id, _created_at in legacy_rows:
|
||||
if task_id is None:
|
||||
continue
|
||||
conflicts.setdefault(task_id, approval_id)
|
||||
|
||||
return conflicts
|
||||
|
||||
|
||||
async def task_counts_for_board(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
|
||||
174
backend/tests/test_approvals_pending_conflicts.py
Normal file
174
backend/tests/test_approvals_pending_conflicts.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.api import approvals as approvals_api
|
||||
from app.models.boards import Board
|
||||
from app.models.organizations import Organization
|
||||
from app.models.tasks import Task
|
||||
from app.schemas.approvals import ApprovalCreate, ApprovalUpdate
|
||||
|
||||
|
||||
async def _make_engine() -> AsyncEngine:
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.connect() as conn, conn.begin():
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
return engine
|
||||
|
||||
|
||||
async def _make_session(engine: AsyncEngine) -> AsyncSession:
|
||||
return AsyncSession(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def _seed_board_with_tasks(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
task_count: int = 2,
|
||||
) -> tuple[Board, list[UUID]]:
|
||||
org_id = uuid4()
|
||||
board = Board(id=uuid4(), organization_id=org_id, name="b", slug="b")
|
||||
task_ids = [uuid4() for _ in range(task_count)]
|
||||
|
||||
session.add(Organization(id=org_id, name=f"org-{org_id}"))
|
||||
session.add(board)
|
||||
for task_id in task_ids:
|
||||
session.add(Task(id=task_id, board_id=board.id, title=f"task-{task_id}"))
|
||||
await session.commit()
|
||||
|
||||
return board, task_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_approval_rejects_duplicate_pending_for_same_task() -> None:
|
||||
engine = await _make_engine()
|
||||
try:
|
||||
async with await _make_session(engine) as session:
|
||||
board, task_ids = await _seed_board_with_tasks(session, task_count=1)
|
||||
task_id = task_ids[0]
|
||||
await approvals_api.create_approval(
|
||||
payload=ApprovalCreate(
|
||||
action_type="task.execute",
|
||||
task_id=task_id,
|
||||
confidence=80,
|
||||
status="pending",
|
||||
),
|
||||
board=board,
|
||||
session=session,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await approvals_api.create_approval(
|
||||
payload=ApprovalCreate(
|
||||
action_type="task.retry",
|
||||
task_id=task_id,
|
||||
confidence=77,
|
||||
status="pending",
|
||||
),
|
||||
board=board,
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 409
|
||||
detail = exc.value.detail
|
||||
assert isinstance(detail, dict)
|
||||
assert detail["message"] == "Each task can have only one pending approval."
|
||||
assert len(detail["conflicts"]) == 1
|
||||
assert detail["conflicts"][0]["task_id"] == str(task_id)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_approval_rejects_pending_conflict_from_linked_task_ids() -> None:
|
||||
engine = await _make_engine()
|
||||
try:
|
||||
async with await _make_session(engine) as session:
|
||||
board, task_ids = await _seed_board_with_tasks(session, task_count=2)
|
||||
task_a, task_b = task_ids
|
||||
await approvals_api.create_approval(
|
||||
payload=ApprovalCreate(
|
||||
action_type="task.batch_execute",
|
||||
task_ids=[task_a, task_b],
|
||||
confidence=85,
|
||||
status="pending",
|
||||
),
|
||||
board=board,
|
||||
session=session,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await approvals_api.create_approval(
|
||||
payload=ApprovalCreate(
|
||||
action_type="task.execute",
|
||||
task_id=task_b,
|
||||
confidence=70,
|
||||
status="pending",
|
||||
),
|
||||
board=board,
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 409
|
||||
detail = exc.value.detail
|
||||
assert isinstance(detail, dict)
|
||||
assert detail["message"] == "Each task can have only one pending approval."
|
||||
assert len(detail["conflicts"]) == 1
|
||||
assert detail["conflicts"][0]["task_id"] == str(task_b)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_approval_rejects_reopening_to_pending_with_existing_pending() -> None:
|
||||
engine = await _make_engine()
|
||||
try:
|
||||
async with await _make_session(engine) as session:
|
||||
board, task_ids = await _seed_board_with_tasks(session, task_count=1)
|
||||
task_id = task_ids[0]
|
||||
pending = await approvals_api.create_approval(
|
||||
payload=ApprovalCreate(
|
||||
action_type="task.execute",
|
||||
task_id=task_id,
|
||||
confidence=83,
|
||||
status="pending",
|
||||
),
|
||||
board=board,
|
||||
session=session,
|
||||
)
|
||||
resolved = await approvals_api.create_approval(
|
||||
payload=ApprovalCreate(
|
||||
action_type="task.review",
|
||||
task_id=task_id,
|
||||
confidence=90,
|
||||
status="approved",
|
||||
),
|
||||
board=board,
|
||||
session=session,
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await approvals_api.update_approval(
|
||||
approval_id=resolved.id, # type: ignore[arg-type]
|
||||
payload=ApprovalUpdate(status="pending"),
|
||||
board=board,
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 409
|
||||
detail = exc.value.detail
|
||||
assert isinstance(detail, dict)
|
||||
assert detail["message"] == "Each task can have only one pending approval."
|
||||
assert detail["conflicts"] == [
|
||||
{
|
||||
"task_id": str(task_id),
|
||||
"approval_id": str(pending.id),
|
||||
},
|
||||
]
|
||||
finally:
|
||||
await engine.dispose()
|
||||
Reference in New Issue
Block a user