feat(approvals): implement task locking and conflict resolution for pending approvals

This commit is contained in:
Abhimanyu Saharan
2026-02-12 14:46:31 +05:30
parent c427a8240f
commit e93b1864e5
3 changed files with 316 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ from sqlmodel import col, select
from app.models.approval_task_links import ApprovalTaskLink
from app.models.approvals import Approval
from app.models.tasks import Task
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -121,6 +122,88 @@ async def replace_approval_task_links(
session.add(ApprovalTaskLink(approval_id=approval_id, task_id=task_id))
async def lock_tasks_for_approval(
session: AsyncSession,
*,
task_ids: Sequence[UUID],
) -> None:
"""Acquire row locks for task ids in deterministic order within a transaction."""
normalized_task_ids = sorted({*task_ids}, key=str)
if not normalized_task_ids:
return
statement = (
select(col(Task.id))
.where(col(Task.id).in_(normalized_task_ids))
.order_by(col(Task.id).asc())
.with_for_update()
)
# Materialize results so the lock query fully executes before proceeding.
_ = list(await session.exec(statement))
async def pending_approval_conflicts_by_task(
session: AsyncSession,
*,
board_id: UUID,
task_ids: Sequence[UUID],
exclude_approval_id: UUID | None = None,
) -> dict[UUID, UUID]:
"""Return the first conflicting pending approval id for each requested task id."""
normalized_task_ids = list({*task_ids})
if not normalized_task_ids:
return {}
linked_statement = (
select(
col(ApprovalTaskLink.task_id),
col(Approval.id),
col(Approval.created_at),
)
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
.where(col(Approval.board_id) == board_id)
.where(col(Approval.status) == "pending")
.where(col(ApprovalTaskLink.task_id).in_(normalized_task_ids))
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
)
if exclude_approval_id is not None:
linked_statement = linked_statement.where(col(Approval.id) != exclude_approval_id)
linked_rows = list(await session.exec(linked_statement))
conflicts: dict[UUID, UUID] = {}
for task_id, approval_id, _created_at in linked_rows:
conflicts.setdefault(task_id, approval_id)
legacy_statement = (
select(
col(Approval.task_id),
col(Approval.id),
col(Approval.created_at),
)
.where(col(Approval.board_id) == board_id)
.where(col(Approval.status) == "pending")
.where(col(Approval.task_id).is_not(None))
.where(col(Approval.task_id).in_(normalized_task_ids))
.where(
~exists(
select(1)
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
.correlate(Approval),
),
)
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
)
if exclude_approval_id is not None:
legacy_statement = legacy_statement.where(col(Approval.id) != exclude_approval_id)
legacy_rows = list(await session.exec(legacy_statement))
for task_id, approval_id, _created_at in legacy_rows:
if task_id is None:
continue
conflicts.setdefault(task_id, approval_id)
return conflicts
async def task_counts_for_board(
session: AsyncSession,
*,