Files
openclaw-mission-control/backend/app/services/task_dependencies.py

241 lines
7.1 KiB
Python

"""Task-dependency helpers for validation, querying, and replacement."""
from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Final
from uuid import UUID
from fastapi import HTTPException, status
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.db import crud
from app.models.task_dependencies import TaskDependency
from app.models.tasks import Task
DONE_STATUS: Final[str] = "done"
_RUNTIME_TYPE_REFERENCES = (UUID, AsyncSession, Mapping, Sequence)
def _dedupe_uuid_list(values: Sequence[UUID]) -> list[UUID]:
# Preserve order; remove duplicates.
seen: set[UUID] = set()
output: list[UUID] = []
for value in values:
if value in seen:
continue
seen.add(value)
output.append(value)
return output
async def dependency_ids_by_task_id(
session: AsyncSession,
*,
board_id: UUID,
task_ids: Sequence[UUID],
) -> dict[UUID, list[UUID]]:
"""Return dependency ids keyed by task id for tasks on a board."""
if not task_ids:
return {}
rows = list(
await session.exec(
select(col(TaskDependency.task_id), col(TaskDependency.depends_on_task_id))
.where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.task_id).in_(task_ids))
.order_by(col(TaskDependency.created_at).asc()),
),
)
mapping: dict[UUID, list[UUID]] = defaultdict(list)
for task_id, depends_on_task_id in rows:
mapping[task_id].append(depends_on_task_id)
return dict(mapping)
async def dependency_status_by_id(
session: AsyncSession,
*,
board_id: UUID,
dependency_ids: Sequence[UUID],
) -> dict[UUID, str]:
"""Return dependency status values keyed by dependency task id."""
if not dependency_ids:
return {}
rows = list(
await session.exec(
select(col(Task.id), col(Task.status))
.where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(dependency_ids)),
),
)
return dict(rows)
def blocked_by_dependency_ids(
*,
dependency_ids: Sequence[UUID],
status_by_id: Mapping[UUID, str],
) -> list[UUID]:
"""Return dependency ids that are not yet in the done status."""
return [dep_id for dep_id in dependency_ids if status_by_id.get(dep_id) != DONE_STATUS]
async def blocked_by_for_task(
session: AsyncSession,
*,
board_id: UUID,
task_id: UUID,
dependency_ids: Sequence[UUID] | None = None,
) -> list[UUID]:
"""Return unresolved dependency ids for the provided task."""
dep_ids = list(dependency_ids or [])
if dependency_ids is None:
deps_map = await dependency_ids_by_task_id(
session,
board_id=board_id,
task_ids=[task_id],
)
dep_ids = deps_map.get(task_id, [])
if not dep_ids:
return []
status_by_id = await dependency_status_by_id(
session,
board_id=board_id,
dependency_ids=dep_ids,
)
return blocked_by_dependency_ids(dependency_ids=dep_ids, status_by_id=status_by_id)
def _has_cycle(nodes: Sequence[UUID], edges: Mapping[UUID, set[UUID]]) -> bool:
"""Detect cycles in a directed dependency graph."""
visited: set[UUID] = set()
in_stack: set[UUID] = set()
def dfs(current: UUID) -> bool:
if current in in_stack:
return True
if current in visited:
return False
visited.add(current)
in_stack.add(current)
for nxt in edges.get(current, set()):
if dfs(nxt):
return True
in_stack.remove(current)
return False
return any(dfs(start_node) for start_node in nodes)
async def validate_dependency_update(
session: AsyncSession,
*,
board_id: UUID,
task_id: UUID,
depends_on_task_ids: Sequence[UUID],
) -> list[UUID]:
"""Validate a dependency update and return normalized dependency ids."""
normalized = _dedupe_uuid_list(depends_on_task_ids)
if task_id in normalized:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail="Task cannot depend on itself.",
)
if not normalized:
return []
# Ensure all dependency tasks exist on this board.
existing_ids = set(
await session.exec(
select(col(Task.id))
.where(col(Task.board_id) == board_id)
.where(col(Task.id).in_(normalized)),
),
)
missing = [dep_id for dep_id in normalized if dep_id not in existing_ids]
if missing:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"message": "One or more dependency tasks were not found on this board.",
"missing_task_ids": [str(value) for value in missing],
},
)
# Rebuild the board-wide graph and overlay the pending edit for this task so
# validation catches indirect cycles created through existing edges.
task_ids = list(
await session.exec(
select(col(Task.id)).where(col(Task.board_id) == board_id),
),
)
rows = list(
await session.exec(
select(
col(TaskDependency.task_id),
col(TaskDependency.depends_on_task_id),
).where(col(TaskDependency.board_id) == board_id),
),
)
edges: dict[UUID, set[UUID]] = defaultdict(set)
for src, dst in rows:
edges[src].add(dst)
edges[task_id] = set(normalized)
if _has_cycle(task_ids, edges):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Dependency cycle detected. Remove the cycle before saving.",
)
return normalized
async def replace_task_dependencies(
session: AsyncSession,
*,
board_id: UUID,
task_id: UUID,
depends_on_task_ids: Sequence[UUID],
) -> list[UUID]:
"""Replace dependencies for a task and return the normalized dependency ids."""
normalized = await validate_dependency_update(
session,
board_id=board_id,
task_id=task_id,
depends_on_task_ids=depends_on_task_ids,
)
await crud.delete_where(
session,
TaskDependency,
col(TaskDependency.board_id) == board_id,
col(TaskDependency.task_id) == task_id,
commit=False,
)
for dep_id in normalized:
session.add(
TaskDependency(
board_id=board_id,
task_id=task_id,
depends_on_task_id=dep_id,
),
)
return normalized
async def dependent_task_ids(
session: AsyncSession,
*,
board_id: UUID,
dependency_task_id: UUID,
) -> list[UUID]:
"""Return task ids that depend on the provided dependency task id."""
rows = await session.exec(
select(col(TaskDependency.task_id))
.where(col(TaskDependency.board_id) == board_id)
.where(col(TaskDependency.depends_on_task_id) == dependency_task_id),
)
return list(rows)