Files
openclaw-mission-control/backend/app/services/approval_task_links.py

274 lines
8.6 KiB
Python

"""Helpers for normalizing and querying approval-task associations."""
from __future__ import annotations
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import case, delete, exists, func
from sqlmodel import col, select
from app.models.approval_task_links import ApprovalTaskLink
from app.models.approvals import Approval
from app.models.tasks import Task
if TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncSession
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
TASK_IDS_KEYS: tuple[str, ...] = ("task_ids", "taskIds", "taskIDs")
def _coerce_uuid(value: object) -> UUID | None:
if isinstance(value, UUID):
return value
if isinstance(value, str):
try:
return UUID(value)
except ValueError:
return None
return None
def extract_task_ids(payload: dict[str, object] | None) -> list[UUID]:
"""Extract task UUIDs from approval payload aliases."""
if not payload:
return []
collected: list[UUID] = []
for key in TASK_IDS_KEYS:
raw = payload.get(key)
if isinstance(raw, Sequence) and not isinstance(raw, (str, bytes, bytearray)):
for item in raw:
task_id = _coerce_uuid(item)
if task_id is not None:
collected.append(task_id)
for key in TASK_ID_KEYS:
task_id = _coerce_uuid(payload.get(key))
if task_id is not None:
collected.append(task_id)
deduped: list[UUID] = []
seen: set[UUID] = set()
for task_id in collected:
if task_id in seen:
continue
seen.add(task_id)
deduped.append(task_id)
return deduped
def normalize_task_ids(
*,
task_id: UUID | None,
task_ids: Sequence[UUID],
payload: dict[str, object] | None,
) -> list[UUID]:
"""Merge explicit and payload-provided task references into an ordered unique list."""
merged: list[UUID] = []
merged.extend(task_ids)
if task_id is not None:
merged.append(task_id)
merged.extend(extract_task_ids(payload))
deduped: list[UUID] = []
seen: set[UUID] = set()
for value in merged:
if value in seen:
continue
seen.add(value)
deduped.append(value)
return deduped
async def load_task_ids_by_approval(
session: AsyncSession,
*,
approval_ids: Iterable[UUID],
) -> dict[UUID, list[UUID]]:
"""Return task ids grouped by approval id in insertion order."""
ids = list({*approval_ids})
if not ids:
return {}
rows = list(
await session.exec(
select(col(ApprovalTaskLink.approval_id), col(ApprovalTaskLink.task_id))
.where(col(ApprovalTaskLink.approval_id).in_(ids))
.order_by(col(ApprovalTaskLink.created_at).asc()),
),
)
mapping: dict[UUID, list[UUID]] = {approval_id: [] for approval_id in ids}
for approval_id, task_id in rows:
mapping.setdefault(approval_id, []).append(task_id)
return mapping
async def replace_approval_task_links(
session: AsyncSession,
*,
approval_id: UUID,
task_ids: Sequence[UUID],
) -> None:
"""Replace approval-task link rows for an approval id."""
await session.exec(
delete(ApprovalTaskLink).where(
col(ApprovalTaskLink.approval_id) == approval_id,
),
)
for task_id in task_ids:
session.add(ApprovalTaskLink(approval_id=approval_id, task_id=task_id))
async def lock_tasks_for_approval(
session: AsyncSession,
*,
task_ids: Sequence[UUID],
) -> None:
"""Acquire row locks for task ids in deterministic order within a transaction."""
normalized_task_ids = sorted({*task_ids}, key=str)
if not normalized_task_ids:
return
statement = (
select(col(Task.id))
.where(col(Task.id).in_(normalized_task_ids))
.order_by(col(Task.id).asc())
.with_for_update()
)
# Materialize results so the lock query fully executes before proceeding.
_ = list(await session.exec(statement))
async def pending_approval_conflicts_by_task(
session: AsyncSession,
*,
board_id: UUID,
task_ids: Sequence[UUID],
exclude_approval_id: UUID | None = None,
) -> dict[UUID, UUID]:
"""Return the first conflicting pending approval id for each requested task id."""
normalized_task_ids = list({*task_ids})
if not normalized_task_ids:
return {}
linked_statement = (
select(
col(ApprovalTaskLink.task_id),
col(Approval.id),
col(Approval.created_at),
)
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
.where(col(Approval.board_id) == board_id)
.where(col(Approval.status) == "pending")
.where(col(ApprovalTaskLink.task_id).in_(normalized_task_ids))
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
)
if exclude_approval_id is not None:
linked_statement = linked_statement.where(col(Approval.id) != exclude_approval_id)
linked_rows = list(await session.exec(linked_statement))
conflicts: dict[UUID, UUID] = {}
for task_id, approval_id, _created_at in linked_rows:
conflicts.setdefault(task_id, approval_id)
legacy_statement = (
select(
col(Approval.task_id),
col(Approval.id),
col(Approval.created_at),
)
.where(col(Approval.board_id) == board_id)
.where(col(Approval.status) == "pending")
.where(col(Approval.task_id).is_not(None))
.where(col(Approval.task_id).in_(normalized_task_ids))
.where(
~exists(
select(1)
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
.correlate(Approval),
),
)
.order_by(col(Approval.created_at).asc(), col(Approval.id).asc())
)
if exclude_approval_id is not None:
legacy_statement = legacy_statement.where(col(Approval.id) != exclude_approval_id)
legacy_rows = list(await session.exec(legacy_statement))
for legacy_task_id, approval_id, _created_at in legacy_rows:
if legacy_task_id is None:
continue
conflicts.setdefault(legacy_task_id, approval_id)
return conflicts
async def task_counts_for_board(
session: AsyncSession,
*,
board_id: UUID,
task_ids: set[UUID] | None = None,
) -> dict[UUID, tuple[int, int]]:
"""Compute total/pending approval counts per task across all linked tasks on a board."""
link_statement = (
select(
col(ApprovalTaskLink.task_id),
func.count(col(Approval.id)).label("total"),
func.sum(
case(
(col(Approval.status) == "pending", 1),
else_=0,
),
).label("pending"),
)
.join(Approval, col(Approval.id) == col(ApprovalTaskLink.approval_id))
.where(col(Approval.board_id) == board_id)
)
if task_ids is not None:
if not task_ids:
return {}
link_statement = link_statement.where(col(ApprovalTaskLink.task_id).in_(task_ids))
link_statement = link_statement.group_by(col(ApprovalTaskLink.task_id))
counts: dict[UUID, tuple[int, int]] = {}
for task_id, total, pending in list(await session.exec(link_statement)):
counts[task_id] = (int(total or 0), int(pending or 0))
# Backward compatibility: include legacy rows that have task_id set but no link rows.
legacy_statement = (
select(
col(Approval.task_id),
func.count(col(Approval.id)).label("total"),
func.sum(
case(
(col(Approval.status) == "pending", 1),
else_=0,
),
).label("pending"),
)
.where(col(Approval.board_id) == board_id)
.where(col(Approval.task_id).is_not(None))
.where(
~exists(
select(1)
.where(col(ApprovalTaskLink.approval_id) == col(Approval.id))
.correlate(Approval),
),
)
)
if task_ids is not None:
legacy_statement = legacy_statement.where(col(Approval.task_id).in_(task_ids))
legacy_statement = legacy_statement.group_by(col(Approval.task_id))
for legacy_task_id, total, pending in list(await session.exec(legacy_statement)):
if legacy_task_id is None:
continue
previous = counts.get(legacy_task_id, (0, 0))
counts[legacy_task_id] = (
previous[0] + int(total or 0),
previous[1] + int(pending or 0),
)
return counts