274 lines
8.6 KiB
Python
274 lines
8.6 KiB
Python
"""Helpers for normalizing and querying approval-task associations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable, Sequence
|
|
from typing import TYPE_CHECKING
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import case, delete, exists, func
|
|
from sqlmodel import col, select
|
|
|
|
from app.models.approval_task_links import ApprovalTaskLink
|
|
from app.models.approvals import Approval
|
|
from app.models.tasks import Task
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
|
|
TASK_IDS_KEYS: tuple[str, ...] = ("task_ids", "taskIds", "taskIDs")
|
|
|
|
|
|
def _coerce_uuid(value: object) -> UUID | None:
|
|
if isinstance(value, UUID):
|
|
return value
|
|
if isinstance(value, str):
|
|
try:
|
|
return UUID(value)
|
|
except ValueError:
|
|
return None
|
|
return None
|
|
|
|
|
|
def extract_task_ids(payload: dict[str, object] | None) -> list[UUID]:
|
|
"""Extract task UUIDs from approval payload aliases."""
|
|
if not payload:
|
|
return []
|
|
|
|
collected: list[UUID] = []
|
|
for key in TASK_IDS_KEYS:
|
|
raw = payload.get(key)
|
|
if isinstance(raw, Sequence) and not isinstance(raw, (str, bytes, bytearray)):
|
|
for item in raw:
|
|
task_id = _coerce_uuid(item)
|
|
if task_id is not None:
|
|
collected.append(task_id)
|
|
for key in TASK_ID_KEYS:
|
|
task_id = _coerce_uuid(payload.get(key))
|
|
if task_id is not None:
|
|
collected.append(task_id)
|
|
|
|
deduped: list[UUID] = []
|
|
seen: set[UUID] = set()
|
|
for task_id in collected:
|
|
if task_id in seen:
|
|
continue
|
|
seen.add(task_id)
|
|
deduped.append(task_id)
|
|
return deduped
|
|
|
|
|
|
def normalize_task_ids(
|
|
*,
|
|
task_id: UUID | None,
|
|
task_ids: Sequence[UUID],
|
|
payload: dict[str, object] | None,
|
|
) -> list[UUID]:
|
|
"""Merge explicit and payload-provided task references into an ordered unique list."""
|
|
merged: list[UUID] = []
|
|
merged.extend(task_ids)
|
|
if task_id is not None:
|
|
merged.append(task_id)
|
|
merged.extend(extract_task_ids(payload))
|
|
|
|
deduped: list[UUID] = []
|
|
seen: set[UUID] = set()
|
|
for value in merged:
|
|
if value in seen:
|
|
continue
|
|
seen.add(value)
|
|
deduped.append(value)
|
|
return deduped
|
|
|
|
|
|
async def load_task_ids_by_approval(
|
|
session: AsyncSession,
|
|
*,
|
|
approval_ids: Iterable[UUID],
|
|
) -> dict[UUID, list[UUID]]:
|
|
"""Return task ids grouped by approval id in insertion order."""
|
|
ids = list({*approval_ids})
|
|
if not ids:
|
|
return {}
|
|
|
|
rows = list(
|
|
await session.exec(
|
|
select(col(ApprovalTaskLink.approval_id), col(ApprovalTaskLink.task_id))
|
|
.where(col(ApprovalTaskLink.approval_id).in_(ids))
|
|
.order_by(col(ApprovalTaskLink.created_at).asc()),
|
|
),
|
|
)
|
|
|
|
mapping: dict[UUID, list[UUID]] = {approval_id: [] for approval_id in ids}
|
|
for approval_id, task_id in rows:
|
|
mapping.setdefault(approval_id, []).append(task_id)
|
|
return mapping
|
|
|
|
|
|
async def replace_approval_task_links(
|
|
session: AsyncSession,
|
|
*,
|
|
approval_id: UUID,
|
|
task_ids: Sequence[UUID],
|
|
) -> None:
|
|
"""Replace approval-task link rows for an approval id."""
|
|
await session.exec(
|
|
delete(ApprovalTaskLink).where(
|
|
col(ApprovalTaskLink.approval_id) == approval_id,
|
|
),
|
|
)
|
|
for task_id in task_ids:
|
|
session.add(ApprovalTaskLink(approval_id=approval_id, task_id=task_id))
|
|
|
|
|
|
async def lock_tasks_for_approval(
|
|
session: AsyncSession,
|
|
*,
|
|
task_ids: Sequence[UUID],
|
|
) -> None:
|
|
"""Acquire row locks for task ids in deterministic order within a transaction."""
|
|
normalized_task_ids = sorted({*task_ids}, key=str)
|
|
if not normalized_task_ids:
|
|
return
|
|
statement = (
|
|
select(col(Task.id))
|
|
.where(col(Task.id).in_(normalized_task_ids))
|
|
.order_by(col(Task.id).asc())
|
|
.with_for_update()
|
|
)
|
|
# Materialize results so the lock query fully executes before proceeding.
|
|
_ = list(await session.exec(statement))
|
|
|
|
|
|
async def pending_approval_conflicts_by_task(
|
|
session: AsyncSession,
|
|
*,
|
|
board_id: UUID,
|
|
task_ids: Sequence[UUID],
|
|
exclude_approval_id: UUID | None = None,
|
|
) -> dict[UUID, UUID]:
|
|
"""Return the first conflicting pending approval id for each requested task id."""
|
|
normalized_task_ids = list({*task_ids})
|
|
if not normalized_task_ids:
|
|
return {}
|
|
|
|
linked_statement = (
|
|
select(
|
|
col(ApprovalTaskLink.task_id),
|
|
col(Approval.id),
|
|
col(Approval.created_at),
|
|
)
|
|
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
|
|
.where(col(Approval.board_id) == board_id)
|
|
.where(col(Approval.status) == "pending")
|
|
.where(col(ApprovalTaskLink.task_id).in_(normalized_task_ids))
|
|
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
|
|
)
|
|
if exclude_approval_id is not None:
|
|
linked_statement = linked_statement.where(col(Approval.id) != exclude_approval_id)
|
|
linked_rows = list(await session.exec(linked_statement))
|
|
|
|
conflicts: dict[UUID, UUID] = {}
|
|
for task_id, approval_id, _created_at in linked_rows:
|
|
conflicts.setdefault(task_id, approval_id)
|
|
|
|
legacy_statement = (
|
|
select(
|
|
col(Approval.task_id),
|
|
col(Approval.id),
|
|
col(Approval.created_at),
|
|
)
|
|
.where(col(Approval.board_id) == board_id)
|
|
.where(col(Approval.status) == "pending")
|
|
.where(col(Approval.task_id).is_not(None))
|
|
.where(col(Approval.task_id).in_(normalized_task_ids))
|
|
.where(
|
|
~exists(
|
|
select(1)
|
|
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
|
|
.correlate(Approval),
|
|
),
|
|
)
|
|
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
|
|
)
|
|
if exclude_approval_id is not None:
|
|
legacy_statement = legacy_statement.where(col(Approval.id) != exclude_approval_id)
|
|
legacy_rows = list(await session.exec(legacy_statement))
|
|
|
|
for legacy_task_id, approval_id, _created_at in legacy_rows:
|
|
if legacy_task_id is None:
|
|
continue
|
|
conflicts.setdefault(legacy_task_id, approval_id)
|
|
|
|
return conflicts
|
|
|
|
|
|
async def task_counts_for_board(
|
|
session: AsyncSession,
|
|
*,
|
|
board_id: UUID,
|
|
task_ids: set[UUID] | None = None,
|
|
) -> dict[UUID, tuple[int, int]]:
|
|
"""Compute total/pending approval counts per task across all linked tasks on a board."""
|
|
|
|
link_statement = (
|
|
select(
|
|
col(ApprovalTaskLink.task_id),
|
|
func.count(col(Approval.id)).label("total"),
|
|
func.sum(
|
|
case(
|
|
(col(Approval.status) == "pending", 1),
|
|
else_=0,
|
|
),
|
|
).label("pending"),
|
|
)
|
|
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
|
|
.where(col(Approval.board_id) == board_id)
|
|
)
|
|
if task_ids is not None:
|
|
if not task_ids:
|
|
return {}
|
|
link_statement = link_statement.where(col(ApprovalTaskLink.task_id).in_(task_ids))
|
|
link_statement = link_statement.group_by(col(ApprovalTaskLink.task_id))
|
|
|
|
counts: dict[UUID, tuple[int, int]] = {}
|
|
for task_id, total, pending in list(await session.exec(link_statement)):
|
|
counts[task_id] = (int(total or 0), int(pending or 0))
|
|
|
|
# Backward compatibility: include legacy rows that have task_id set but no link rows.
|
|
legacy_statement = (
|
|
select(
|
|
col(Approval.task_id),
|
|
func.count(col(Approval.id)).label("total"),
|
|
func.sum(
|
|
case(
|
|
(col(Approval.status) == "pending", 1),
|
|
else_=0,
|
|
),
|
|
).label("pending"),
|
|
)
|
|
.where(col(Approval.board_id) == board_id)
|
|
.where(col(Approval.task_id).is_not(None))
|
|
.where(
|
|
~exists(
|
|
select(1)
|
|
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
|
|
.correlate(Approval),
|
|
),
|
|
)
|
|
)
|
|
if task_ids is not None:
|
|
legacy_statement = legacy_statement.where(col(Approval.task_id).in_(task_ids))
|
|
legacy_statement = legacy_statement.group_by(col(Approval.task_id))
|
|
|
|
for legacy_task_id, total, pending in list(await session.exec(legacy_statement)):
|
|
if legacy_task_id is None:
|
|
continue
|
|
previous = counts.get(legacy_task_id, (0, 0))
|
|
counts[legacy_task_id] = (
|
|
previous[0] + int(total or 0),
|
|
previous[1] + int(pending or 0),
|
|
)
|
|
return counts
|