Files
openclaw-mission-control/backend/app/api/approvals.py

227 lines
8.5 KiB
Python

from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy import asc, case, func, or_
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sse_starlette.sse import EventSourceResponse
from app.api.deps import ActorContext, get_board_or_404, require_admin_auth, require_admin_or_agent
from app.core.auth import AuthContext
from app.core.time import utcnow
from app.db.pagination import paginate
from app.db.session import async_session_maker, get_session
from app.models.approvals import Approval
from app.models.boards import Board
from app.schemas.approvals import ApprovalCreate, ApprovalRead, ApprovalStatus, ApprovalUpdate
from app.schemas.pagination import DefaultLimitOffsetPage
router = APIRouter(prefix="/boards/{board_id}/approvals", tags=["approvals"])
TASK_ID_KEYS: tuple[str, ...] = ("task_id", "taskId", "taskID")
def _extract_task_id(payload: dict[str, object] | None) -> UUID | None:
if not payload:
return None
for key in TASK_ID_KEYS:
value = payload.get(key)
if isinstance(value, UUID):
return value
if isinstance(value, str):
try:
return UUID(value)
except ValueError:
continue
return None
def _parse_since(value: str | None) -> datetime | None:
if not value:
return None
normalized = value.strip()
if not normalized:
return None
normalized = normalized.replace("Z", "+00:00")
try:
parsed = datetime.fromisoformat(normalized)
except ValueError:
return None
if parsed.tzinfo is not None:
return parsed.astimezone(timezone.utc).replace(tzinfo=None)
return parsed
def _approval_updated_at(approval: Approval) -> datetime:
return approval.resolved_at or approval.created_at
def _serialize_approval(approval: Approval) -> dict[str, object]:
return ApprovalRead.model_validate(approval, from_attributes=True).model_dump(mode="json")
async def _fetch_approval_events(
session: AsyncSession,
board_id: UUID,
since: datetime,
) -> list[Approval]:
statement = (
select(Approval)
.where(col(Approval.board_id) == board_id)
.where(
or_(
col(Approval.created_at) >= since,
col(Approval.resolved_at) >= since,
)
)
.order_by(asc(col(Approval.created_at)))
)
return list(await session.exec(statement))
@router.get("", response_model=DefaultLimitOffsetPage[ApprovalRead])
async def list_approvals(
status_filter: ApprovalStatus | None = Query(default=None, alias="status"),
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> DefaultLimitOffsetPage[ApprovalRead]:
if actor.actor_type == "agent" and actor.agent:
if actor.agent.board_id and actor.agent.board_id != board.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
statement = select(Approval).where(col(Approval.board_id) == board.id)
if status_filter:
statement = statement.where(col(Approval.status) == status_filter)
statement = statement.order_by(col(Approval.created_at).desc())
return await paginate(session, statement)
@router.get("/stream")
async def stream_approvals(
request: Request,
board: Board = Depends(get_board_or_404),
actor: ActorContext = Depends(require_admin_or_agent),
since: str | None = Query(default=None),
) -> EventSourceResponse:
if actor.actor_type == "agent" and actor.agent:
if actor.agent.board_id and actor.agent.board_id != board.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
since_dt = _parse_since(since) or utcnow()
last_seen = since_dt
async def event_generator() -> AsyncIterator[dict[str, str]]:
nonlocal last_seen
while True:
if await request.is_disconnected():
break
async with async_session_maker() as session:
approvals = await _fetch_approval_events(session, board.id, last_seen)
pending_approvals_count = int(
(
await session.exec(
select(func.count(col(Approval.id)))
.where(col(Approval.board_id) == board.id)
.where(col(Approval.status) == "pending")
)
).one()
)
task_ids = {
approval.task_id for approval in approvals if approval.task_id is not None
}
counts_by_task_id: dict[UUID, tuple[int, int]] = {}
if task_ids:
rows = list(
await session.exec(
select(
col(Approval.task_id),
func.count(col(Approval.id)).label("total"),
func.sum(
case((col(Approval.status) == "pending", 1), else_=0)
).label("pending"),
)
.where(col(Approval.board_id) == board.id)
.where(col(Approval.task_id).in_(task_ids))
.group_by(col(Approval.task_id))
)
)
for task_id, total, pending in rows:
if task_id is None:
continue
counts_by_task_id[task_id] = (int(total or 0), int(pending or 0))
for approval in approvals:
updated_at = _approval_updated_at(approval)
if updated_at > last_seen:
last_seen = updated_at
payload: dict[str, object] = {
"approval": _serialize_approval(approval),
"pending_approvals_count": pending_approvals_count,
}
if approval.task_id is not None:
counts = counts_by_task_id.get(approval.task_id)
if counts is not None:
total, pending = counts
payload["task_counts"] = {
"task_id": str(approval.task_id),
"approvals_count": total,
"approvals_pending_count": pending,
}
yield {"event": "approval", "data": json.dumps(payload)}
await asyncio.sleep(2)
return EventSourceResponse(event_generator(), ping=15)
@router.post("", response_model=ApprovalRead)
async def create_approval(
payload: ApprovalCreate,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
actor: ActorContext = Depends(require_admin_or_agent),
) -> Approval:
if actor.actor_type == "agent" and actor.agent:
if actor.agent.board_id and actor.agent.board_id != board.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
task_id = payload.task_id or _extract_task_id(payload.payload)
approval = Approval(
board_id=board.id,
task_id=task_id,
agent_id=payload.agent_id,
action_type=payload.action_type,
payload=payload.payload,
confidence=payload.confidence,
rubric_scores=payload.rubric_scores,
status=payload.status,
)
session.add(approval)
await session.commit()
await session.refresh(approval)
return approval
@router.patch("/{approval_id}", response_model=ApprovalRead)
async def update_approval(
approval_id: str,
payload: ApprovalUpdate,
board: Board = Depends(get_board_or_404),
session: AsyncSession = Depends(get_session),
auth: AuthContext = Depends(require_admin_auth),
) -> Approval:
approval = await session.get(Approval, approval_id)
if approval is None or approval.board_id != board.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
updates = payload.model_dump(exclude_unset=True)
if "status" in updates:
approval.status = updates["status"]
if approval.status != "pending":
approval.resolved_at = utcnow()
session.add(approval)
await session.commit()
await session.refresh(approval)
return approval